Add suggestions from code review

This commit is contained in:
Simon Alibert 2024-10-11 18:52:11 +02:00
parent 3ea53124e0
commit 8bd406e607
1 changed files with 9 additions and 17 deletions

View File

@ -48,11 +48,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
repo_id: str,
root: Path | None = None,
episodes: list[int] | None = None,
split: str = "train",
image_transforms: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None,
tolerance_s: float = 1e-4,
download_data: bool = True,
download_videos: bool = True,
video_backend: str | None = None,
):
"""LeRobotDataset encapsulates 3 main things:
@ -64,7 +63,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
- hf_dataset (from datasets.Dataset), which will read any values from parquet files.
- (optional) videos from which frames are loaded to be synchronous with data from parquet files.
3 use modes are available for this class, depending on 3 different use cases:
3 modes are available for this class, depending on 3 different use cases:
1. Your dataset already exists on the Hugging Face Hub at the address
https://huggingface.co/datasets/{repo_id} and is not on your local disk in the 'root' folder:
@ -119,7 +118,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
'~/.cache/huggingface/lerobot'.
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
their episode_index in this list. Defaults to None.
split (str, optional): _description_. Defaults to "train".
image_transforms (Callable | None, optional): You can pass standard v2 image transforms from
torchvision.transforms.v2 here which will be applied to visual modalities (whether they come
from videos or images). Defaults to None.
@ -129,19 +127,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
timestamps is separated to the next by 1/fps +/- tolerance_s. This also applies to frames
decoded from video files. It is also used to check that `delta_timestamps` (when provided) are
multiples of 1/fps. Defaults to 1e-4.
download_data (bool, optional): Flag to download actual data. Defaults to True.
download_videos (bool, optional): Flag to download the videos. Defaults to True.
video_backend (str | None, optional): Video backend to use for decoding videos. There is currently
a single option which is the pyav decoder used by Torchvision. Defaults to pyav.
"""
super().__init__()
self.repo_id = repo_id
self.root = root if root is not None else LEROBOT_HOME / repo_id
self.split = split
self.image_transforms = image_transforms
self.delta_timestamps = delta_timestamps
self.episodes = episodes
self.tolerance_s = tolerance_s
self.download_data = download_data
self.download_videos = download_videos
self.video_backend = video_backend if video_backend is not None else "pyav"
self.delta_indices = None
@ -152,13 +149,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.stats = load_stats(repo_id, self._version, self.root)
self.tasks = load_tasks(repo_id, self._version, self.root)
if not self.download_data:
# TODO(aliberts): Add actual support for this
# maybe use local_files_only=True or HF_HUB_OFFLINE=True
# see thread https://huggingface.slack.com/archives/C06ME3E7JUD/p1728637455476019
self.hf_dataset, self.episode_data_index = None, None
return
# Load actual data
self.download_episodes()
self.hf_dataset = load_hf_dataset(self.root, self.data_path, self.total_episodes, self.episodes)
@ -192,12 +182,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
# TODO(rcadene, aliberts): implement faster transfer
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
files = None
ignore_patterns = None if self.download_videos else "videos/"
if self.episodes is not None:
files = [
self.data_path.format(episode_index=ep_idx, total_episodes=self.total_episodes)
for ep_idx in self.episodes
]
if len(self.video_keys) > 0:
if len(self.video_keys) > 0 and self.download_videos:
video_files = [
self.videos_path.format(video_key=vid_key, episode_index=ep_idx)
for vid_key in self.video_keys
@ -211,6 +202,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
revision=self._version,
local_dir=self.root,
allow_patterns=files,
ignore_patterns=ignore_patterns,
)
@property
@ -371,7 +363,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
item = {**video_frames, **item}
if self.image_transforms is not None:
for cam in self.camera_keys:
image_keys = self.camera_keys if self.download_videos else self.image_keys
for cam in image_keys:
item[cam] = self.image_transforms(item[cam])
return item
@ -380,7 +373,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
return (
f"{self.__class__.__name__}(\n"
f" Repository ID: '{self.repo_id}',\n"
f" Split: '{self.split}',\n"
f" Number of Samples: {self.num_samples},\n"
f" Number of Episodes: {self.num_episodes},\n"
f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n"