From 8bd406e6070e200b36b6a9a864011bb4063fcedc Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Fri, 11 Oct 2024 18:52:11 +0200 Subject: [PATCH] Add suggestions from code review --- lerobot/common/datasets/lerobot_dataset.py | 26 ++++++++-------------- 1 file changed, 9 insertions(+), 17 deletions(-) diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 61d27287..6b149554 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -48,11 +48,10 @@ class LeRobotDataset(torch.utils.data.Dataset): repo_id: str, root: Path | None = None, episodes: list[int] | None = None, - split: str = "train", image_transforms: Callable | None = None, delta_timestamps: dict[list[float]] | None = None, tolerance_s: float = 1e-4, - download_data: bool = True, + download_videos: bool = True, video_backend: str | None = None, ): """LeRobotDataset encapsulates 3 main things: @@ -64,7 +63,7 @@ class LeRobotDataset(torch.utils.data.Dataset): - hf_dataset (from datasets.Dataset), which will read any values from parquet files. - (optional) videos from which frames are loaded to be synchronous with data from parquet files. - 3 use modes are available for this class, depending on 3 different use cases: + 3 modes are available for this class, depending on 3 different use cases: 1. Your dataset already exists on the Hugging Face Hub at the address https://huggingface.co/datasets/{repo_id} and is not on your local disk in the 'root' folder: @@ -119,7 +118,6 @@ class LeRobotDataset(torch.utils.data.Dataset): '~/.cache/huggingface/lerobot'. episodes (list[int] | None, optional): If specified, this will only load episodes specified by their episode_index in this list. Defaults to None. - split (str, optional): _description_. Defaults to "train". image_transforms (Callable | None, optional): You can pass standard v2 image transforms from torchvision.transforms.v2 here which will be applied to visual modalities (whether they come from videos or images). Defaults to None. @@ -129,19 +127,18 @@ class LeRobotDataset(torch.utils.data.Dataset): timestamps is separated to the next by 1/fps +/- tolerance_s. This also applies to frames decoded from video files. It is also used to check that `delta_timestamps` (when provided) are multiples of 1/fps. Defaults to 1e-4. - download_data (bool, optional): Flag to download actual data. Defaults to True. + download_videos (bool, optional): Flag to download the videos. Defaults to True. video_backend (str | None, optional): Video backend to use for decoding videos. There is currently a single option which is the pyav decoder used by Torchvision. Defaults to pyav. """ super().__init__() self.repo_id = repo_id self.root = root if root is not None else LEROBOT_HOME / repo_id - self.split = split self.image_transforms = image_transforms self.delta_timestamps = delta_timestamps self.episodes = episodes self.tolerance_s = tolerance_s - self.download_data = download_data + self.download_videos = download_videos self.video_backend = video_backend if video_backend is not None else "pyav" self.delta_indices = None @@ -152,13 +149,6 @@ class LeRobotDataset(torch.utils.data.Dataset): self.stats = load_stats(repo_id, self._version, self.root) self.tasks = load_tasks(repo_id, self._version, self.root) - if not self.download_data: - # TODO(aliberts): Add actual support for this - # maybe use local_files_only=True or HF_HUB_OFFLINE=True - # see thread https://huggingface.slack.com/archives/C06ME3E7JUD/p1728637455476019 - self.hf_dataset, self.episode_data_index = None, None - return - # Load actual data self.download_episodes() self.hf_dataset = load_hf_dataset(self.root, self.data_path, self.total_episodes, self.episodes) @@ -192,12 +182,13 @@ class LeRobotDataset(torch.utils.data.Dataset): # TODO(rcadene, aliberts): implement faster transfer # https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads files = None + ignore_patterns = None if self.download_videos else "videos/" if self.episodes is not None: files = [ self.data_path.format(episode_index=ep_idx, total_episodes=self.total_episodes) for ep_idx in self.episodes ] - if len(self.video_keys) > 0: + if len(self.video_keys) > 0 and self.download_videos: video_files = [ self.videos_path.format(video_key=vid_key, episode_index=ep_idx) for vid_key in self.video_keys @@ -211,6 +202,7 @@ class LeRobotDataset(torch.utils.data.Dataset): revision=self._version, local_dir=self.root, allow_patterns=files, + ignore_patterns=ignore_patterns, ) @property @@ -371,7 +363,8 @@ class LeRobotDataset(torch.utils.data.Dataset): item = {**video_frames, **item} if self.image_transforms is not None: - for cam in self.camera_keys: + image_keys = self.camera_keys if self.download_videos else self.image_keys + for cam in image_keys: item[cam] = self.image_transforms(item[cam]) return item @@ -380,7 +373,6 @@ class LeRobotDataset(torch.utils.data.Dataset): return ( f"{self.__class__.__name__}(\n" f" Repository ID: '{self.repo_id}',\n" - f" Split: '{self.split}',\n" f" Number of Samples: {self.num_samples},\n" f" Number of Episodes: {self.num_episodes},\n" f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n"