Add suggestions from code review
This commit is contained in:
parent
3ea53124e0
commit
8bd406e607
|
@ -48,11 +48,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
repo_id: str,
|
||||
root: Path | None = None,
|
||||
episodes: list[int] | None = None,
|
||||
split: str = "train",
|
||||
image_transforms: Callable | None = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
tolerance_s: float = 1e-4,
|
||||
download_data: bool = True,
|
||||
download_videos: bool = True,
|
||||
video_backend: str | None = None,
|
||||
):
|
||||
"""LeRobotDataset encapsulates 3 main things:
|
||||
|
@ -64,7 +63,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
- hf_dataset (from datasets.Dataset), which will read any values from parquet files.
|
||||
- (optional) videos from which frames are loaded to be synchronous with data from parquet files.
|
||||
|
||||
3 use modes are available for this class, depending on 3 different use cases:
|
||||
3 modes are available for this class, depending on 3 different use cases:
|
||||
|
||||
1. Your dataset already exists on the Hugging Face Hub at the address
|
||||
https://huggingface.co/datasets/{repo_id} and is not on your local disk in the 'root' folder:
|
||||
|
@ -119,7 +118,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
'~/.cache/huggingface/lerobot'.
|
||||
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
|
||||
their episode_index in this list. Defaults to None.
|
||||
split (str, optional): _description_. Defaults to "train".
|
||||
image_transforms (Callable | None, optional): You can pass standard v2 image transforms from
|
||||
torchvision.transforms.v2 here which will be applied to visual modalities (whether they come
|
||||
from videos or images). Defaults to None.
|
||||
|
@ -129,19 +127,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
timestamps is separated to the next by 1/fps +/- tolerance_s. This also applies to frames
|
||||
decoded from video files. It is also used to check that `delta_timestamps` (when provided) are
|
||||
multiples of 1/fps. Defaults to 1e-4.
|
||||
download_data (bool, optional): Flag to download actual data. Defaults to True.
|
||||
download_videos (bool, optional): Flag to download the videos. Defaults to True.
|
||||
video_backend (str | None, optional): Video backend to use for decoding videos. There is currently
|
||||
a single option which is the pyav decoder used by Torchvision. Defaults to pyav.
|
||||
"""
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
self.root = root if root is not None else LEROBOT_HOME / repo_id
|
||||
self.split = split
|
||||
self.image_transforms = image_transforms
|
||||
self.delta_timestamps = delta_timestamps
|
||||
self.episodes = episodes
|
||||
self.tolerance_s = tolerance_s
|
||||
self.download_data = download_data
|
||||
self.download_videos = download_videos
|
||||
self.video_backend = video_backend if video_backend is not None else "pyav"
|
||||
self.delta_indices = None
|
||||
|
||||
|
@ -152,13 +149,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
self.stats = load_stats(repo_id, self._version, self.root)
|
||||
self.tasks = load_tasks(repo_id, self._version, self.root)
|
||||
|
||||
if not self.download_data:
|
||||
# TODO(aliberts): Add actual support for this
|
||||
# maybe use local_files_only=True or HF_HUB_OFFLINE=True
|
||||
# see thread https://huggingface.slack.com/archives/C06ME3E7JUD/p1728637455476019
|
||||
self.hf_dataset, self.episode_data_index = None, None
|
||||
return
|
||||
|
||||
# Load actual data
|
||||
self.download_episodes()
|
||||
self.hf_dataset = load_hf_dataset(self.root, self.data_path, self.total_episodes, self.episodes)
|
||||
|
@ -192,12 +182,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
# TODO(rcadene, aliberts): implement faster transfer
|
||||
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
|
||||
files = None
|
||||
ignore_patterns = None if self.download_videos else "videos/"
|
||||
if self.episodes is not None:
|
||||
files = [
|
||||
self.data_path.format(episode_index=ep_idx, total_episodes=self.total_episodes)
|
||||
for ep_idx in self.episodes
|
||||
]
|
||||
if len(self.video_keys) > 0:
|
||||
if len(self.video_keys) > 0 and self.download_videos:
|
||||
video_files = [
|
||||
self.videos_path.format(video_key=vid_key, episode_index=ep_idx)
|
||||
for vid_key in self.video_keys
|
||||
|
@ -211,6 +202,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
revision=self._version,
|
||||
local_dir=self.root,
|
||||
allow_patterns=files,
|
||||
ignore_patterns=ignore_patterns,
|
||||
)
|
||||
|
||||
@property
|
||||
|
@ -371,7 +363,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
item = {**video_frames, **item}
|
||||
|
||||
if self.image_transforms is not None:
|
||||
for cam in self.camera_keys:
|
||||
image_keys = self.camera_keys if self.download_videos else self.image_keys
|
||||
for cam in image_keys:
|
||||
item[cam] = self.image_transforms(item[cam])
|
||||
|
||||
return item
|
||||
|
@ -380,7 +373,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
return (
|
||||
f"{self.__class__.__name__}(\n"
|
||||
f" Repository ID: '{self.repo_id}',\n"
|
||||
f" Split: '{self.split}',\n"
|
||||
f" Number of Samples: {self.num_samples},\n"
|
||||
f" Number of Episodes: {self.num_episodes},\n"
|
||||
f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n"
|
||||
|
|
Loading…
Reference in New Issue