Refactor dataset features
This commit is contained in:
parent
757ea175d3
commit
aed9f4036a
|
@ -30,11 +30,11 @@ from huggingface_hub import snapshot_download, upload_folder
|
||||||
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats
|
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats
|
||||||
from lerobot.common.datasets.image_writer import ImageWriter
|
from lerobot.common.datasets.image_writer import ImageWriter
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
|
DEFAULT_FEATURES,
|
||||||
EPISODES_PATH,
|
EPISODES_PATH,
|
||||||
INFO_PATH,
|
INFO_PATH,
|
||||||
STATS_PATH,
|
STATS_PATH,
|
||||||
TASKS_PATH,
|
TASKS_PATH,
|
||||||
_get_info_from_robot,
|
|
||||||
append_jsonlines,
|
append_jsonlines,
|
||||||
check_delta_timestamps,
|
check_delta_timestamps,
|
||||||
check_timestamps_sync,
|
check_timestamps_sync,
|
||||||
|
@ -43,6 +43,7 @@ from lerobot.common.datasets.utils import (
|
||||||
create_empty_dataset_info,
|
create_empty_dataset_info,
|
||||||
get_delta_indices,
|
get_delta_indices,
|
||||||
get_episode_data_index,
|
get_episode_data_index,
|
||||||
|
get_features_from_robot,
|
||||||
get_hub_safe_version,
|
get_hub_safe_version,
|
||||||
hf_transform_to_torch,
|
hf_transform_to_torch,
|
||||||
load_episodes,
|
load_episodes,
|
||||||
|
@ -116,7 +117,7 @@ class LeRobotDatasetMetadata:
|
||||||
|
|
||||||
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
|
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
|
||||||
ep_chunk = self.get_episode_chunk(ep_index)
|
ep_chunk = self.get_episode_chunk(ep_index)
|
||||||
fpath = self.videos_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
|
fpath = self.video_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
|
||||||
return Path(fpath)
|
return Path(fpath)
|
||||||
|
|
||||||
def get_episode_chunk(self, ep_index: int) -> int:
|
def get_episode_chunk(self, ep_index: int) -> int:
|
||||||
|
@ -128,15 +129,20 @@ class LeRobotDatasetMetadata:
|
||||||
return self.info["data_path"]
|
return self.info["data_path"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def videos_path(self) -> str | None:
|
def video_path(self) -> str | None:
|
||||||
"""Formattable string for the video files."""
|
"""Formattable string for the video files."""
|
||||||
return self.info["videos"]["videos_path"] if len(self.video_keys) > 0 else None
|
return self.info["video_path"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def fps(self) -> int:
|
def fps(self) -> int:
|
||||||
"""Frames per second used during data collection."""
|
"""Frames per second used during data collection."""
|
||||||
return self.info["fps"]
|
return self.info["fps"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def features(self) -> dict[str, dict]:
|
||||||
|
""""""
|
||||||
|
return self.info["features"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def keys(self) -> list[str]:
|
def keys(self) -> list[str]:
|
||||||
"""Keys to access non-image data (state, actions etc.)."""
|
"""Keys to access non-image data (state, actions etc.)."""
|
||||||
|
@ -145,22 +151,27 @@ class LeRobotDatasetMetadata:
|
||||||
@property
|
@property
|
||||||
def image_keys(self) -> list[str]:
|
def image_keys(self) -> list[str]:
|
||||||
"""Keys to access visual modalities stored as images."""
|
"""Keys to access visual modalities stored as images."""
|
||||||
return self.info["image_keys"]
|
return [key for key, ft in self.features.items() if ft["dtype"] == "image"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def video_keys(self) -> list[str]:
|
def video_keys(self) -> list[str]:
|
||||||
"""Keys to access visual modalities stored as videos."""
|
"""Keys to access visual modalities stored as videos."""
|
||||||
return self.info["video_keys"]
|
return [key for key, ft in self.features.items() if ft["dtype"] == "video"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def camera_keys(self) -> list[str]:
|
def camera_keys(self) -> list[str]:
|
||||||
"""Keys to access visual modalities (regardless of their storage method)."""
|
"""Keys to access visual modalities (regardless of their storage method)."""
|
||||||
return self.image_keys + self.video_keys
|
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def names(self) -> dict[list[str]]:
|
def names(self) -> dict[str, list[str]]:
|
||||||
"""Names of the various dimensions of vector modalities."""
|
"""Names of the various dimensions of vector modalities."""
|
||||||
return self.info["names"]
|
return {key: ft["names"] for key, ft in self.features.items()}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shapes(self) -> dict:
|
||||||
|
"""Shapes for the different features."""
|
||||||
|
return {key: tuple(ft["shape"]) for key, ft in self.features.items()}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def total_episodes(self) -> int:
|
def total_episodes(self) -> int:
|
||||||
|
@ -187,11 +198,6 @@ class LeRobotDatasetMetadata:
|
||||||
"""Max number of episodes per chunk."""
|
"""Max number of episodes per chunk."""
|
||||||
return self.info["chunks_size"]
|
return self.info["chunks_size"]
|
||||||
|
|
||||||
@property
|
|
||||||
def shapes(self) -> dict:
|
|
||||||
"""Shapes for the different features."""
|
|
||||||
return self.info["shapes"]
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def task_to_task_index(self) -> dict:
|
def task_to_task_index(self) -> dict:
|
||||||
return {task: task_idx for task_idx, task in self.tasks.items()}
|
return {task: task_idx for task_idx, task in self.tasks.items()}
|
||||||
|
@ -253,45 +259,33 @@ class LeRobotDatasetMetadata:
|
||||||
root: Path | None = None,
|
root: Path | None = None,
|
||||||
robot: Robot | None = None,
|
robot: Robot | None = None,
|
||||||
robot_type: str | None = None,
|
robot_type: str | None = None,
|
||||||
keys: list[str] | None = None,
|
features: dict | None = None,
|
||||||
image_keys: list[str] | None = None,
|
|
||||||
video_keys: list[str] = None,
|
|
||||||
shapes: dict | None = None,
|
|
||||||
names: dict | None = None,
|
|
||||||
use_videos: bool = True,
|
use_videos: bool = True,
|
||||||
) -> "LeRobotDatasetMetadata":
|
) -> "LeRobotDatasetMetadata":
|
||||||
"""Creates metadata for a LeRobotDataset."""
|
"""Creates metadata for a LeRobotDataset."""
|
||||||
obj = cls.__new__(cls)
|
obj = cls.__new__(cls)
|
||||||
obj.repo_id = repo_id
|
obj.repo_id = repo_id
|
||||||
obj.root = root if root is not None else LEROBOT_HOME / repo_id
|
obj.root = root if root is not None else LEROBOT_HOME / repo_id
|
||||||
obj.image_writer = None
|
|
||||||
|
|
||||||
if robot is not None:
|
if robot is not None:
|
||||||
robot_type, keys, image_keys, video_keys, shapes, names = _get_info_from_robot(robot, use_videos)
|
features = get_features_from_robot(robot)
|
||||||
|
robot_type = robot.robot_type
|
||||||
if not all(cam.fps == fps for cam in robot.cameras.values()):
|
if not all(cam.fps == fps for cam in robot.cameras.values()):
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset."
|
f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset."
|
||||||
"In this case, frames from lower fps cameras will be repeated to fill in the blanks"
|
"In this case, frames from lower fps cameras will be repeated to fill in the blanks."
|
||||||
)
|
)
|
||||||
elif (
|
elif robot_type is None or features is None:
|
||||||
robot_type is None
|
|
||||||
or keys is None
|
|
||||||
or image_keys is None
|
|
||||||
or video_keys is None
|
|
||||||
or shapes is None
|
|
||||||
or names is None
|
|
||||||
):
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Dataset info (robot_type, keys, shapes...) must either come from a Robot or explicitly passed upon creation."
|
"Dataset features must either come from a Robot or explicitly passed upon creation."
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
if len(video_keys) > 0 and not use_videos:
|
features = {**features, **DEFAULT_FEATURES}
|
||||||
raise ValueError()
|
|
||||||
|
|
||||||
obj.tasks, obj.stats, obj.episodes = {}, {}, []
|
obj.tasks, obj.stats, obj.episodes = {}, {}, []
|
||||||
obj.info = create_empty_dataset_info(
|
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos)
|
||||||
CODEBASE_VERSION, fps, robot_type, keys, image_keys, video_keys, shapes, names
|
if len(obj.video_keys) > 0 and not use_videos:
|
||||||
)
|
raise ValueError()
|
||||||
write_json(obj.info, obj.root / INFO_PATH)
|
write_json(obj.info, obj.root / INFO_PATH)
|
||||||
obj.local_files_only = True
|
obj.local_files_only = True
|
||||||
return obj
|
return obj
|
||||||
|
@ -509,6 +503,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
hf_dataset = load_dataset("parquet", data_files=files, split="train")
|
hf_dataset = load_dataset("parquet", data_files=files, split="train")
|
||||||
|
|
||||||
hf_dataset.set_transform(hf_transform_to_torch)
|
hf_dataset.set_transform(hf_transform_to_torch)
|
||||||
|
# return hf_dataset.with_format("torch") TODO
|
||||||
return hf_dataset
|
return hf_dataset
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -662,8 +657,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
"task_index": None,
|
"task_index": None,
|
||||||
"frame_index": [],
|
"frame_index": [],
|
||||||
"timestamp": [],
|
"timestamp": [],
|
||||||
"next.done": [],
|
**{key: [] for key in self.meta.features},
|
||||||
**{key: [] for key in self.meta.keys},
|
|
||||||
**{key: [] for key in self.meta.image_keys},
|
**{key: [] for key in self.meta.image_keys},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -845,7 +839,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(
|
def create(
|
||||||
cls,
|
cls,
|
||||||
metadata: LeRobotDatasetMetadata,
|
repo_id: str,
|
||||||
|
fps: int,
|
||||||
|
root: Path | None = None,
|
||||||
|
robot: Robot | None = None,
|
||||||
|
robot_type: str | None = None,
|
||||||
|
features: dict | None = None,
|
||||||
|
use_videos: bool = True,
|
||||||
tolerance_s: float = 1e-4,
|
tolerance_s: float = 1e-4,
|
||||||
image_writer_processes: int = 0,
|
image_writer_processes: int = 0,
|
||||||
image_writer_threads: int = 0,
|
image_writer_threads: int = 0,
|
||||||
|
@ -853,7 +853,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
) -> "LeRobotDataset":
|
) -> "LeRobotDataset":
|
||||||
"""Create a LeRobot Dataset from scratch in order to record data."""
|
"""Create a LeRobot Dataset from scratch in order to record data."""
|
||||||
obj = cls.__new__(cls)
|
obj = cls.__new__(cls)
|
||||||
obj.meta = metadata
|
obj.meta = LeRobotDatasetMetadata.create(
|
||||||
|
repo_id=repo_id,
|
||||||
|
fps=fps,
|
||||||
|
root=root,
|
||||||
|
robot=robot,
|
||||||
|
robot_type=robot_type,
|
||||||
|
features=features,
|
||||||
|
use_videos=use_videos,
|
||||||
|
)
|
||||||
obj.repo_id = obj.meta.repo_id
|
obj.repo_id = obj.meta.repo_id
|
||||||
obj.root = obj.meta.root
|
obj.root = obj.meta.root
|
||||||
obj.local_files_only = obj.meta.local_files_only
|
obj.local_files_only = obj.meta.local_files_only
|
||||||
|
|
|
@ -48,6 +48,14 @@ This dataset was created using [LeRobot](https://github.com/huggingface/lerobot)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
DEFAULT_FEATURES = {
|
||||||
|
"timestamp": {"dtype": "float32", "shape": (1,), "names": None},
|
||||||
|
"frame_index": {"dtype": "int64", "shape": (1,), "names": None},
|
||||||
|
"episode_index": {"dtype": "int64", "shape": (1,), "names": None},
|
||||||
|
"index": {"dtype": "int64", "shape": (1,), "names": None},
|
||||||
|
"task_index": {"dtype": "int64", "shape": (1,), "names": None},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict:
|
def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict:
|
||||||
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
|
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
|
||||||
|
@ -214,39 +222,25 @@ def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) ->
|
||||||
return version
|
return version
|
||||||
|
|
||||||
|
|
||||||
def _get_info_from_robot(robot: Robot, use_videos: bool) -> tuple[list | dict]:
|
def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict:
|
||||||
shapes = {key: len(names) for key, names in robot.names.items()}
|
camera_ft = {}
|
||||||
camera_shapes = {}
|
if robot.cameras:
|
||||||
for key, cam in robot.cameras.items():
|
camera_ft = {
|
||||||
video_key = f"observation.images.{key}"
|
key: {"dtype": "video" if use_videos else "image", **ft}
|
||||||
camera_shapes[video_key] = {
|
for key, ft in robot.camera_features.items()
|
||||||
"width": cam.width,
|
|
||||||
"height": cam.height,
|
|
||||||
"channels": cam.channels,
|
|
||||||
}
|
}
|
||||||
keys = list(robot.names)
|
return {**robot.motor_features, **camera_ft, **DEFAULT_FEATURES}
|
||||||
image_keys = [] if use_videos else list(camera_shapes)
|
|
||||||
video_keys = list(camera_shapes) if use_videos else []
|
|
||||||
shapes = {**shapes, **camera_shapes}
|
|
||||||
names = robot.names
|
|
||||||
robot_type = robot.robot_type
|
|
||||||
|
|
||||||
return robot_type, keys, image_keys, video_keys, shapes, names
|
|
||||||
|
|
||||||
|
|
||||||
def create_empty_dataset_info(
|
def create_empty_dataset_info(
|
||||||
codebase_version: str,
|
codebase_version: str,
|
||||||
fps: int,
|
fps: int,
|
||||||
robot_type: str,
|
robot_type: str,
|
||||||
keys: list[str],
|
features: dict,
|
||||||
image_keys: list[str],
|
use_videos: bool,
|
||||||
video_keys: list[str],
|
|
||||||
shapes: dict,
|
|
||||||
names: dict,
|
|
||||||
) -> dict:
|
) -> dict:
|
||||||
return {
|
return {
|
||||||
"codebase_version": codebase_version,
|
"codebase_version": codebase_version,
|
||||||
"data_path": DEFAULT_PARQUET_PATH,
|
|
||||||
"robot_type": robot_type,
|
"robot_type": robot_type,
|
||||||
"total_episodes": 0,
|
"total_episodes": 0,
|
||||||
"total_frames": 0,
|
"total_frames": 0,
|
||||||
|
@ -256,12 +250,9 @@ def create_empty_dataset_info(
|
||||||
"chunks_size": DEFAULT_CHUNK_SIZE,
|
"chunks_size": DEFAULT_CHUNK_SIZE,
|
||||||
"fps": fps,
|
"fps": fps,
|
||||||
"splits": {},
|
"splits": {},
|
||||||
"keys": keys,
|
"data_path": DEFAULT_PARQUET_PATH,
|
||||||
"video_keys": video_keys,
|
"video_path": DEFAULT_VIDEO_PATH if use_videos else None,
|
||||||
"image_keys": image_keys,
|
"features": features,
|
||||||
"shapes": shapes,
|
|
||||||
"names": names,
|
|
||||||
"videos": {"videos_path": DEFAULT_VIDEO_PATH} if len(video_keys) > 0 else None,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -400,6 +391,12 @@ def create_lerobot_dataset_card(
|
||||||
tags: list | None = None, text: str | None = None, info: dict | None = None
|
tags: list | None = None, text: str | None = None, info: dict | None = None
|
||||||
) -> DatasetCard:
|
) -> DatasetCard:
|
||||||
card = DatasetCard(DATASET_CARD_TEMPLATE)
|
card = DatasetCard(DATASET_CARD_TEMPLATE)
|
||||||
|
card.data.configs = [
|
||||||
|
{
|
||||||
|
"config_name": "default",
|
||||||
|
"data_files": "data/*/*.parquet",
|
||||||
|
}
|
||||||
|
]
|
||||||
card.data.task_categories = ["robotics"]
|
card.data.task_categories = ["robotics"]
|
||||||
card.data.tags = ["LeRobot"]
|
card.data.tags = ["LeRobot"]
|
||||||
if tags is not None:
|
if tags is not None:
|
||||||
|
|
|
@ -106,6 +106,7 @@ import json
|
||||||
import math
|
import math
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import tempfile
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
@ -137,9 +138,8 @@ from lerobot.common.datasets.utils import (
|
||||||
)
|
)
|
||||||
from lerobot.common.datasets.video_utils import (
|
from lerobot.common.datasets.video_utils import (
|
||||||
VideoFrame, # noqa: F401
|
VideoFrame, # noqa: F401
|
||||||
get_image_shapes,
|
get_image_pixel_channels,
|
||||||
get_video_info,
|
get_video_info,
|
||||||
get_video_shapes,
|
|
||||||
)
|
)
|
||||||
from lerobot.common.utils.utils import init_hydra_config
|
from lerobot.common.utils.utils import init_hydra_config
|
||||||
|
|
||||||
|
@ -202,21 +202,37 @@ def convert_stats_to_json(v1_dir: Path, v2_dir: Path) -> None:
|
||||||
torch.testing.assert_close(stats_json[key], stats[key])
|
torch.testing.assert_close(stats_json[key], stats[key])
|
||||||
|
|
||||||
|
|
||||||
def get_keys(dataset: Dataset) -> dict[str, list]:
|
def get_features_from_hf_dataset(dataset: Dataset, robot_config: dict | None = None) -> dict[str, list]:
|
||||||
sequence_keys, image_keys, video_keys = [], [], []
|
features = {}
|
||||||
for key, ft in dataset.features.items():
|
for key, ft in dataset.features.items():
|
||||||
|
if isinstance(ft, datasets.Value):
|
||||||
|
dtype = ft.dtype
|
||||||
|
shape = (1,)
|
||||||
|
names = None
|
||||||
if isinstance(ft, datasets.Sequence):
|
if isinstance(ft, datasets.Sequence):
|
||||||
sequence_keys.append(key)
|
assert isinstance(ft.feature, datasets.Value)
|
||||||
|
dtype = ft.feature.dtype
|
||||||
|
shape = (ft.length,)
|
||||||
|
names = robot_config["names"][key] if robot_config else [f"motor_{i}" for i in range(ft.length)]
|
||||||
|
assert len(names) == shape[0]
|
||||||
elif isinstance(ft, datasets.Image):
|
elif isinstance(ft, datasets.Image):
|
||||||
image_keys.append(key)
|
dtype = "image"
|
||||||
|
image = dataset[0][key] # Assuming first row
|
||||||
|
channels = get_image_pixel_channels(image)
|
||||||
|
shape = (image.width, image.height, channels)
|
||||||
|
names = ["width", "height", "channel"]
|
||||||
elif ft._type == "VideoFrame":
|
elif ft._type == "VideoFrame":
|
||||||
video_keys.append(key)
|
dtype = "video"
|
||||||
|
shape = None # Add shape later
|
||||||
|
names = ["width", "height", "channel"]
|
||||||
|
|
||||||
return {
|
features[key] = {
|
||||||
"sequence": sequence_keys,
|
"dtype": dtype,
|
||||||
"image": image_keys,
|
"shape": shape,
|
||||||
"video": video_keys,
|
"names": names,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
def add_task_index_by_episodes(dataset: Dataset, tasks_by_episodes: dict) -> tuple[Dataset, list[str]]:
|
def add_task_index_by_episodes(dataset: Dataset, tasks_by_episodes: dict) -> tuple[Dataset, list[str]]:
|
||||||
|
@ -259,17 +275,15 @@ def add_task_index_from_tasks_col(
|
||||||
|
|
||||||
def split_parquet_by_episodes(
|
def split_parquet_by_episodes(
|
||||||
dataset: Dataset,
|
dataset: Dataset,
|
||||||
keys: dict[str, list],
|
|
||||||
total_episodes: int,
|
total_episodes: int,
|
||||||
total_chunks: int,
|
total_chunks: int,
|
||||||
output_dir: Path,
|
output_dir: Path,
|
||||||
) -> list:
|
) -> list:
|
||||||
table = dataset.remove_columns(keys["video"])._data.table
|
table = dataset.data.table
|
||||||
episode_lengths = []
|
episode_lengths = []
|
||||||
for ep_chunk in range(total_chunks):
|
for ep_chunk in range(total_chunks):
|
||||||
ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk
|
ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk
|
||||||
ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
|
ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
|
||||||
|
|
||||||
chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format(episode_chunk=ep_chunk)
|
chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format(episode_chunk=ep_chunk)
|
||||||
(output_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
|
(output_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
|
||||||
for ep_idx in range(ep_chunk_start, ep_chunk_end):
|
for ep_idx in range(ep_chunk_start, ep_chunk_end):
|
||||||
|
@ -396,27 +410,22 @@ def _get_lfs_untracked_videos(work_dir: Path, video_files: list[str]) -> list[st
|
||||||
|
|
||||||
|
|
||||||
def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch: str) -> dict:
|
def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch: str) -> dict:
|
||||||
hub_api = HfApi()
|
|
||||||
videos_info_dict = {"videos_path": DEFAULT_VIDEO_PATH}
|
|
||||||
|
|
||||||
# Assumes first episode
|
# Assumes first episode
|
||||||
video_files = [
|
video_files = [
|
||||||
DEFAULT_VIDEO_PATH.format(episode_chunk=0, video_key=vid_key, episode_index=0)
|
DEFAULT_VIDEO_PATH.format(episode_chunk=0, video_key=vid_key, episode_index=0)
|
||||||
for vid_key in video_keys
|
for vid_key in video_keys
|
||||||
]
|
]
|
||||||
|
hub_api = HfApi()
|
||||||
hub_api.snapshot_download(
|
hub_api.snapshot_download(
|
||||||
repo_id=repo_id, repo_type="dataset", local_dir=local_dir, revision=branch, allow_patterns=video_files
|
repo_id=repo_id, repo_type="dataset", local_dir=local_dir, revision=branch, allow_patterns=video_files
|
||||||
)
|
)
|
||||||
|
videos_info_dict = {}
|
||||||
for vid_key, vid_path in zip(video_keys, video_files, strict=True):
|
for vid_key, vid_path in zip(video_keys, video_files, strict=True):
|
||||||
videos_info_dict[vid_key] = get_video_info(local_dir / vid_path)
|
videos_info_dict[vid_key] = get_video_info(local_dir / vid_path)
|
||||||
|
|
||||||
return videos_info_dict
|
return videos_info_dict
|
||||||
|
|
||||||
|
|
||||||
def get_generic_motor_names(sequence_shapes: dict) -> dict:
|
|
||||||
return {key: [f"motor_{i}" for i in range(length)] for key, length in sequence_shapes.items()}
|
|
||||||
|
|
||||||
|
|
||||||
def convert_dataset(
|
def convert_dataset(
|
||||||
repo_id: str,
|
repo_id: str,
|
||||||
local_dir: Path,
|
local_dir: Path,
|
||||||
|
@ -443,7 +452,8 @@ def convert_dataset(
|
||||||
|
|
||||||
metadata_v1 = load_json(v1x_dir / V1_INFO_PATH)
|
metadata_v1 = load_json(v1x_dir / V1_INFO_PATH)
|
||||||
dataset = datasets.load_dataset("parquet", data_dir=v1x_dir / "data", split="train")
|
dataset = datasets.load_dataset("parquet", data_dir=v1x_dir / "data", split="train")
|
||||||
keys = get_keys(dataset)
|
features = get_features_from_hf_dataset(dataset, robot_config)
|
||||||
|
video_keys = [key for key, ft in features.items() if ft["dtype"] == "video"]
|
||||||
|
|
||||||
if single_task and "language_instruction" in dataset.column_names:
|
if single_task and "language_instruction" in dataset.column_names:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
|
@ -457,7 +467,7 @@ def convert_dataset(
|
||||||
episode_indices = sorted(dataset.unique("episode_index"))
|
episode_indices = sorted(dataset.unique("episode_index"))
|
||||||
total_episodes = len(episode_indices)
|
total_episodes = len(episode_indices)
|
||||||
assert episode_indices == list(range(total_episodes))
|
assert episode_indices == list(range(total_episodes))
|
||||||
total_videos = total_episodes * len(keys["video"])
|
total_videos = total_episodes * len(video_keys)
|
||||||
total_chunks = total_episodes // DEFAULT_CHUNK_SIZE
|
total_chunks = total_episodes // DEFAULT_CHUNK_SIZE
|
||||||
if total_episodes % DEFAULT_CHUNK_SIZE != 0:
|
if total_episodes % DEFAULT_CHUNK_SIZE != 0:
|
||||||
total_chunks += 1
|
total_chunks += 1
|
||||||
|
@ -470,7 +480,6 @@ def convert_dataset(
|
||||||
elif tasks_path:
|
elif tasks_path:
|
||||||
tasks_by_episodes = load_json(tasks_path)
|
tasks_by_episodes = load_json(tasks_path)
|
||||||
tasks_by_episodes = {int(ep_idx): task for ep_idx, task in tasks_by_episodes.items()}
|
tasks_by_episodes = {int(ep_idx): task for ep_idx, task in tasks_by_episodes.items()}
|
||||||
# tasks = list(set(tasks_by_episodes.values()))
|
|
||||||
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
|
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
|
||||||
tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
|
tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
|
||||||
elif tasks_col:
|
elif tasks_col:
|
||||||
|
@ -481,56 +490,50 @@ def convert_dataset(
|
||||||
assert set(tasks) == {task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks}
|
assert set(tasks) == {task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks}
|
||||||
tasks = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)]
|
tasks = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)]
|
||||||
write_jsonlines(tasks, v20_dir / TASKS_PATH)
|
write_jsonlines(tasks, v20_dir / TASKS_PATH)
|
||||||
|
features["task_index"] = {
|
||||||
# Shapes
|
"dtype": "int64",
|
||||||
sequence_shapes = {key: dataset.features[key].length for key in keys["sequence"]}
|
"shape": (1,),
|
||||||
image_shapes = get_image_shapes(dataset, keys["image"]) if len(keys["image"]) > 0 else {}
|
"names": None,
|
||||||
|
}
|
||||||
|
|
||||||
# Videos
|
# Videos
|
||||||
if len(keys["video"]) > 0:
|
if video_keys:
|
||||||
assert metadata_v1.get("video", False)
|
assert metadata_v1.get("video", False)
|
||||||
tmp_video_dir = local_dir / "videos" / V20 / repo_id
|
dataset = dataset.remove_columns(video_keys)
|
||||||
tmp_video_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
clean_gitattr = Path(
|
clean_gitattr = Path(
|
||||||
hub_api.hf_hub_download(
|
hub_api.hf_hub_download(
|
||||||
repo_id=GITATTRIBUTES_REF, repo_type="dataset", local_dir=local_dir, filename=".gitattributes"
|
repo_id=GITATTRIBUTES_REF, repo_type="dataset", local_dir=local_dir, filename=".gitattributes"
|
||||||
)
|
)
|
||||||
).absolute()
|
).absolute()
|
||||||
move_videos(
|
with tempfile.TemporaryDirectory() as tmp_video_dir:
|
||||||
repo_id, keys["video"], total_episodes, total_chunks, tmp_video_dir, clean_gitattr, branch
|
move_videos(
|
||||||
)
|
repo_id, video_keys, total_episodes, total_chunks, Path(tmp_video_dir), clean_gitattr, branch
|
||||||
videos_info = get_videos_info(repo_id, v1x_dir, video_keys=keys["video"], branch=branch)
|
)
|
||||||
video_shapes = get_video_shapes(videos_info, keys["video"])
|
videos_info = get_videos_info(repo_id, v1x_dir, video_keys=video_keys, branch=branch)
|
||||||
for img_key in keys["video"]:
|
for key in video_keys:
|
||||||
assert math.isclose(videos_info[img_key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3)
|
features[key]["shape"] = (
|
||||||
|
videos_info[key].pop("video.width"),
|
||||||
|
videos_info[key].pop("video.height"),
|
||||||
|
videos_info[key].pop("video.channels"),
|
||||||
|
)
|
||||||
|
features[key]["video_info"] = videos_info[key]
|
||||||
|
assert math.isclose(videos_info[key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3)
|
||||||
if "encoding" in metadata_v1:
|
if "encoding" in metadata_v1:
|
||||||
assert videos_info[img_key]["video.pix_fmt"] == metadata_v1["encoding"]["pix_fmt"]
|
assert videos_info[key]["video.pix_fmt"] == metadata_v1["encoding"]["pix_fmt"]
|
||||||
else:
|
else:
|
||||||
assert metadata_v1.get("video", 0) == 0
|
assert metadata_v1.get("video", 0) == 0
|
||||||
videos_info = None
|
videos_info = None
|
||||||
video_shapes = {}
|
|
||||||
|
|
||||||
# Split data into 1 parquet file by episode
|
# Split data into 1 parquet file by episode
|
||||||
episode_lengths = split_parquet_by_episodes(dataset, keys, total_episodes, total_chunks, v20_dir)
|
episode_lengths = split_parquet_by_episodes(dataset, total_episodes, total_chunks, v20_dir)
|
||||||
|
|
||||||
# Names
|
|
||||||
if robot_config is not None:
|
if robot_config is not None:
|
||||||
robot_type = robot_config["robot_type"]
|
robot_type = robot_config["robot_type"]
|
||||||
names = robot_config["names"]
|
|
||||||
if "observation.effort" in keys["sequence"]:
|
|
||||||
names["observation.effort"] = names["observation.state"]
|
|
||||||
if "observation.velocity" in keys["sequence"]:
|
|
||||||
names["observation.velocity"] = names["observation.state"]
|
|
||||||
repo_tags = [robot_type]
|
repo_tags = [robot_type]
|
||||||
else:
|
else:
|
||||||
robot_type = "unknown"
|
robot_type = "unknown"
|
||||||
names = get_generic_motor_names(sequence_shapes)
|
|
||||||
repo_tags = None
|
repo_tags = None
|
||||||
|
|
||||||
assert set(names) == set(keys["sequence"])
|
|
||||||
for key in sequence_shapes:
|
|
||||||
assert len(names[key]) == sequence_shapes[key]
|
|
||||||
|
|
||||||
# Episodes
|
# Episodes
|
||||||
episodes = [
|
episodes = [
|
||||||
{"episode_index": ep_idx, "tasks": tasks_by_episodes[ep_idx], "length": episode_lengths[ep_idx]}
|
{"episode_index": ep_idx, "tasks": tasks_by_episodes[ep_idx], "length": episode_lengths[ep_idx]}
|
||||||
|
@ -541,7 +544,6 @@ def convert_dataset(
|
||||||
# Assemble metadata v2.0
|
# Assemble metadata v2.0
|
||||||
metadata_v2_0 = {
|
metadata_v2_0 = {
|
||||||
"codebase_version": V20,
|
"codebase_version": V20,
|
||||||
"data_path": DEFAULT_PARQUET_PATH,
|
|
||||||
"robot_type": robot_type,
|
"robot_type": robot_type,
|
||||||
"total_episodes": total_episodes,
|
"total_episodes": total_episodes,
|
||||||
"total_frames": len(dataset),
|
"total_frames": len(dataset),
|
||||||
|
@ -551,15 +553,13 @@ def convert_dataset(
|
||||||
"chunks_size": DEFAULT_CHUNK_SIZE,
|
"chunks_size": DEFAULT_CHUNK_SIZE,
|
||||||
"fps": metadata_v1["fps"],
|
"fps": metadata_v1["fps"],
|
||||||
"splits": {"train": f"0:{total_episodes}"},
|
"splits": {"train": f"0:{total_episodes}"},
|
||||||
"keys": keys["sequence"],
|
"data_path": DEFAULT_PARQUET_PATH,
|
||||||
"video_keys": keys["video"],
|
"video_path": DEFAULT_VIDEO_PATH if video_keys else None,
|
||||||
"image_keys": keys["image"],
|
"features": features,
|
||||||
"shapes": {**sequence_shapes, **video_shapes, **image_shapes},
|
|
||||||
"names": names,
|
|
||||||
"videos": videos_info,
|
|
||||||
}
|
}
|
||||||
write_json(metadata_v2_0, v20_dir / INFO_PATH)
|
write_json(metadata_v2_0, v20_dir / INFO_PATH)
|
||||||
convert_stats_to_json(v1x_dir, v20_dir)
|
convert_stats_to_json(v1x_dir, v20_dir)
|
||||||
|
card = create_lerobot_dataset_card(tags=repo_tags, info=metadata_v2_0)
|
||||||
|
|
||||||
with contextlib.suppress(EntryNotFoundError):
|
with contextlib.suppress(EntryNotFoundError):
|
||||||
hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch)
|
hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch)
|
||||||
|
@ -585,28 +585,11 @@ def convert_dataset(
|
||||||
revision=branch,
|
revision=branch,
|
||||||
)
|
)
|
||||||
|
|
||||||
card = create_lerobot_dataset_card(tags=repo_tags, info=metadata_v2_0)
|
|
||||||
card.push_to_hub(repo_id=repo_id, repo_type="dataset", revision=branch)
|
card.push_to_hub(repo_id=repo_id, repo_type="dataset", revision=branch)
|
||||||
|
|
||||||
if not test_branch:
|
if not test_branch:
|
||||||
create_branch(repo_id=repo_id, branch=V20, repo_type="dataset")
|
create_branch(repo_id=repo_id, branch=V20, repo_type="dataset")
|
||||||
|
|
||||||
# TODO:
|
|
||||||
# - [X] Add shapes
|
|
||||||
# - [X] Add keys
|
|
||||||
# - [X] Add paths
|
|
||||||
# - [X] convert stats.json
|
|
||||||
# - [X] Add task.json
|
|
||||||
# - [X] Add names
|
|
||||||
# - [X] Add robot_type
|
|
||||||
# - [X] Add splits
|
|
||||||
# - [X] Push properly to branch v2.0 and delete v1.6 stuff from that branch
|
|
||||||
# - [X] Handle multitask datasets
|
|
||||||
# - [X] Handle hf hub repo limits (add chunks logic)
|
|
||||||
# - [X] Add test-branch
|
|
||||||
# - [X] Use jsonlines for episodes
|
|
||||||
# - [X] Add sanity checks (encoding, shapes)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
|
@ -25,7 +25,6 @@ from typing import Any, ClassVar
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
import torch
|
import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
from datasets import Dataset
|
|
||||||
from datasets.features.features import register_feature
|
from datasets.features.features import register_feature
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
@ -292,33 +291,6 @@ def get_video_info(video_path: Path | str) -> dict:
|
||||||
return video_info
|
return video_info
|
||||||
|
|
||||||
|
|
||||||
def get_video_shapes(videos_info: dict, video_keys: list) -> dict:
|
|
||||||
video_shapes = {}
|
|
||||||
for img_key in video_keys:
|
|
||||||
channels = get_video_pixel_channels(videos_info[img_key]["video.pix_fmt"])
|
|
||||||
video_shapes[img_key] = {
|
|
||||||
"width": videos_info[img_key]["video.width"],
|
|
||||||
"height": videos_info[img_key]["video.height"],
|
|
||||||
"channels": channels,
|
|
||||||
}
|
|
||||||
|
|
||||||
return video_shapes
|
|
||||||
|
|
||||||
|
|
||||||
def get_image_shapes(dataset: Dataset, image_keys: list) -> dict:
|
|
||||||
image_shapes = {}
|
|
||||||
for img_key in image_keys:
|
|
||||||
image = dataset[0][img_key] # Assuming first row
|
|
||||||
channels = get_image_pixel_channels(image)
|
|
||||||
image_shapes[img_key] = {
|
|
||||||
"width": image.width,
|
|
||||||
"height": image.height,
|
|
||||||
"channels": channels,
|
|
||||||
}
|
|
||||||
|
|
||||||
return image_shapes
|
|
||||||
|
|
||||||
|
|
||||||
def get_video_pixel_channels(pix_fmt: str) -> int:
|
def get_video_pixel_channels(pix_fmt: str) -> int:
|
||||||
if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt:
|
if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt:
|
||||||
return 1
|
return 1
|
||||||
|
|
|
@ -226,13 +226,42 @@ class ManipulatorRobot:
|
||||||
self.is_connected = False
|
self.is_connected = False
|
||||||
self.logs = {}
|
self.logs = {}
|
||||||
|
|
||||||
action_names = [f"{arm}_{motor}" for arm, bus in self.leader_arms.items() for motor in bus.motors]
|
def get_motor_names(self, arm: dict[str, MotorsBus]) -> list:
|
||||||
state_names = [f"{arm}_{motor}" for arm, bus in self.follower_arms.items() for motor in bus.motors]
|
return [f"{arm}_{motor}" for arm, bus in arm.items() for motor in bus.motors]
|
||||||
self.names = {
|
|
||||||
"action": action_names,
|
@property
|
||||||
"observation.state": state_names,
|
def camera_features(self) -> dict:
|
||||||
|
cam_ft = {}
|
||||||
|
for cam_key, cam in self.cameras.items():
|
||||||
|
key = f"observation.images.{cam_key}"
|
||||||
|
cam_ft[key] = {
|
||||||
|
"shape": (cam.width, cam.height, cam.channels),
|
||||||
|
"names": ["width", "height", "channels"],
|
||||||
|
"info": None,
|
||||||
|
}
|
||||||
|
return cam_ft
|
||||||
|
|
||||||
|
@property
|
||||||
|
def motor_features(self) -> dict:
|
||||||
|
action_names = self.get_motor_names(self.leader_arms)
|
||||||
|
state_names = self.get_motor_names(self.leader_arms)
|
||||||
|
return {
|
||||||
|
"action": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (len(action_names),),
|
||||||
|
"names": action_names,
|
||||||
|
},
|
||||||
|
"observation.state": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (len(state_names),),
|
||||||
|
"names": state_names,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def features(self):
|
||||||
|
return {**self.motor_features, **self.camera_features}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def has_camera(self):
|
def has_camera(self):
|
||||||
return len(self.cameras) > 0
|
return len(self.cameras) > 0
|
||||||
|
|
|
@ -11,6 +11,7 @@ def get_arm_id(name, arm_type):
|
||||||
class Robot(Protocol):
|
class Robot(Protocol):
|
||||||
# TODO(rcadene, aliberts): Add unit test checking the protocol is implemented in the corresponding classes
|
# TODO(rcadene, aliberts): Add unit test checking the protocol is implemented in the corresponding classes
|
||||||
robot_type: str
|
robot_type: str
|
||||||
|
features: dict
|
||||||
|
|
||||||
def connect(self): ...
|
def connect(self): ...
|
||||||
def run_calibration(self): ...
|
def run_calibration(self): ...
|
||||||
|
|
|
@ -105,7 +105,7 @@ from pathlib import Path
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
# from safetensors.torch import load_file, save_file
|
# from safetensors.torch import load_file, save_file
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.common.robot_devices.control_utils import (
|
from lerobot.common.robot_devices.control_utils import (
|
||||||
control_loop,
|
control_loop,
|
||||||
has_method,
|
has_method,
|
||||||
|
@ -234,15 +234,12 @@ def record(
|
||||||
|
|
||||||
# Create empty dataset or load existing saved episodes
|
# Create empty dataset or load existing saved episodes
|
||||||
sanity_check_dataset_name(repo_id, policy)
|
sanity_check_dataset_name(repo_id, policy)
|
||||||
dataset_metadata = LeRobotDatasetMetadata.create(
|
dataset = LeRobotDataset.create(
|
||||||
repo_id,
|
repo_id,
|
||||||
fps,
|
fps,
|
||||||
root=root,
|
root=root,
|
||||||
robot=robot,
|
robot=robot,
|
||||||
use_videos=video,
|
use_videos=video,
|
||||||
)
|
|
||||||
dataset = LeRobotDataset.create(
|
|
||||||
dataset_metadata,
|
|
||||||
image_writer_processes=num_image_writer_processes,
|
image_writer_processes=num_image_writer_processes,
|
||||||
image_writer_threads=num_image_writer_threads_per_camera,
|
image_writer_threads=num_image_writer_threads_per_camera,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue