From e4ba084e259cdec3cccae617806a4ebd0c0154d0 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Sun, 3 Nov 2024 18:07:37 +0100 Subject: [PATCH] Add LeRobotDatasetMetadata --- benchmarks/video/run_video_benchmark.py | 2 +- examples/1_load_lerobot_dataset.py | 10 +- examples/3_train_policy.py | 2 +- examples/6_add_image_transforms.py | 4 +- .../advanced/2_calculate_validation_loss.py | 1 + lerobot/common/datasets/compute_stats.py | 2 +- lerobot/common/datasets/factory.py | 2 +- lerobot/common/datasets/lerobot_dataset.py | 569 +++++++++--------- lerobot/common/datasets/utils.py | 2 +- lerobot/scripts/control_robot.py | 12 +- lerobot/scripts/eval.py | 2 +- lerobot/scripts/train.py | 2 +- lerobot/scripts/visualize_dataset.py | 2 +- lerobot/scripts/visualize_dataset_html.py | 11 +- lerobot/scripts/visualize_image_transforms.py | 2 +- tests/fixtures/dataset_factories.py | 67 ++- tests/fixtures/files.py | 4 +- tests/fixtures/hub.py | 8 +- .../save_image_transforms_to_safetensors.py | 2 +- tests/scripts/save_policy_to_safetensors.py | 2 +- tests/test_control_robot.py | 4 +- tests/test_datasets.py | 23 +- tests/test_examples.py | 3 +- tests/test_policies.py | 7 +- tests/test_push_dataset_to_hub.py | 1 + 25 files changed, 419 insertions(+), 327 deletions(-) diff --git a/benchmarks/video/run_video_benchmark.py b/benchmarks/video/run_video_benchmark.py index 46806c07..e9066487 100644 --- a/benchmarks/video/run_video_benchmark.py +++ b/benchmarks/video/run_video_benchmark.py @@ -266,7 +266,7 @@ def benchmark_encoding_decoding( ) ep_num_images = dataset.episode_data_index["to"][0].item() - width, height = tuple(dataset[0][dataset.camera_keys[0]].shape[-2:]) + width, height = tuple(dataset[0][dataset.meta.camera_keys[0]].shape[-2:]) num_pixels = width * height video_size_bytes = video_path.stat().st_size images_size_bytes = get_directory_size(imgs_dir) diff --git a/examples/1_load_lerobot_dataset.py b/examples/1_load_lerobot_dataset.py index 9f291dc5..2647078c 100644 --- a/examples/1_load_lerobot_dataset.py +++ b/examples/1_load_lerobot_dataset.py @@ -13,6 +13,7 @@ Features included in this script: The script ends with examples of how to batch process data using PyTorch's DataLoader. """ +# TODO(aliberts, rcadene): Update this script with the new v2 api from pathlib import Path from pprint import pprint @@ -31,7 +32,7 @@ repo_id = "lerobot/pusht" # You can easily load a dataset from a Hugging Face repository dataset = LeRobotDataset(repo_id) -# LeRobotDataset is actually a thin wrapper around an underlying Hugging Face dataset +# LeRobotDataset actually wraps an underlying Hugging Face dataset # (see https://huggingface.co/docs/datasets/index for more information). print(dataset) print(dataset.hf_dataset) @@ -39,7 +40,7 @@ print(dataset.hf_dataset) # And provides additional utilities for robotics and compatibility with Pytorch print(f"\naverage number of frames per episode: {dataset.num_frames / dataset.num_episodes:.3f}") print(f"frames per second used during data collection: {dataset.fps=}") -print(f"keys to access images from cameras: {dataset.camera_keys=}\n") +print(f"keys to access images from cameras: {dataset.meta.camera_keys=}\n") # Access frame indexes associated to first episode episode_index = 0 @@ -60,14 +61,15 @@ frames = [frame.permute((1, 2, 0)).numpy() for frame in frames] Path("outputs/examples/1_load_lerobot_dataset").mkdir(parents=True, exist_ok=True) imageio.mimsave("outputs/examples/1_load_lerobot_dataset/episode_0.mp4", frames, fps=dataset.fps) + # For many machine learning applications we need to load the history of past observations or trajectories of # future actions. Our datasets can load previous and future frames for each key/modality, using timestamps # differences with the current loaded frame. For instance: delta_timestamps = { # loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame "observation.image": [-1, -0.5, -0.20, 0], - # loads 8 state vectors: 1.5 seconds before, 1 second before, ... 20 ms, 10 ms, and current frame - "observation.state": [-1.5, -1, -0.5, -0.20, -0.10, -0.02, -0.01, 0], + # loads 8 state vectors: 1.5 seconds before, 1 second before, ... 200 ms, 100 ms, and current frame + "observation.state": [-1.5, -1, -0.5, -0.20, -0.10, 0], # loads 64 action vectors: current frame, 1 frame in the future, 2 frames, ... 63 frames in the future "action": [t / dataset.fps for t in range(64)], } diff --git a/examples/3_train_policy.py b/examples/3_train_policy.py index c5ce0d18..935ab2db 100644 --- a/examples/3_train_policy.py +++ b/examples/3_train_policy.py @@ -40,7 +40,7 @@ dataset = LeRobotDataset("lerobot/pusht", delta_timestamps=delta_timestamps) # For this example, no arguments need to be passed because the defaults are set up for PushT. # If you're doing something different, you will likely need to change at least some of the defaults. cfg = DiffusionConfig() -policy = DiffusionPolicy(cfg, dataset_stats=dataset.stats) +policy = DiffusionPolicy(cfg, dataset_stats=dataset.meta.stats) policy.train() policy.to(device) diff --git a/examples/6_add_image_transforms.py b/examples/6_add_image_transforms.py index bdcc6d7b..50465287 100644 --- a/examples/6_add_image_transforms.py +++ b/examples/6_add_image_transforms.py @@ -20,7 +20,7 @@ dataset = LeRobotDataset(dataset_repo_id) first_idx = dataset.episode_data_index["from"][0].item() # Get the frame corresponding to the first camera -frame = dataset[first_idx][dataset.camera_keys[0]] +frame = dataset[first_idx][dataset.meta.camera_keys[0]] # Define the transformations @@ -36,7 +36,7 @@ transforms = v2.Compose( transformed_dataset = LeRobotDataset(dataset_repo_id, image_transforms=transforms) # Get a frame from the transformed dataset -transformed_frame = transformed_dataset[first_idx][transformed_dataset.camera_keys[0]] +transformed_frame = transformed_dataset[first_idx][transformed_dataset.meta.camera_keys[0]] # Create a directory to store output images output_dir = Path("outputs/image_transforms") diff --git a/examples/advanced/2_calculate_validation_loss.py b/examples/advanced/2_calculate_validation_loss.py index 1428014b..b312b7d0 100644 --- a/examples/advanced/2_calculate_validation_loss.py +++ b/examples/advanced/2_calculate_validation_loss.py @@ -8,6 +8,7 @@ especially in the context of imitation learning. The most reliable approach is t on the target environment, whether that be in simulation or the real world. """ +# TODO(aliberts, rcadene): Update this script with the new v2 api import math from pathlib import Path diff --git a/lerobot/common/datasets/compute_stats.py b/lerobot/common/datasets/compute_stats.py index c06c74de..e773bd30 100644 --- a/lerobot/common/datasets/compute_stats.py +++ b/lerobot/common/datasets/compute_stats.py @@ -42,7 +42,7 @@ def get_stats_einops_patterns(dataset, num_workers=0): assert batch[key].dtype != torch.float64 # if isinstance(feats_type, (VideoFrame, Image)): - if key in dataset.camera_keys: + if key in dataset.meta.camera_keys: # sanity check that images are channel first _, c, h, w = batch[key].shape assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}" diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 04b6e57b..f6164ed1 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -111,6 +111,6 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData for stats_type, listconfig in stats_dict.items(): # example of stats_type: min, max, mean, std stats = OmegaConf.to_container(listconfig, resolve=True) - dataset.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32) + dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32) return dataset diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index bbeb25d6..f5932b7e 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -45,7 +45,7 @@ from lerobot.common.datasets.utils import ( get_episode_data_index, get_hub_safe_version, hf_transform_to_torch, - load_episode_dicts, + load_episodes, load_info, load_stats, load_tasks, @@ -66,6 +66,237 @@ CODEBASE_VERSION = "v2.0" LEROBOT_HOME = Path(os.getenv("LEROBOT_HOME", "~/.cache/huggingface/lerobot")).expanduser() +class LeRobotDatasetMetadata: + def __init__( + self, + repo_id: str, + root: Path | None = None, + local_files_only: bool = False, + ): + self.repo_id = repo_id + self.root = root if root is not None else LEROBOT_HOME / repo_id + self.local_files_only = local_files_only + + # Load metadata + (self.root / "meta").mkdir(exist_ok=True, parents=True) + self.pull_from_repo(allow_patterns="meta/") + self.info = load_info(self.root) + self.stats = load_stats(self.root) + self.tasks = load_tasks(self.root) + self.episodes = load_episodes(self.root) + + def pull_from_repo( + self, + allow_patterns: list[str] | str | None = None, + ignore_patterns: list[str] | str | None = None, + ) -> None: + snapshot_download( + self.repo_id, + repo_type="dataset", + revision=self._hub_version, + local_dir=self.root, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + local_files_only=self.local_files_only, + ) + + @cached_property + def _hub_version(self) -> str | None: + return None if self.local_files_only else get_hub_safe_version(self.repo_id, CODEBASE_VERSION) + + @property + def _version(self) -> str: + """Codebase version used to create this dataset.""" + return self.info["codebase_version"] + + def get_data_file_path(self, ep_index: int) -> Path: + ep_chunk = self.get_episode_chunk(ep_index) + fpath = self.data_path.format(episode_chunk=ep_chunk, episode_index=ep_index) + return Path(fpath) + + def get_video_file_path(self, ep_index: int, vid_key: str) -> Path: + ep_chunk = self.get_episode_chunk(ep_index) + fpath = self.videos_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index) + return Path(fpath) + + def get_episode_chunk(self, ep_index: int) -> int: + return ep_index // self.chunks_size + + @property + def data_path(self) -> str: + """Formattable string for the parquet files.""" + return self.info["data_path"] + + @property + def videos_path(self) -> str | None: + """Formattable string for the video files.""" + return self.info["videos"]["videos_path"] if len(self.video_keys) > 0 else None + + @property + def fps(self) -> int: + """Frames per second used during data collection.""" + return self.info["fps"] + + @property + def keys(self) -> list[str]: + """Keys to access non-image data (state, actions etc.).""" + return self.info["keys"] + + @property + def image_keys(self) -> list[str]: + """Keys to access visual modalities stored as images.""" + return self.info["image_keys"] + + @property + def video_keys(self) -> list[str]: + """Keys to access visual modalities stored as videos.""" + return self.info["video_keys"] + + @property + def camera_keys(self) -> list[str]: + """Keys to access visual modalities (regardless of their storage method).""" + return self.image_keys + self.video_keys + + @property + def names(self) -> dict[list[str]]: + """Names of the various dimensions of vector modalities.""" + return self.info["names"] + + @property + def total_episodes(self) -> int: + """Total number of episodes available.""" + return self.info["total_episodes"] + + @property + def total_frames(self) -> int: + """Total number of frames saved in this dataset.""" + return self.info["total_frames"] + + @property + def total_tasks(self) -> int: + """Total number of different tasks performed in this dataset.""" + return self.info["total_tasks"] + + @property + def total_chunks(self) -> int: + """Total number of chunks (groups of episodes).""" + return self.info["total_chunks"] + + @property + def chunks_size(self) -> int: + """Max number of episodes per chunk.""" + return self.info["chunks_size"] + + @property + def shapes(self) -> dict: + """Shapes for the different features.""" + return self.info["shapes"] + + @property + def task_to_task_index(self) -> dict: + return {task: task_idx for task_idx, task in self.tasks.items()} + + def get_task_index(self, task: str) -> int: + """ + Given a task in natural language, returns its task_index if the task already exists in the dataset, + otherwise creates a new task_index. + """ + task_index = self.task_to_task_index.get(task, None) + return task_index if task_index is not None else self.total_tasks + + def add_episode(self, episode_index: int, episode_length: int, task: str, task_index: int) -> None: + self.info["total_episodes"] += 1 + self.info["total_frames"] += episode_length + + if task_index not in self.tasks: + self.info["total_tasks"] += 1 + self.tasks[task_index] = task + task_dict = { + "task_index": task_index, + "task": task, + } + append_jsonlines(task_dict, self.root / TASKS_PATH) + + chunk = self.get_episode_chunk(episode_index) + if chunk >= self.total_chunks: + self.info["total_chunks"] += 1 + + self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"} + self.info["total_videos"] += len(self.video_keys) + write_json(self.info, self.root / INFO_PATH) + + episode_dict = { + "episode_index": episode_index, + "tasks": [task], + "length": episode_length, + } + self.episodes.append(episode_dict) + append_jsonlines(episode_dict, self.root / EPISODES_PATH) + + def write_video_info(self) -> None: + """ + Warning: this function writes info from first episode videos, implicitly assuming that all videos have + been encoded the same way. Also, this means it assumes the first episode exists. + """ + for key in self.video_keys: + if key not in self.info["videos"]: + video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key) + self.info["videos"][key] = get_video_info(video_path) + + write_json(self.info, self.root / INFO_PATH) + + @classmethod + def create( + cls, + repo_id: str, + fps: int, + root: Path | None = None, + robot: Robot | None = None, + robot_type: str | None = None, + keys: list[str] | None = None, + image_keys: list[str] | None = None, + video_keys: list[str] = None, + shapes: dict | None = None, + names: dict | None = None, + use_videos: bool = True, + ) -> "LeRobotDatasetMetadata": + """Creates metadata for a LeRobotDataset.""" + obj = cls.__new__(cls) + obj.repo_id = repo_id + obj.root = root if root is not None else LEROBOT_HOME / repo_id + obj.image_writer = None + + if robot is not None: + robot_type, keys, image_keys, video_keys, shapes, names = _get_info_from_robot(robot, use_videos) + if not all(cam.fps == fps for cam in robot.cameras.values()): + logging.warning( + f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset." + "In this case, frames from lower fps cameras will be repeated to fill in the blanks" + ) + elif ( + robot_type is None + or keys is None + or image_keys is None + or video_keys is None + or shapes is None + or names is None + ): + raise ValueError( + "Dataset info (robot_type, keys, shapes...) must either come from a Robot or explicitly passed upon creation." + ) + + if len(video_keys) > 0 and not use_videos: + raise ValueError() + + obj.tasks, obj.stats, obj.episodes = {}, {}, [] + obj.info = create_empty_dataset_info( + CODEBASE_VERSION, fps, robot_type, keys, image_keys, video_keys, shapes, names + ) + write_json(obj.info, obj.root / INFO_PATH) + obj.local_files_only = True + return obj + + class LeRobotDataset(torch.utils.data.Dataset): def __init__( self, @@ -86,9 +317,9 @@ class LeRobotDataset(torch.utils.data.Dataset): - On your local disk in the 'root' folder. This is typically the case when you recorded your dataset locally and you may or may not have pushed it to the hub yet. Instantiating this class with 'root' will load your dataset directly from disk. This can happen while you're offline (no - internet connection). + internet connection), in that case, use local_files_only=True. - - On the Hugging Face Hub at the address https://huggingface.co/datasets/{repo_id} and is not on + - On the Hugging Face Hub at the address https://huggingface.co/datasets/{repo_id} and not on your local disk in the 'root' folder. Instantiating this class with this 'repo_id' will download the dataset from that address and load it, pending your dataset is compliant with codebase_version v2.0. If your dataset has been created before this new format, you will be @@ -96,9 +327,9 @@ class LeRobotDataset(torch.utils.data.Dataset): lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py. - 2. Your dataset doesn't already exists (either on local disk or on the Hub): - You can create an empty LeRobotDataset with the 'create' classmethod. This can be used for - recording a dataset or port an existing dataset to the LeRobotDataset format. + 2. Your dataset doesn't already exists (either on local disk or on the Hub): you can create an empty + LeRobotDataset with the 'create' classmethod. This can be used for recording a dataset or port an + existing dataset to the LeRobotDataset format. In terms of files, LeRobotDataset encapsulates 3 main things: @@ -192,21 +423,18 @@ class LeRobotDataset(torch.utils.data.Dataset): self.image_writer = None self.episode_buffer = {} - # Load metadata self.root.mkdir(exist_ok=True, parents=True) - self.pull_from_repo(allow_patterns="meta/") - self.info = load_info(self.root) - self.stats = load_stats(self.root) - self.tasks = load_tasks(self.root) - self.episode_dicts = load_episode_dicts(self.root) + + # Load metadata + self.meta = LeRobotDatasetMetadata(self.repo_id, self.root, self.local_files_only) # Check version - check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION) + check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION) # Load actual data self.download_episodes(download_videos) self.hf_dataset = self.load_hf_dataset() - self.episode_data_index = get_episode_data_index(self.episode_dicts, self.episodes) + self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes) # Check timestamps check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s) @@ -216,26 +444,6 @@ class LeRobotDataset(torch.utils.data.Dataset): check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s) self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps) - # TODO(aliberts): - # - [X] Move delta_timestamp logic outside __get_item__ - # - [X] Update __get_item__ - # - [/] Add doc - # - [ ] Add self.add_frame() - # - [ ] Add self.consolidate() for: - # - [X] Check timestamps sync - # - [ ] Sanity checks (episodes num, shapes, files, etc.) - # - [ ] Update episode_index (arg update=True) - # - [ ] Update info.json (arg update=True) - - @cached_property - def _hub_version(self) -> str | None: - return None if self.local_files_only else get_hub_safe_version(self.repo_id, CODEBASE_VERSION) - - @property - def _version(self) -> str: - """Codebase version used to create this dataset.""" - return self.info["codebase_version"] - def push_to_hub(self, push_videos: bool = True) -> None: if not self.consolidated: raise RuntimeError( @@ -262,7 +470,7 @@ class LeRobotDataset(torch.utils.data.Dataset): snapshot_download( self.repo_id, repo_type="dataset", - revision=self._hub_version, + revision=self.meta._hub_version, local_dir=self.root, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, @@ -280,11 +488,11 @@ class LeRobotDataset(torch.utils.data.Dataset): files = None ignore_patterns = None if download_videos else "videos/" if self.episodes is not None: - files = [str(self.get_data_file_path(ep_idx)) for ep_idx in self.episodes] - if len(self.video_keys) > 0 and download_videos: + files = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes] + if len(self.meta.video_keys) > 0 and download_videos: video_files = [ - str(self.get_video_file_path(ep_idx, vid_key)) - for vid_key in self.video_keys + str(self.meta.get_video_file_path(ep_idx, vid_key)) + for vid_key in self.meta.video_keys for ep_idx in self.episodes ] files += video_files @@ -297,108 +505,30 @@ class LeRobotDataset(torch.utils.data.Dataset): path = str(self.root / "data") hf_dataset = load_dataset("parquet", data_dir=path, split="train") else: - files = [str(self.root / self.get_data_file_path(ep_idx)) for ep_idx in self.episodes] + files = [str(self.root / self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes] hf_dataset = load_dataset("parquet", data_files=files, split="train") hf_dataset.set_transform(hf_transform_to_torch) return hf_dataset - def get_data_file_path(self, ep_index: int) -> Path: - ep_chunk = self.get_episode_chunk(ep_index) - fpath = self.data_path.format(episode_chunk=ep_chunk, episode_index=ep_index) - return Path(fpath) - - def get_video_file_path(self, ep_index: int, vid_key: str) -> Path: - ep_chunk = self.get_episode_chunk(ep_index) - fpath = self.videos_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index) - return Path(fpath) - - def get_episode_chunk(self, ep_index: int) -> int: - return ep_index // self.chunks_size - - @property - def data_path(self) -> str: - """Formattable string for the parquet files.""" - return self.info["data_path"] - - @property - def videos_path(self) -> str | None: - """Formattable string for the video files.""" - return self.info["videos"]["videos_path"] if len(self.video_keys) > 0 else None - @property def fps(self) -> int: """Frames per second used during data collection.""" - return self.info["fps"] - - @property - def keys(self) -> list[str]: - """Keys to access non-image data (state, actions etc.).""" - return self.info["keys"] - - @property - def image_keys(self) -> list[str]: - """Keys to access visual modalities stored as images.""" - return self.info["image_keys"] - - @property - def video_keys(self) -> list[str]: - """Keys to access visual modalities stored as videos.""" - return self.info["video_keys"] - - @property - def camera_keys(self) -> list[str]: - """Keys to access visual modalities (regardless of their storage method).""" - return self.image_keys + self.video_keys - - @property - def names(self) -> dict[list[str]]: - """Names of the various dimensions of vector modalities.""" - return self.info["names"] + return self.meta.fps @property def num_frames(self) -> int: """Number of frames in selected episodes.""" - return len(self.hf_dataset) if self.hf_dataset is not None else self.total_frames + return len(self.hf_dataset) if self.hf_dataset is not None else self.meta.total_frames @property def num_episodes(self) -> int: """Number of episodes selected.""" - return len(self.episodes) if self.episodes is not None else self.total_episodes - - @property - def total_episodes(self) -> int: - """Total number of episodes available.""" - return self.info["total_episodes"] - - @property - def total_frames(self) -> int: - """Total number of frames saved in this dataset.""" - return self.info["total_frames"] - - @property - def total_tasks(self) -> int: - """Total number of different tasks performed in this dataset.""" - return self.info["total_tasks"] - - @property - def total_chunks(self) -> int: - """Total number of chunks (groups of episodes).""" - return self.info["total_chunks"] - - @property - def chunks_size(self) -> int: - """Max number of episodes per chunk.""" - return self.info["chunks_size"] - - @property - def shapes(self) -> dict: - """Shapes for the different features.""" - return self.info["shapes"] + return len(self.episodes) if self.episodes is not None else self.meta.total_episodes @property def features(self) -> list[str]: - return list(self._features) + self.video_keys + return list(self._features) + self.meta.video_keys @property def _features(self) -> datasets.Features: @@ -418,39 +548,15 @@ class LeRobotDataset(torch.utils.data.Dataset): features[key] = datasets.Value(dtype="bool") elif key in ["timestamp", "next.reward"]: features[key] = datasets.Value(dtype="float32") - elif key in self.image_keys: + elif key in self.meta.image_keys: features[key] = datasets.Image() - elif key in self.keys: + elif key in self.meta.keys: features[key] = datasets.Sequence( - length=self.shapes[key], feature=datasets.Value(dtype="float32") + length=self.meta.shapes[key], feature=datasets.Value(dtype="float32") ) return datasets.Features(features) - @property - def task_to_task_index(self) -> dict: - return {task: task_idx for task_idx, task in self.tasks.items()} - - def get_task_index(self, task: str) -> int: - """ - Given a task in natural language, returns its task_index if the task already exists in the dataset, - otherwise creates a new task_index. - """ - task_index = self.task_to_task_index.get(task, None) - return task_index if task_index is not None else self.total_tasks - - def current_episode_index(self, idx: int) -> int: - episode_index = self.hf_dataset["episode_index"][idx] - if self.episodes is not None: - # get episode_index from selected episodes - episode_index = self.episodes.index(episode_index) - - return episode_index - - def episode_length(self, episode_index) -> int: - """Number of samples/frames for given episode.""" - return self.info["episodes"][episode_index]["length"] - def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]: ep_start = self.episode_data_index["from"][ep_idx] ep_end = self.episode_data_index["to"][ep_idx] @@ -472,7 +578,7 @@ class LeRobotDataset(torch.utils.data.Dataset): query_indices: dict[str, list[int]] | None = None, ) -> dict[str, list[float]]: query_timestamps = {} - for key in self.video_keys: + for key in self.meta.video_keys: if query_indices is not None and key in query_indices: timestamps = self.hf_dataset.select(query_indices[key])["timestamp"] query_timestamps[key] = torch.stack(timestamps).tolist() @@ -485,7 +591,7 @@ class LeRobotDataset(torch.utils.data.Dataset): return { key: torch.stack(self.hf_dataset.select(q_idx)[key]) for key, q_idx in query_indices.items() - if key not in self.video_keys + if key not in self.meta.video_keys } def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict: @@ -496,7 +602,7 @@ class LeRobotDataset(torch.utils.data.Dataset): """ item = {} for vid_key, query_ts in query_timestamps.items(): - video_path = self.root / self.get_video_file_path(ep_idx, vid_key) + video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key) frames = decode_video_frames_torchvision( video_path, query_ts, self.tolerance_s, self.video_backend ) @@ -525,14 +631,14 @@ class LeRobotDataset(torch.utils.data.Dataset): for key, val in query_result.items(): item[key] = val - if len(self.video_keys) > 0: + if len(self.meta.video_keys) > 0: current_ts = item["timestamp"].item() query_timestamps = self._get_query_timestamps(current_ts, query_indices) video_frames = self._query_videos(query_timestamps, ep_idx) item = {**video_frames, **item} if self.image_transforms is not None: - image_keys = self.camera_keys + image_keys = self.meta.camera_keys for cam in image_keys: item[cam] = self.image_transforms(item[cam]) @@ -545,20 +651,20 @@ class LeRobotDataset(torch.utils.data.Dataset): f" Selected episodes: {self.episodes},\n" f" Number of selected episodes: {self.num_episodes},\n" f" Number of selected samples: {self.num_frames},\n" - f"\n{json.dumps(self.info, indent=4)}\n" + f"\n{json.dumps(self.meta.info, indent=4)}\n" ) def _create_episode_buffer(self, episode_index: int | None = None) -> dict: # TODO(aliberts): Handle resume return { "size": 0, - "episode_index": self.total_episodes if episode_index is None else episode_index, + "episode_index": self.meta.total_episodes if episode_index is None else episode_index, "task_index": None, "frame_index": [], "timestamp": [], "next.done": [], - **{key: [] for key in self.keys}, - **{key: [] for key in self.image_keys}, + **{key: [] for key in self.meta.keys}, + **{key: [] for key in self.meta.image_keys}, } def add_frame(self, frame: dict) -> None: @@ -573,7 +679,7 @@ class LeRobotDataset(torch.utils.data.Dataset): self.episode_buffer["next.done"].append(False) # Save all observed modalities except images - for key in self.keys: + for key in self.meta.keys: self.episode_buffer[key].append(frame[key]) self.episode_buffer["size"] += 1 @@ -582,7 +688,7 @@ class LeRobotDataset(torch.utils.data.Dataset): return # Save images - for cam_key in self.camera_keys: + for cam_key in self.meta.camera_keys: img_path = self.image_writer.get_image_file_path( episode_index=self.episode_buffer["episode_index"], image_key=cam_key, frame_index=frame_index ) @@ -594,7 +700,7 @@ class LeRobotDataset(torch.utils.data.Dataset): fpath=img_path, ) - if cam_key in self.image_keys: + if cam_key in self.meta.image_keys: self.episode_buffer[cam_key].append(str(img_path)) def add_episode(self, task: str, encode_videos: bool = False) -> None: @@ -609,17 +715,17 @@ class LeRobotDataset(torch.utils.data.Dataset): """ episode_length = self.episode_buffer.pop("size") episode_index = self.episode_buffer["episode_index"] - if episode_index != self.total_episodes: + if episode_index != self.meta.total_episodes: # TODO(aliberts): Add option to use existing episode_index raise NotImplementedError() - task_index = self.get_task_index(task) + task_index = self.meta.get_task_index(task) self.episode_buffer["next.done"][-1] = True for key in self.episode_buffer: - if key in self.image_keys: + if key in self.meta.image_keys: continue - elif key in self.keys: + elif key in self.meta.keys: self.episode_buffer[key] = torch.stack(self.episode_buffer[key]) elif key == "episode_index": self.episode_buffer[key] = torch.full((episode_length,), episode_index) @@ -628,13 +734,15 @@ class LeRobotDataset(torch.utils.data.Dataset): else: self.episode_buffer[key] = torch.tensor(self.episode_buffer[key]) - self.episode_buffer["index"] = torch.arange(self.total_frames, self.total_frames + episode_length) - self._save_episode_to_metadata(episode_index, episode_length, task, task_index) + self.episode_buffer["index"] = torch.arange( + self.meta.total_frames, self.meta.total_frames + episode_length + ) + self.meta.add_episode(episode_index, episode_length, task, task_index) self._wait_image_writer() self._save_episode_table(episode_index) - if encode_videos and len(self.video_keys) > 0: + if encode_videos and len(self.meta.video_keys) > 0: self.encode_videos() # Reset the buffer @@ -643,45 +751,14 @@ class LeRobotDataset(torch.utils.data.Dataset): def _save_episode_table(self, episode_index: int) -> None: ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=self._features, split="train") - ep_data_path = self.root / self.get_data_file_path(ep_index=episode_index) + ep_data_path = self.root / self.meta.get_data_file_path(ep_index=episode_index) ep_data_path.parent.mkdir(parents=True, exist_ok=True) write_parquet(ep_dataset, ep_data_path) - def _save_episode_to_metadata( - self, episode_index: int, episode_length: int, task: str, task_index: int - ) -> None: - self.info["total_episodes"] += 1 - self.info["total_frames"] += episode_length - - if task_index not in self.tasks: - self.info["total_tasks"] += 1 - self.tasks[task_index] = task - task_dict = { - "task_index": task_index, - "task": task, - } - append_jsonlines(task_dict, self.root / TASKS_PATH) - - chunk = self.get_episode_chunk(episode_index) - if chunk >= self.total_chunks: - self.info["total_chunks"] += 1 - - self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"} - self.info["total_videos"] += len(self.video_keys) - write_json(self.info, self.root / INFO_PATH) - - episode_dict = { - "episode_index": episode_index, - "tasks": [task], - "length": episode_length, - } - self.episode_dicts.append(episode_dict) - append_jsonlines(episode_dict, self.root / EPISODES_PATH) - def clear_episode_buffer(self) -> None: episode_index = self.episode_buffer["episode_index"] if self.image_writer is not None: - for cam_key in self.camera_keys: + for cam_key in self.meta.camera_keys: img_dir = self.image_writer.get_episode_dir(episode_index, cam_key) if img_dir.is_dir(): shutil.rmtree(img_dir) @@ -717,12 +794,12 @@ class LeRobotDataset(torch.utils.data.Dataset): def encode_videos(self) -> None: # Use ffmpeg to convert frames stored as png into mp4 videos - for episode_index in range(self.total_episodes): - for key in self.video_keys: + for episode_index in range(self.meta.total_episodes): + for key in self.meta.video_keys: # TODO: create video_buffer to store the state of encoded/unencoded videos and remove the need # to call self.image_writer here tmp_imgs_dir = self.image_writer.get_episode_dir(episode_index, key) - video_path = self.root / self.get_video_file_path(episode_index, key) + video_path = self.root / self.meta.get_video_file_path(episode_index, key) if video_path.is_file(): # Skip if video is already encoded. Could be the case when resuming data recording. continue @@ -730,40 +807,28 @@ class LeRobotDataset(torch.utils.data.Dataset): # since video encoding with ffmpeg is already using multithreading. encode_video_frames(tmp_imgs_dir, video_path, self.fps, overwrite=True) - def _write_video_info(self) -> None: - """ - Warning: this function writes info from first episode videos, implicitly assuming that all videos have - been encoded the same way. Also, this means it assumes the first episode exists. - """ - for key in self.video_keys: - if key not in self.info["videos"]: - video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key) - self.info["videos"][key] = get_video_info(video_path) - - write_json(self.info, self.root / INFO_PATH) - def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = False) -> None: self.hf_dataset = self.load_hf_dataset() - self.episode_data_index = get_episode_data_index(self.episode_dicts, self.episodes) + self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes) check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s) - if len(self.video_keys) > 0: + if len(self.meta.video_keys) > 0: self.encode_videos() - self._write_video_info() + self.meta.write_video_info() if not keep_image_files and self.image_writer is not None: shutil.rmtree(self.image_writer.write_dir) video_files = list(self.root.rglob("*.mp4")) - assert len(video_files) == self.num_episodes * len(self.video_keys) + assert len(video_files) == self.num_episodes * len(self.meta.video_keys) parquet_files = list(self.root.rglob("*.parquet")) assert len(parquet_files) == self.num_episodes if run_compute_stats: self.stop_image_writer() - self.stats = compute_stats(self) - write_stats(self.stats, self.root / STATS_PATH) + self.meta.stats = compute_stats(self) + write_stats(self.meta.stats, self.root / STATS_PATH) self.consolidated = True else: logging.warning( @@ -780,60 +845,23 @@ class LeRobotDataset(torch.utils.data.Dataset): @classmethod def create( cls, - repo_id: str, - fps: int, - root: Path | None = None, - robot: Robot | None = None, - robot_type: str | None = None, - keys: list[str] | None = None, - image_keys: list[str] | None = None, - video_keys: list[str] = None, - shapes: dict | None = None, - names: dict | None = None, + metadata: LeRobotDatasetMetadata, tolerance_s: float = 1e-4, image_writer_processes: int = 0, - image_writer_threads_per_camera: int = 0, - use_videos: bool = True, + image_writer_threads: int = 0, video_backend: str | None = None, ) -> "LeRobotDataset": """Create a LeRobot Dataset from scratch in order to record data.""" obj = cls.__new__(cls) - obj.repo_id = repo_id - obj.root = root if root is not None else LEROBOT_HOME / repo_id + obj.meta = metadata + obj.repo_id = obj.meta.repo_id + obj.root = obj.meta.root + obj.local_files_only = obj.meta.local_files_only obj.tolerance_s = tolerance_s obj.image_writer = None - if robot is not None: - robot_type, keys, image_keys, video_keys, shapes, names = _get_info_from_robot(robot, use_videos) - if not all(cam.fps == fps for cam in robot.cameras.values()): - logging.warning( - f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset." - "In this case, frames from lower fps cameras will be repeated to fill in the blanks" - ) - if len(robot.cameras) > 0 and (image_writer_processes or image_writer_threads_per_camera): - obj.start_image_writer( - image_writer_processes, image_writer_threads_per_camera * robot.num_cameras - ) - elif ( - robot_type is None - or keys is None - or image_keys is None - or video_keys is None - or shapes is None - or names is None - ): - raise ValueError( - "Dataset info (robot_type, keys, shapes...) must either come from a Robot or explicitly passed upon creation." - ) - - if len(video_keys) > 0 and not use_videos: - raise ValueError() - - obj.tasks, obj.stats, obj.episode_dicts = {}, {}, [] - obj.info = create_empty_dataset_info( - CODEBASE_VERSION, fps, robot_type, keys, image_keys, video_keys, shapes, names - ) - write_json(obj.info, obj.root / INFO_PATH) + if image_writer_processes or image_writer_threads: + obj.start_image_writer(image_writer_processes, image_writer_threads) # TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer obj.episode_buffer = obj._create_episode_buffer() @@ -849,7 +877,6 @@ class LeRobotDataset(torch.utils.data.Dataset): obj.image_transforms = None obj.delta_timestamps = None obj.delta_indices = None - obj.local_files_only = True obj.episode_data_index = None obj.video_backend = video_backend if video_backend is not None else "pyav" return obj @@ -889,7 +916,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): # Check that some properties are consistent across datasets. Note: We may relax some of these # consistency requirements in future iterations of this class. for repo_id, dataset in zip(self.repo_ids, self._datasets, strict=True): - if dataset.info != self._datasets[0].info: + if dataset.meta.info != self._datasets[0].meta.info: raise ValueError( f"Detected a mismatch in dataset info between {self.repo_ids[0]} and {repo_id}. This is " "not yet supported." @@ -938,7 +965,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info. """ - return self._datasets[0].info["fps"] + return self._datasets[0].meta.info["fps"] @property def video(self) -> bool: @@ -948,7 +975,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info. """ - return self._datasets[0].info.get("video", False) + return self._datasets[0].meta.info.get("video", False) @property def features(self) -> datasets.Features: diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 0e60af3f..5ade25ae 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -139,7 +139,7 @@ def load_tasks(local_dir: Path) -> dict: return {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])} -def load_episode_dicts(local_dir: Path) -> dict: +def load_episodes(local_dir: Path) -> dict: return load_jsonlines(local_dir / EPISODES_PATH) diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index f23fee38..a0841d00 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -105,7 +105,7 @@ from pathlib import Path from typing import List # from safetensors.torch import load_file, save_file -from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata from lerobot.common.robot_devices.control_utils import ( control_loop, has_method, @@ -234,15 +234,18 @@ def record( # Create empty dataset or load existing saved episodes sanity_check_dataset_name(repo_id, policy) - dataset = LeRobotDataset.create( + dataset_metadata = LeRobotDatasetMetadata.create( repo_id, fps, root=root, robot=robot, - image_writer_processes=num_image_writer_processes, - image_writer_threads_per_camera=num_image_writer_threads_per_camera, use_videos=video, ) + dataset = LeRobotDataset.create( + dataset_metadata, + image_writer_processes=num_image_writer_processes, + image_writer_threads=num_image_writer_threads_per_camera, + ) if not robot.is_connected: robot.connect() @@ -315,7 +318,6 @@ def record( dataset.consolidate(run_compute_stats) - # lerobot_dataset = create_lerobot_dataset(dataset, run_compute_stats, push_to_hub, tags, play_sounds) if push_to_hub: dataset.push_to_hub() diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 0aec8472..040f92d9 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -484,7 +484,7 @@ def main( policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=str(pretrained_policy_path)) else: # Note: We need the dataset stats to pass to the policy's normalization modules. - policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats) + policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).meta.stats) assert isinstance(policy, nn.Module) policy.eval() diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 8ff3b389..9a0b7e4c 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -328,7 +328,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No logging.info("make_policy") policy = make_policy( hydra_cfg=cfg, - dataset_stats=offline_dataset.stats if not cfg.resume else None, + dataset_stats=offline_dataset.meta.stats if not cfg.resume else None, pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None, ) assert isinstance(policy, nn.Module) diff --git a/lerobot/scripts/visualize_dataset.py b/lerobot/scripts/visualize_dataset.py index 6cff5752..d7720c10 100644 --- a/lerobot/scripts/visualize_dataset.py +++ b/lerobot/scripts/visualize_dataset.py @@ -153,7 +153,7 @@ def visualize_dataset( rr.set_time_seconds("timestamp", batch["timestamp"][i].item()) # display each camera image - for key in dataset.camera_keys: + for key in dataset.meta.camera_keys: # TODO(rcadene): add `.compress()`? is it lossless? rr.log(key, rr.Image(to_hwc_uint8_numpy(batch[key][i]))) diff --git a/lerobot/scripts/visualize_dataset_html.py b/lerobot/scripts/visualize_dataset_html.py index 10a85bda..b396a369 100644 --- a/lerobot/scripts/visualize_dataset_html.py +++ b/lerobot/scripts/visualize_dataset_html.py @@ -97,8 +97,8 @@ def run_server( "num_episodes": dataset.num_episodes, "fps": dataset.fps, } - video_paths = [dataset.get_video_file_path(episode_id, key) for key in dataset.video_keys] - tasks = dataset.episode_dicts[episode_id]["tasks"] + video_paths = [dataset.meta.get_video_file_path(episode_id, key) for key in dataset.meta.video_keys] + tasks = dataset.meta.episodes[episode_id]["tasks"] videos_info = [ {"url": url_for("static", filename=video_path), "filename": video_path.name} for video_path in video_paths @@ -170,7 +170,8 @@ def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str] # get first frame of episode (hack to get video_path of the episode) first_frame_idx = dataset.episode_data_index["from"][ep_index].item() return [ - dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"] for key in dataset.video_keys + dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"] + for key in dataset.meta.video_keys ] @@ -202,8 +203,8 @@ def visualize_dataset_html( dataset = LeRobotDataset(repo_id, root=root) - if len(dataset.image_keys) > 0: - raise NotImplementedError(f"Image keys ({dataset.image_keys=}) are currently not supported.") + if len(dataset.meta.image_keys) > 0: + raise NotImplementedError(f"Image keys ({dataset.meta.image_keys=}) are currently not supported.") if output_dir is None: output_dir = f"outputs/visualize_dataset_html/{repo_id}" diff --git a/lerobot/scripts/visualize_image_transforms.py b/lerobot/scripts/visualize_image_transforms.py index e7cd3582..f9fb5c08 100644 --- a/lerobot/scripts/visualize_image_transforms.py +++ b/lerobot/scripts/visualize_image_transforms.py @@ -157,7 +157,7 @@ def visualize_transforms(cfg, output_dir: Path, n_examples: int = 5): output_dir.mkdir(parents=True, exist_ok=True) # Get 1st frame from 1st camera of 1st episode - original_frame = dataset[0][dataset.camera_keys[0]] + original_frame = dataset[0][dataset.meta.camera_keys[0]] to_pil(original_frame).save(output_dir / "original_frame.png", quality=100) print("\nOriginal frame saved to:") print(f" {output_dir / 'original_frame.png'}.") diff --git a/tests/fixtures/dataset_factories.py b/tests/fixtures/dataset_factories.py index b489792a..bbd485b7 100644 --- a/tests/fixtures/dataset_factories.py +++ b/tests/fixtures/dataset_factories.py @@ -8,7 +8,7 @@ import PIL.Image import pytest import torch -from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset +from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata from lerobot.common.datasets.utils import ( DEFAULT_CHUNK_SIZE, DEFAULT_PARQUET_PATH, @@ -33,8 +33,8 @@ def make_dummy_shapes(keys: list[str] | None = None, camera_keys: list[str] | No return shapes -def get_task_index(tasks_dicts: dict, task: str) -> int: - tasks = {d["task_index"]: d["task"] for d in tasks_dicts} +def get_task_index(task_dicts: dict, task: str) -> int: + tasks = {d["task_index"]: d["task"] for d in task_dicts} task_to_task_index = {task: task_idx for task_idx, task in tasks.items()} return task_to_task_index[task] @@ -313,6 +313,47 @@ def hf_dataset_factory(img_array_factory, episodes, tasks): return _create_hf_dataset +@pytest.fixture(scope="session") +def lerobot_dataset_metadata_factory( + info, + stats, + tasks, + episodes, + mock_snapshot_download_factory, +): + def _create_lerobot_dataset_metadata( + root: Path, + repo_id: str = DUMMY_REPO_ID, + info_dict: dict = info, + stats_dict: dict = stats, + task_dicts: list[dict] = tasks, + episode_dicts: list[dict] = episodes, + **kwargs, + ) -> LeRobotDatasetMetadata: + mock_snapshot_download = mock_snapshot_download_factory( + info_dict=info_dict, + stats_dict=stats_dict, + task_dicts=task_dicts, + episode_dicts=episode_dicts, + ) + with ( + patch( + "lerobot.common.datasets.lerobot_dataset.get_hub_safe_version" + ) as mock_get_hub_safe_version_patch, + patch( + "lerobot.common.datasets.lerobot_dataset.snapshot_download" + ) as mock_snapshot_download_patch, + ): + mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version, enforce_v2=True: version + mock_snapshot_download_patch.side_effect = mock_snapshot_download + + return LeRobotDatasetMetadata( + repo_id=repo_id, root=root, local_files_only=kwargs.get("local_files_only", False) + ) + + return _create_lerobot_dataset_metadata + + @pytest.fixture(scope="session") def lerobot_dataset_factory( info, @@ -321,6 +362,7 @@ def lerobot_dataset_factory( episodes, hf_dataset, mock_snapshot_download_factory, + lerobot_dataset_metadata_factory, ): def _create_lerobot_dataset( root: Path, @@ -335,19 +377,26 @@ def lerobot_dataset_factory( mock_snapshot_download = mock_snapshot_download_factory( info_dict=info_dict, stats_dict=stats_dict, - tasks_dicts=task_dicts, - episodes_dicts=episode_dicts, + task_dicts=task_dicts, + episode_dicts=episode_dicts, hf_ds=hf_ds, ) + mock_metadata = lerobot_dataset_metadata_factory( + root=root, + repo_id=repo_id, + info_dict=info_dict, + stats_dict=stats_dict, + task_dicts=task_dicts, + episode_dicts=episode_dicts, + **kwargs, + ) with ( - patch( - "lerobot.common.datasets.lerobot_dataset.get_hub_safe_version" - ) as mock_get_hub_safe_version_patch, + patch("lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch, patch( "lerobot.common.datasets.lerobot_dataset.snapshot_download" ) as mock_snapshot_download_patch, ): - mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version, enforce_v2=True: version + mock_metadata_patch.return_value = mock_metadata mock_snapshot_download_patch.side_effect = mock_snapshot_download return LeRobotDataset(repo_id=repo_id, root=root, **kwargs) diff --git a/tests/fixtures/files.py b/tests/fixtures/files.py index 714824f9..a9ee2c35 100644 --- a/tests/fixtures/files.py +++ b/tests/fixtures/files.py @@ -36,11 +36,11 @@ def stats_path(stats): @pytest.fixture(scope="session") def tasks_path(tasks): - def _create_tasks_jsonl_file(dir: Path, tasks_dicts: list = tasks) -> Path: + def _create_tasks_jsonl_file(dir: Path, task_dicts: list = tasks) -> Path: fpath = dir / TASKS_PATH fpath.parent.mkdir(parents=True, exist_ok=True) with jsonlines.open(fpath, "w") as writer: - writer.write_all(tasks_dicts) + writer.write_all(task_dicts) return fpath return _create_tasks_jsonl_file diff --git a/tests/fixtures/hub.py b/tests/fixtures/hub.py index 3422936c..8dd9e966 100644 --- a/tests/fixtures/hub.py +++ b/tests/fixtures/hub.py @@ -26,7 +26,7 @@ def mock_snapshot_download_factory( """ def _mock_snapshot_download_func( - info_dict=info, stats_dict=stats, tasks_dicts=tasks, episodes_dicts=episodes, hf_ds=hf_dataset + info_dict=info, stats_dict=stats, task_dicts=tasks, episode_dicts=episodes, hf_ds=hf_dataset ): def _extract_episode_index_from_path(fpath: str) -> int: path = Path(fpath) @@ -53,7 +53,7 @@ def mock_snapshot_download_factory( all_files.extend(meta_files) data_files = [] - for episode_dict in episodes_dicts: + for episode_dict in episode_dicts: ep_idx = episode_dict["episode_index"] ep_chunk = ep_idx // info_dict["chunks_size"] data_path = info_dict["data_path"].format(episode_chunk=ep_chunk, episode_index=ep_idx) @@ -75,9 +75,9 @@ def mock_snapshot_download_factory( elif rel_path == STATS_PATH: _ = stats_path(local_dir, stats_dict) elif rel_path == TASKS_PATH: - _ = tasks_path(local_dir, tasks_dicts) + _ = tasks_path(local_dir, task_dicts) elif rel_path == EPISODES_PATH: - _ = episode_path(local_dir, episodes_dicts) + _ = episode_path(local_dir, episode_dicts) else: pass return str(local_dir) diff --git a/tests/scripts/save_image_transforms_to_safetensors.py b/tests/scripts/save_image_transforms_to_safetensors.py index 9d024a01..1fa194e5 100644 --- a/tests/scripts/save_image_transforms_to_safetensors.py +++ b/tests/scripts/save_image_transforms_to_safetensors.py @@ -76,7 +76,7 @@ def main(): dataset = LeRobotDataset(DATASET_REPO_ID, image_transforms=None) output_dir = Path(ARTIFACT_DIR) output_dir.mkdir(parents=True, exist_ok=True) - original_frame = dataset[0][dataset.camera_keys[0]] + original_frame = dataset[0][dataset.meta.camera_keys[0]] save_single_transforms(original_frame, output_dir) save_default_config_transform(original_frame, output_dir) diff --git a/tests/scripts/save_policy_to_safetensors.py b/tests/scripts/save_policy_to_safetensors.py index 5236b7ae..29d0ae19 100644 --- a/tests/scripts/save_policy_to_safetensors.py +++ b/tests/scripts/save_policy_to_safetensors.py @@ -38,7 +38,7 @@ def get_policy_stats(env_name, policy_name, extra_overrides): ) set_global_seed(1337) dataset = make_dataset(cfg) - policy = make_policy(cfg, dataset_stats=dataset.stats) + policy = make_policy(cfg, dataset_stats=dataset.meta.stats) policy.train() optimizer, _ = make_optimizer_and_scheduler(cfg, policy) diff --git a/tests/test_control_robot.py b/tests/test_control_robot.py index 6734af2b..88a4d1cd 100644 --- a/tests/test_control_robot.py +++ b/tests/test_control_robot.py @@ -155,7 +155,7 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock): display_cameras=False, play_sounds=False, ) - assert dataset.total_episodes == 2 + assert dataset.meta.total_episodes == 2 assert len(dataset) == 2 replay(robot, episode=0, fps=1, root=root, repo_id=repo_id, play_sounds=False) @@ -193,7 +193,7 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock): overrides=overrides, ) - policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats) + policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.meta.stats) optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) out_dir = tmpdir / "logger" logger = Logger(cfg, out_dir, wandb_job_name="debug") diff --git a/tests/test_datasets.py b/tests/test_datasets.py index c46bb51a..d1d49b31 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -33,7 +33,11 @@ from lerobot.common.datasets.compute_stats import ( get_stats_einops_patterns, ) from lerobot.common.datasets.factory import make_dataset -from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset +from lerobot.common.datasets.lerobot_dataset import ( + LeRobotDataset, + LeRobotDatasetMetadata, + MultiLeRobotDataset, +) from lerobot.common.datasets.utils import ( create_branch, flatten_dict, @@ -53,14 +57,17 @@ def test_same_attributes_defined(lerobot_dataset_factory, tmp_path): # Instantiate both ways robot = make_robot("koch", mock=True) root_create = tmp_path / "create" - dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create) + metadata_create = LeRobotDatasetMetadata.create( + repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create + ) + dataset_create = LeRobotDataset.create(metadata_create) root_init = tmp_path / "init" dataset_init = lerobot_dataset_factory(root=root_init) # Access the '_hub_version' cached_property in both instances to force its creation - _ = dataset_init._hub_version - _ = dataset_create._hub_version + _ = dataset_init.meta._hub_version + _ = dataset_create.meta._hub_version init_attr = set(vars(dataset_init).keys()) create_attr = set(vars(dataset_create).keys()) @@ -78,8 +85,8 @@ def test_dataset_initialization(lerobot_dataset_from_episodes_factory, tmp_path) dataset = lerobot_dataset_from_episodes_factory(root=tmp_path, **kwargs) assert dataset.repo_id == kwargs["repo_id"] - assert dataset.total_episodes == kwargs["total_episodes"] - assert dataset.total_frames == kwargs["total_frames"] + assert dataset.meta.total_episodes == kwargs["total_episodes"] + assert dataset.meta.total_frames == kwargs["total_frames"] assert dataset.episodes == kwargs["episodes"] assert dataset.num_episodes == len(kwargs["episodes"]) assert dataset.num_frames == len(dataset) @@ -118,7 +125,7 @@ def test_factory(env_name, repo_id, policy_name): ) dataset = make_dataset(cfg) delta_timestamps = dataset.delta_timestamps - camera_keys = dataset.camera_keys + camera_keys = dataset.meta.camera_keys item = dataset[0] @@ -251,7 +258,7 @@ def test_compute_stats_on_xarm(): assert torch.allclose(computed_stats[k]["max"], expected_stats[k]["max"]) # load stats used during training which are expected to match the ones returned by computed_stats - loaded_stats = dataset.stats # noqa: F841 + loaded_stats = dataset.meta.stats # noqa: F841 # TODO(rcadene): we can't test this because expected_stats is computed on a subset # # test loaded stats match expected stats diff --git a/tests/test_examples.py b/tests/test_examples.py index 0a6ce422..6b304863 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -13,7 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# TODO(aliberts): Mute logging for these tests + import io import subprocess import sys @@ -29,6 +29,7 @@ def _find_and_replace(text: str, finds_and_replaces: list[tuple[str, str]]) -> s return text +# TODO(aliberts): Remove usage of subprocess calls and patch code with fixtures def _run_script(path): subprocess.run([sys.executable, path], check=True) diff --git a/tests/test_policies.py b/tests/test_policies.py index f358170d..573a3486 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -50,7 +50,7 @@ def test_get_policy_and_config_classes(policy_name: str): assert issubclass(config_cls, inspect.signature(policy_cls.__init__).parameters["config"].annotation) -# TODO(aliberts): refactor using lerobot/__init__.py variables +@pytest.mark.skip("TODO after v2 migration / removing hydra") @pytest.mark.parametrize( "env_name,policy_name,extra_overrides", [ @@ -136,7 +136,7 @@ def test_policy(env_name, policy_name, extra_overrides): # Check that we can make the policy object. dataset = make_dataset(cfg) - policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats) + policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.meta.stats) # Check that the policy follows the required protocol. assert isinstance( policy, Policy @@ -195,6 +195,7 @@ def test_policy(env_name, policy_name, extra_overrides): env.step(action) +@pytest.mark.skip("TODO after v2 migration / removing hydra") def test_act_backbone_lr(): """ Test that the ACT policy can be instantiated with a different learning rate for the backbone. @@ -213,7 +214,7 @@ def test_act_backbone_lr(): assert cfg.training.lr_backbone == 0.001 dataset = make_dataset(cfg) - policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats) + policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.meta.stats) optimizer, _ = make_optimizer_and_scheduler(cfg, policy) assert len(optimizer.param_groups) == 2 assert optimizer.param_groups[0]["lr"] == cfg.training.lr diff --git a/tests/test_push_dataset_to_hub.py b/tests/test_push_dataset_to_hub.py index f6725f87..bcba38f0 100644 --- a/tests/test_push_dataset_to_hub.py +++ b/tests/test_push_dataset_to_hub.py @@ -250,6 +250,7 @@ def test_push_dataset_to_hub_out_dir_force_override_false(tmpdir): ) +@pytest.mark.skip("TODO after v2 migration / removing hydra") @pytest.mark.parametrize( "required_packages, raw_format, repo_id, make_test_data", [