Add suggestions from code review

This commit is contained in:
Simon Alibert 2024-10-11 18:52:11 +02:00
parent 3ea53124e0
commit 8bd406e607
1 changed files with 9 additions and 17 deletions

View File

@ -48,11 +48,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
repo_id: str, repo_id: str,
root: Path | None = None, root: Path | None = None,
episodes: list[int] | None = None, episodes: list[int] | None = None,
split: str = "train",
image_transforms: Callable | None = None, image_transforms: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None, delta_timestamps: dict[list[float]] | None = None,
tolerance_s: float = 1e-4, tolerance_s: float = 1e-4,
download_data: bool = True, download_videos: bool = True,
video_backend: str | None = None, video_backend: str | None = None,
): ):
"""LeRobotDataset encapsulates 3 main things: """LeRobotDataset encapsulates 3 main things:
@ -64,7 +63,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
- hf_dataset (from datasets.Dataset), which will read any values from parquet files. - hf_dataset (from datasets.Dataset), which will read any values from parquet files.
- (optional) videos from which frames are loaded to be synchronous with data from parquet files. - (optional) videos from which frames are loaded to be synchronous with data from parquet files.
3 use modes are available for this class, depending on 3 different use cases: 3 modes are available for this class, depending on 3 different use cases:
1. Your dataset already exists on the Hugging Face Hub at the address 1. Your dataset already exists on the Hugging Face Hub at the address
https://huggingface.co/datasets/{repo_id} and is not on your local disk in the 'root' folder: https://huggingface.co/datasets/{repo_id} and is not on your local disk in the 'root' folder:
@ -119,7 +118,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
'~/.cache/huggingface/lerobot'. '~/.cache/huggingface/lerobot'.
episodes (list[int] | None, optional): If specified, this will only load episodes specified by episodes (list[int] | None, optional): If specified, this will only load episodes specified by
their episode_index in this list. Defaults to None. their episode_index in this list. Defaults to None.
split (str, optional): _description_. Defaults to "train".
image_transforms (Callable | None, optional): You can pass standard v2 image transforms from image_transforms (Callable | None, optional): You can pass standard v2 image transforms from
torchvision.transforms.v2 here which will be applied to visual modalities (whether they come torchvision.transforms.v2 here which will be applied to visual modalities (whether they come
from videos or images). Defaults to None. from videos or images). Defaults to None.
@ -129,19 +127,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
timestamps is separated to the next by 1/fps +/- tolerance_s. This also applies to frames timestamps is separated to the next by 1/fps +/- tolerance_s. This also applies to frames
decoded from video files. It is also used to check that `delta_timestamps` (when provided) are decoded from video files. It is also used to check that `delta_timestamps` (when provided) are
multiples of 1/fps. Defaults to 1e-4. multiples of 1/fps. Defaults to 1e-4.
download_data (bool, optional): Flag to download actual data. Defaults to True. download_videos (bool, optional): Flag to download the videos. Defaults to True.
video_backend (str | None, optional): Video backend to use for decoding videos. There is currently video_backend (str | None, optional): Video backend to use for decoding videos. There is currently
a single option which is the pyav decoder used by Torchvision. Defaults to pyav. a single option which is the pyav decoder used by Torchvision. Defaults to pyav.
""" """
super().__init__() super().__init__()
self.repo_id = repo_id self.repo_id = repo_id
self.root = root if root is not None else LEROBOT_HOME / repo_id self.root = root if root is not None else LEROBOT_HOME / repo_id
self.split = split
self.image_transforms = image_transforms self.image_transforms = image_transforms
self.delta_timestamps = delta_timestamps self.delta_timestamps = delta_timestamps
self.episodes = episodes self.episodes = episodes
self.tolerance_s = tolerance_s self.tolerance_s = tolerance_s
self.download_data = download_data self.download_videos = download_videos
self.video_backend = video_backend if video_backend is not None else "pyav" self.video_backend = video_backend if video_backend is not None else "pyav"
self.delta_indices = None self.delta_indices = None
@ -152,13 +149,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.stats = load_stats(repo_id, self._version, self.root) self.stats = load_stats(repo_id, self._version, self.root)
self.tasks = load_tasks(repo_id, self._version, self.root) self.tasks = load_tasks(repo_id, self._version, self.root)
if not self.download_data:
# TODO(aliberts): Add actual support for this
# maybe use local_files_only=True or HF_HUB_OFFLINE=True
# see thread https://huggingface.slack.com/archives/C06ME3E7JUD/p1728637455476019
self.hf_dataset, self.episode_data_index = None, None
return
# Load actual data # Load actual data
self.download_episodes() self.download_episodes()
self.hf_dataset = load_hf_dataset(self.root, self.data_path, self.total_episodes, self.episodes) self.hf_dataset = load_hf_dataset(self.root, self.data_path, self.total_episodes, self.episodes)
@ -192,12 +182,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
# TODO(rcadene, aliberts): implement faster transfer # TODO(rcadene, aliberts): implement faster transfer
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads # https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
files = None files = None
ignore_patterns = None if self.download_videos else "videos/"
if self.episodes is not None: if self.episodes is not None:
files = [ files = [
self.data_path.format(episode_index=ep_idx, total_episodes=self.total_episodes) self.data_path.format(episode_index=ep_idx, total_episodes=self.total_episodes)
for ep_idx in self.episodes for ep_idx in self.episodes
] ]
if len(self.video_keys) > 0: if len(self.video_keys) > 0 and self.download_videos:
video_files = [ video_files = [
self.videos_path.format(video_key=vid_key, episode_index=ep_idx) self.videos_path.format(video_key=vid_key, episode_index=ep_idx)
for vid_key in self.video_keys for vid_key in self.video_keys
@ -211,6 +202,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
revision=self._version, revision=self._version,
local_dir=self.root, local_dir=self.root,
allow_patterns=files, allow_patterns=files,
ignore_patterns=ignore_patterns,
) )
@property @property
@ -371,7 +363,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
item = {**video_frames, **item} item = {**video_frames, **item}
if self.image_transforms is not None: if self.image_transforms is not None:
for cam in self.camera_keys: image_keys = self.camera_keys if self.download_videos else self.image_keys
for cam in image_keys:
item[cam] = self.image_transforms(item[cam]) item[cam] = self.image_transforms(item[cam])
return item return item
@ -380,7 +373,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
return ( return (
f"{self.__class__.__name__}(\n" f"{self.__class__.__name__}(\n"
f" Repository ID: '{self.repo_id}',\n" f" Repository ID: '{self.repo_id}',\n"
f" Split: '{self.split}',\n"
f" Number of Samples: {self.num_samples},\n" f" Number of Samples: {self.num_samples},\n"
f" Number of Episodes: {self.num_episodes},\n" f" Number of Episodes: {self.num_episodes},\n"
f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n" f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n"