From aed9f4036a295156754a10d55411618efb1087ce Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Tue, 5 Nov 2024 13:10:43 +0100 Subject: [PATCH] Refactor dataset features --- lerobot/common/datasets/lerobot_dataset.py | 90 ++++++------ lerobot/common/datasets/utils.py | 55 ++++--- .../datasets/v2/convert_dataset_v1_to_v2.py | 137 ++++++++---------- lerobot/common/datasets/video_utils.py | 28 ---- .../robot_devices/robots/manipulator.py | 39 ++++- lerobot/common/robot_devices/robots/utils.py | 1 + lerobot/scripts/control_robot.py | 7 +- 7 files changed, 172 insertions(+), 185 deletions(-) diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index f5932b7e..f03d6826 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -30,11 +30,11 @@ from huggingface_hub import snapshot_download, upload_folder from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats from lerobot.common.datasets.image_writer import ImageWriter from lerobot.common.datasets.utils import ( + DEFAULT_FEATURES, EPISODES_PATH, INFO_PATH, STATS_PATH, TASKS_PATH, - _get_info_from_robot, append_jsonlines, check_delta_timestamps, check_timestamps_sync, @@ -43,6 +43,7 @@ from lerobot.common.datasets.utils import ( create_empty_dataset_info, get_delta_indices, get_episode_data_index, + get_features_from_robot, get_hub_safe_version, hf_transform_to_torch, load_episodes, @@ -116,7 +117,7 @@ class LeRobotDatasetMetadata: def get_video_file_path(self, ep_index: int, vid_key: str) -> Path: ep_chunk = self.get_episode_chunk(ep_index) - fpath = self.videos_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index) + fpath = self.video_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index) return Path(fpath) def get_episode_chunk(self, ep_index: int) -> int: @@ -128,15 +129,20 @@ class LeRobotDatasetMetadata: return self.info["data_path"] @property - def videos_path(self) -> str | None: + def video_path(self) -> str | None: """Formattable string for the video files.""" - return self.info["videos"]["videos_path"] if len(self.video_keys) > 0 else None + return self.info["video_path"] @property def fps(self) -> int: """Frames per second used during data collection.""" return self.info["fps"] + @property + def features(self) -> dict[str, dict]: + """""" + return self.info["features"] + @property def keys(self) -> list[str]: """Keys to access non-image data (state, actions etc.).""" @@ -145,22 +151,27 @@ class LeRobotDatasetMetadata: @property def image_keys(self) -> list[str]: """Keys to access visual modalities stored as images.""" - return self.info["image_keys"] + return [key for key, ft in self.features.items() if ft["dtype"] == "image"] @property def video_keys(self) -> list[str]: """Keys to access visual modalities stored as videos.""" - return self.info["video_keys"] + return [key for key, ft in self.features.items() if ft["dtype"] == "video"] @property def camera_keys(self) -> list[str]: """Keys to access visual modalities (regardless of their storage method).""" - return self.image_keys + self.video_keys + return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]] @property - def names(self) -> dict[list[str]]: + def names(self) -> dict[str, list[str]]: """Names of the various dimensions of vector modalities.""" - return self.info["names"] + return {key: ft["names"] for key, ft in self.features.items()} + + @property + def shapes(self) -> dict: + """Shapes for the different features.""" + return {key: tuple(ft["shape"]) for key, ft in self.features.items()} @property def total_episodes(self) -> int: @@ -187,11 +198,6 @@ class LeRobotDatasetMetadata: """Max number of episodes per chunk.""" return self.info["chunks_size"] - @property - def shapes(self) -> dict: - """Shapes for the different features.""" - return self.info["shapes"] - @property def task_to_task_index(self) -> dict: return {task: task_idx for task_idx, task in self.tasks.items()} @@ -253,45 +259,33 @@ class LeRobotDatasetMetadata: root: Path | None = None, robot: Robot | None = None, robot_type: str | None = None, - keys: list[str] | None = None, - image_keys: list[str] | None = None, - video_keys: list[str] = None, - shapes: dict | None = None, - names: dict | None = None, + features: dict | None = None, use_videos: bool = True, ) -> "LeRobotDatasetMetadata": """Creates metadata for a LeRobotDataset.""" obj = cls.__new__(cls) obj.repo_id = repo_id obj.root = root if root is not None else LEROBOT_HOME / repo_id - obj.image_writer = None if robot is not None: - robot_type, keys, image_keys, video_keys, shapes, names = _get_info_from_robot(robot, use_videos) + features = get_features_from_robot(robot) + robot_type = robot.robot_type if not all(cam.fps == fps for cam in robot.cameras.values()): logging.warning( f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset." - "In this case, frames from lower fps cameras will be repeated to fill in the blanks" + "In this case, frames from lower fps cameras will be repeated to fill in the blanks." ) - elif ( - robot_type is None - or keys is None - or image_keys is None - or video_keys is None - or shapes is None - or names is None - ): + elif robot_type is None or features is None: raise ValueError( - "Dataset info (robot_type, keys, shapes...) must either come from a Robot or explicitly passed upon creation." + "Dataset features must either come from a Robot or explicitly passed upon creation." ) - - if len(video_keys) > 0 and not use_videos: - raise ValueError() + else: + features = {**features, **DEFAULT_FEATURES} obj.tasks, obj.stats, obj.episodes = {}, {}, [] - obj.info = create_empty_dataset_info( - CODEBASE_VERSION, fps, robot_type, keys, image_keys, video_keys, shapes, names - ) + obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos) + if len(obj.video_keys) > 0 and not use_videos: + raise ValueError() write_json(obj.info, obj.root / INFO_PATH) obj.local_files_only = True return obj @@ -509,6 +503,7 @@ class LeRobotDataset(torch.utils.data.Dataset): hf_dataset = load_dataset("parquet", data_files=files, split="train") hf_dataset.set_transform(hf_transform_to_torch) + # return hf_dataset.with_format("torch") TODO return hf_dataset @property @@ -662,8 +657,7 @@ class LeRobotDataset(torch.utils.data.Dataset): "task_index": None, "frame_index": [], "timestamp": [], - "next.done": [], - **{key: [] for key in self.meta.keys}, + **{key: [] for key in self.meta.features}, **{key: [] for key in self.meta.image_keys}, } @@ -845,7 +839,13 @@ class LeRobotDataset(torch.utils.data.Dataset): @classmethod def create( cls, - metadata: LeRobotDatasetMetadata, + repo_id: str, + fps: int, + root: Path | None = None, + robot: Robot | None = None, + robot_type: str | None = None, + features: dict | None = None, + use_videos: bool = True, tolerance_s: float = 1e-4, image_writer_processes: int = 0, image_writer_threads: int = 0, @@ -853,7 +853,15 @@ class LeRobotDataset(torch.utils.data.Dataset): ) -> "LeRobotDataset": """Create a LeRobot Dataset from scratch in order to record data.""" obj = cls.__new__(cls) - obj.meta = metadata + obj.meta = LeRobotDatasetMetadata.create( + repo_id=repo_id, + fps=fps, + root=root, + robot=robot, + robot_type=robot_type, + features=features, + use_videos=use_videos, + ) obj.repo_id = obj.meta.repo_id obj.root = obj.meta.root obj.local_files_only = obj.meta.local_files_only diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index daebb505..eef319d9 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -48,6 +48,14 @@ This dataset was created using [LeRobot](https://github.com/huggingface/lerobot) """ +DEFAULT_FEATURES = { + "timestamp": {"dtype": "float32", "shape": (1,), "names": None}, + "frame_index": {"dtype": "int64", "shape": (1,), "names": None}, + "episode_index": {"dtype": "int64", "shape": (1,), "names": None}, + "index": {"dtype": "int64", "shape": (1,), "names": None}, + "task_index": {"dtype": "int64", "shape": (1,), "names": None}, +} + def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict: """Flatten a nested dictionary structure by collapsing nested keys into one key with a separator. @@ -214,39 +222,25 @@ def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) -> return version -def _get_info_from_robot(robot: Robot, use_videos: bool) -> tuple[list | dict]: - shapes = {key: len(names) for key, names in robot.names.items()} - camera_shapes = {} - for key, cam in robot.cameras.items(): - video_key = f"observation.images.{key}" - camera_shapes[video_key] = { - "width": cam.width, - "height": cam.height, - "channels": cam.channels, +def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict: + camera_ft = {} + if robot.cameras: + camera_ft = { + key: {"dtype": "video" if use_videos else "image", **ft} + for key, ft in robot.camera_features.items() } - keys = list(robot.names) - image_keys = [] if use_videos else list(camera_shapes) - video_keys = list(camera_shapes) if use_videos else [] - shapes = {**shapes, **camera_shapes} - names = robot.names - robot_type = robot.robot_type - - return robot_type, keys, image_keys, video_keys, shapes, names + return {**robot.motor_features, **camera_ft, **DEFAULT_FEATURES} def create_empty_dataset_info( codebase_version: str, fps: int, robot_type: str, - keys: list[str], - image_keys: list[str], - video_keys: list[str], - shapes: dict, - names: dict, + features: dict, + use_videos: bool, ) -> dict: return { "codebase_version": codebase_version, - "data_path": DEFAULT_PARQUET_PATH, "robot_type": robot_type, "total_episodes": 0, "total_frames": 0, @@ -256,12 +250,9 @@ def create_empty_dataset_info( "chunks_size": DEFAULT_CHUNK_SIZE, "fps": fps, "splits": {}, - "keys": keys, - "video_keys": video_keys, - "image_keys": image_keys, - "shapes": shapes, - "names": names, - "videos": {"videos_path": DEFAULT_VIDEO_PATH} if len(video_keys) > 0 else None, + "data_path": DEFAULT_PARQUET_PATH, + "video_path": DEFAULT_VIDEO_PATH if use_videos else None, + "features": features, } @@ -400,6 +391,12 @@ def create_lerobot_dataset_card( tags: list | None = None, text: str | None = None, info: dict | None = None ) -> DatasetCard: card = DatasetCard(DATASET_CARD_TEMPLATE) + card.data.configs = [ + { + "config_name": "default", + "data_files": "data/*/*.parquet", + } + ] card.data.task_categories = ["robotics"] card.data.tags = ["LeRobot"] if tags is not None: diff --git a/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py b/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py index 10312272..8432d609 100644 --- a/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py +++ b/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py @@ -106,6 +106,7 @@ import json import math import shutil import subprocess +import tempfile import warnings from pathlib import Path @@ -137,9 +138,8 @@ from lerobot.common.datasets.utils import ( ) from lerobot.common.datasets.video_utils import ( VideoFrame, # noqa: F401 - get_image_shapes, + get_image_pixel_channels, get_video_info, - get_video_shapes, ) from lerobot.common.utils.utils import init_hydra_config @@ -202,21 +202,37 @@ def convert_stats_to_json(v1_dir: Path, v2_dir: Path) -> None: torch.testing.assert_close(stats_json[key], stats[key]) -def get_keys(dataset: Dataset) -> dict[str, list]: - sequence_keys, image_keys, video_keys = [], [], [] +def get_features_from_hf_dataset(dataset: Dataset, robot_config: dict | None = None) -> dict[str, list]: + features = {} for key, ft in dataset.features.items(): + if isinstance(ft, datasets.Value): + dtype = ft.dtype + shape = (1,) + names = None if isinstance(ft, datasets.Sequence): - sequence_keys.append(key) + assert isinstance(ft.feature, datasets.Value) + dtype = ft.feature.dtype + shape = (ft.length,) + names = robot_config["names"][key] if robot_config else [f"motor_{i}" for i in range(ft.length)] + assert len(names) == shape[0] elif isinstance(ft, datasets.Image): - image_keys.append(key) + dtype = "image" + image = dataset[0][key] # Assuming first row + channels = get_image_pixel_channels(image) + shape = (image.width, image.height, channels) + names = ["width", "height", "channel"] elif ft._type == "VideoFrame": - video_keys.append(key) + dtype = "video" + shape = None # Add shape later + names = ["width", "height", "channel"] - return { - "sequence": sequence_keys, - "image": image_keys, - "video": video_keys, - } + features[key] = { + "dtype": dtype, + "shape": shape, + "names": names, + } + + return features def add_task_index_by_episodes(dataset: Dataset, tasks_by_episodes: dict) -> tuple[Dataset, list[str]]: @@ -259,17 +275,15 @@ def add_task_index_from_tasks_col( def split_parquet_by_episodes( dataset: Dataset, - keys: dict[str, list], total_episodes: int, total_chunks: int, output_dir: Path, ) -> list: - table = dataset.remove_columns(keys["video"])._data.table + table = dataset.data.table episode_lengths = [] for ep_chunk in range(total_chunks): ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes) - chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format(episode_chunk=ep_chunk) (output_dir / chunk_dir).mkdir(parents=True, exist_ok=True) for ep_idx in range(ep_chunk_start, ep_chunk_end): @@ -396,27 +410,22 @@ def _get_lfs_untracked_videos(work_dir: Path, video_files: list[str]) -> list[st def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch: str) -> dict: - hub_api = HfApi() - videos_info_dict = {"videos_path": DEFAULT_VIDEO_PATH} - # Assumes first episode video_files = [ DEFAULT_VIDEO_PATH.format(episode_chunk=0, video_key=vid_key, episode_index=0) for vid_key in video_keys ] + hub_api = HfApi() hub_api.snapshot_download( repo_id=repo_id, repo_type="dataset", local_dir=local_dir, revision=branch, allow_patterns=video_files ) + videos_info_dict = {} for vid_key, vid_path in zip(video_keys, video_files, strict=True): videos_info_dict[vid_key] = get_video_info(local_dir / vid_path) return videos_info_dict -def get_generic_motor_names(sequence_shapes: dict) -> dict: - return {key: [f"motor_{i}" for i in range(length)] for key, length in sequence_shapes.items()} - - def convert_dataset( repo_id: str, local_dir: Path, @@ -443,7 +452,8 @@ def convert_dataset( metadata_v1 = load_json(v1x_dir / V1_INFO_PATH) dataset = datasets.load_dataset("parquet", data_dir=v1x_dir / "data", split="train") - keys = get_keys(dataset) + features = get_features_from_hf_dataset(dataset, robot_config) + video_keys = [key for key, ft in features.items() if ft["dtype"] == "video"] if single_task and "language_instruction" in dataset.column_names: warnings.warn( @@ -457,7 +467,7 @@ def convert_dataset( episode_indices = sorted(dataset.unique("episode_index")) total_episodes = len(episode_indices) assert episode_indices == list(range(total_episodes)) - total_videos = total_episodes * len(keys["video"]) + total_videos = total_episodes * len(video_keys) total_chunks = total_episodes // DEFAULT_CHUNK_SIZE if total_episodes % DEFAULT_CHUNK_SIZE != 0: total_chunks += 1 @@ -470,7 +480,6 @@ def convert_dataset( elif tasks_path: tasks_by_episodes = load_json(tasks_path) tasks_by_episodes = {int(ep_idx): task for ep_idx, task in tasks_by_episodes.items()} - # tasks = list(set(tasks_by_episodes.values())) dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes) tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()} elif tasks_col: @@ -481,56 +490,50 @@ def convert_dataset( assert set(tasks) == {task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks} tasks = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)] write_jsonlines(tasks, v20_dir / TASKS_PATH) - - # Shapes - sequence_shapes = {key: dataset.features[key].length for key in keys["sequence"]} - image_shapes = get_image_shapes(dataset, keys["image"]) if len(keys["image"]) > 0 else {} + features["task_index"] = { + "dtype": "int64", + "shape": (1,), + "names": None, + } # Videos - if len(keys["video"]) > 0: + if video_keys: assert metadata_v1.get("video", False) - tmp_video_dir = local_dir / "videos" / V20 / repo_id - tmp_video_dir.mkdir(parents=True, exist_ok=True) + dataset = dataset.remove_columns(video_keys) clean_gitattr = Path( hub_api.hf_hub_download( repo_id=GITATTRIBUTES_REF, repo_type="dataset", local_dir=local_dir, filename=".gitattributes" ) ).absolute() - move_videos( - repo_id, keys["video"], total_episodes, total_chunks, tmp_video_dir, clean_gitattr, branch - ) - videos_info = get_videos_info(repo_id, v1x_dir, video_keys=keys["video"], branch=branch) - video_shapes = get_video_shapes(videos_info, keys["video"]) - for img_key in keys["video"]: - assert math.isclose(videos_info[img_key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3) + with tempfile.TemporaryDirectory() as tmp_video_dir: + move_videos( + repo_id, video_keys, total_episodes, total_chunks, Path(tmp_video_dir), clean_gitattr, branch + ) + videos_info = get_videos_info(repo_id, v1x_dir, video_keys=video_keys, branch=branch) + for key in video_keys: + features[key]["shape"] = ( + videos_info[key].pop("video.width"), + videos_info[key].pop("video.height"), + videos_info[key].pop("video.channels"), + ) + features[key]["video_info"] = videos_info[key] + assert math.isclose(videos_info[key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3) if "encoding" in metadata_v1: - assert videos_info[img_key]["video.pix_fmt"] == metadata_v1["encoding"]["pix_fmt"] + assert videos_info[key]["video.pix_fmt"] == metadata_v1["encoding"]["pix_fmt"] else: assert metadata_v1.get("video", 0) == 0 videos_info = None - video_shapes = {} # Split data into 1 parquet file by episode - episode_lengths = split_parquet_by_episodes(dataset, keys, total_episodes, total_chunks, v20_dir) + episode_lengths = split_parquet_by_episodes(dataset, total_episodes, total_chunks, v20_dir) - # Names if robot_config is not None: robot_type = robot_config["robot_type"] - names = robot_config["names"] - if "observation.effort" in keys["sequence"]: - names["observation.effort"] = names["observation.state"] - if "observation.velocity" in keys["sequence"]: - names["observation.velocity"] = names["observation.state"] repo_tags = [robot_type] else: robot_type = "unknown" - names = get_generic_motor_names(sequence_shapes) repo_tags = None - assert set(names) == set(keys["sequence"]) - for key in sequence_shapes: - assert len(names[key]) == sequence_shapes[key] - # Episodes episodes = [ {"episode_index": ep_idx, "tasks": tasks_by_episodes[ep_idx], "length": episode_lengths[ep_idx]} @@ -541,7 +544,6 @@ def convert_dataset( # Assemble metadata v2.0 metadata_v2_0 = { "codebase_version": V20, - "data_path": DEFAULT_PARQUET_PATH, "robot_type": robot_type, "total_episodes": total_episodes, "total_frames": len(dataset), @@ -551,15 +553,13 @@ def convert_dataset( "chunks_size": DEFAULT_CHUNK_SIZE, "fps": metadata_v1["fps"], "splits": {"train": f"0:{total_episodes}"}, - "keys": keys["sequence"], - "video_keys": keys["video"], - "image_keys": keys["image"], - "shapes": {**sequence_shapes, **video_shapes, **image_shapes}, - "names": names, - "videos": videos_info, + "data_path": DEFAULT_PARQUET_PATH, + "video_path": DEFAULT_VIDEO_PATH if video_keys else None, + "features": features, } write_json(metadata_v2_0, v20_dir / INFO_PATH) convert_stats_to_json(v1x_dir, v20_dir) + card = create_lerobot_dataset_card(tags=repo_tags, info=metadata_v2_0) with contextlib.suppress(EntryNotFoundError): hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch) @@ -585,28 +585,11 @@ def convert_dataset( revision=branch, ) - card = create_lerobot_dataset_card(tags=repo_tags, info=metadata_v2_0) card.push_to_hub(repo_id=repo_id, repo_type="dataset", revision=branch) if not test_branch: create_branch(repo_id=repo_id, branch=V20, repo_type="dataset") - # TODO: - # - [X] Add shapes - # - [X] Add keys - # - [X] Add paths - # - [X] convert stats.json - # - [X] Add task.json - # - [X] Add names - # - [X] Add robot_type - # - [X] Add splits - # - [X] Push properly to branch v2.0 and delete v1.6 stuff from that branch - # - [X] Handle multitask datasets - # - [X] Handle hf hub repo limits (add chunks logic) - # - [X] Add test-branch - # - [X] Use jsonlines for episodes - # - [X] Add sanity checks (encoding, shapes) - def main(): parser = argparse.ArgumentParser() diff --git a/lerobot/common/datasets/video_utils.py b/lerobot/common/datasets/video_utils.py index 48f22435..80cc79cc 100644 --- a/lerobot/common/datasets/video_utils.py +++ b/lerobot/common/datasets/video_utils.py @@ -25,7 +25,6 @@ from typing import Any, ClassVar import pyarrow as pa import torch import torchvision -from datasets import Dataset from datasets.features.features import register_feature from PIL import Image @@ -292,33 +291,6 @@ def get_video_info(video_path: Path | str) -> dict: return video_info -def get_video_shapes(videos_info: dict, video_keys: list) -> dict: - video_shapes = {} - for img_key in video_keys: - channels = get_video_pixel_channels(videos_info[img_key]["video.pix_fmt"]) - video_shapes[img_key] = { - "width": videos_info[img_key]["video.width"], - "height": videos_info[img_key]["video.height"], - "channels": channels, - } - - return video_shapes - - -def get_image_shapes(dataset: Dataset, image_keys: list) -> dict: - image_shapes = {} - for img_key in image_keys: - image = dataset[0][img_key] # Assuming first row - channels = get_image_pixel_channels(image) - image_shapes[img_key] = { - "width": image.width, - "height": image.height, - "channels": channels, - } - - return image_shapes - - def get_video_pixel_channels(pix_fmt: str) -> int: if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt: return 1 diff --git a/lerobot/common/robot_devices/robots/manipulator.py b/lerobot/common/robot_devices/robots/manipulator.py index 3385e7bb..6bdad3e6 100644 --- a/lerobot/common/robot_devices/robots/manipulator.py +++ b/lerobot/common/robot_devices/robots/manipulator.py @@ -226,13 +226,42 @@ class ManipulatorRobot: self.is_connected = False self.logs = {} - action_names = [f"{arm}_{motor}" for arm, bus in self.leader_arms.items() for motor in bus.motors] - state_names = [f"{arm}_{motor}" for arm, bus in self.follower_arms.items() for motor in bus.motors] - self.names = { - "action": action_names, - "observation.state": state_names, + def get_motor_names(self, arm: dict[str, MotorsBus]) -> list: + return [f"{arm}_{motor}" for arm, bus in arm.items() for motor in bus.motors] + + @property + def camera_features(self) -> dict: + cam_ft = {} + for cam_key, cam in self.cameras.items(): + key = f"observation.images.{cam_key}" + cam_ft[key] = { + "shape": (cam.width, cam.height, cam.channels), + "names": ["width", "height", "channels"], + "info": None, + } + return cam_ft + + @property + def motor_features(self) -> dict: + action_names = self.get_motor_names(self.leader_arms) + state_names = self.get_motor_names(self.leader_arms) + return { + "action": { + "dtype": "float32", + "shape": (len(action_names),), + "names": action_names, + }, + "observation.state": { + "dtype": "float32", + "shape": (len(state_names),), + "names": state_names, + }, } + @property + def features(self): + return {**self.motor_features, **self.camera_features} + @property def has_camera(self): return len(self.cameras) > 0 diff --git a/lerobot/common/robot_devices/robots/utils.py b/lerobot/common/robot_devices/robots/utils.py index 5cd5bd10..a40db131 100644 --- a/lerobot/common/robot_devices/robots/utils.py +++ b/lerobot/common/robot_devices/robots/utils.py @@ -11,6 +11,7 @@ def get_arm_id(name, arm_type): class Robot(Protocol): # TODO(rcadene, aliberts): Add unit test checking the protocol is implemented in the corresponding classes robot_type: str + features: dict def connect(self): ... def run_calibration(self): ... diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index a0841d00..e6218787 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -105,7 +105,7 @@ from pathlib import Path from typing import List # from safetensors.torch import load_file, save_file -from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.robot_devices.control_utils import ( control_loop, has_method, @@ -234,15 +234,12 @@ def record( # Create empty dataset or load existing saved episodes sanity_check_dataset_name(repo_id, policy) - dataset_metadata = LeRobotDatasetMetadata.create( + dataset = LeRobotDataset.create( repo_id, fps, root=root, robot=robot, use_videos=video, - ) - dataset = LeRobotDataset.create( - dataset_metadata, image_writer_processes=num_image_writer_processes, image_writer_threads=num_image_writer_threads_per_camera, )