Add suggestions from code review
This commit is contained in:
parent
3ea53124e0
commit
8bd406e607
|
@ -48,11 +48,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
repo_id: str,
|
repo_id: str,
|
||||||
root: Path | None = None,
|
root: Path | None = None,
|
||||||
episodes: list[int] | None = None,
|
episodes: list[int] | None = None,
|
||||||
split: str = "train",
|
|
||||||
image_transforms: Callable | None = None,
|
image_transforms: Callable | None = None,
|
||||||
delta_timestamps: dict[list[float]] | None = None,
|
delta_timestamps: dict[list[float]] | None = None,
|
||||||
tolerance_s: float = 1e-4,
|
tolerance_s: float = 1e-4,
|
||||||
download_data: bool = True,
|
download_videos: bool = True,
|
||||||
video_backend: str | None = None,
|
video_backend: str | None = None,
|
||||||
):
|
):
|
||||||
"""LeRobotDataset encapsulates 3 main things:
|
"""LeRobotDataset encapsulates 3 main things:
|
||||||
|
@ -64,7 +63,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
- hf_dataset (from datasets.Dataset), which will read any values from parquet files.
|
- hf_dataset (from datasets.Dataset), which will read any values from parquet files.
|
||||||
- (optional) videos from which frames are loaded to be synchronous with data from parquet files.
|
- (optional) videos from which frames are loaded to be synchronous with data from parquet files.
|
||||||
|
|
||||||
3 use modes are available for this class, depending on 3 different use cases:
|
3 modes are available for this class, depending on 3 different use cases:
|
||||||
|
|
||||||
1. Your dataset already exists on the Hugging Face Hub at the address
|
1. Your dataset already exists on the Hugging Face Hub at the address
|
||||||
https://huggingface.co/datasets/{repo_id} and is not on your local disk in the 'root' folder:
|
https://huggingface.co/datasets/{repo_id} and is not on your local disk in the 'root' folder:
|
||||||
|
@ -119,7 +118,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
'~/.cache/huggingface/lerobot'.
|
'~/.cache/huggingface/lerobot'.
|
||||||
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
|
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
|
||||||
their episode_index in this list. Defaults to None.
|
their episode_index in this list. Defaults to None.
|
||||||
split (str, optional): _description_. Defaults to "train".
|
|
||||||
image_transforms (Callable | None, optional): You can pass standard v2 image transforms from
|
image_transforms (Callable | None, optional): You can pass standard v2 image transforms from
|
||||||
torchvision.transforms.v2 here which will be applied to visual modalities (whether they come
|
torchvision.transforms.v2 here which will be applied to visual modalities (whether they come
|
||||||
from videos or images). Defaults to None.
|
from videos or images). Defaults to None.
|
||||||
|
@ -129,19 +127,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
timestamps is separated to the next by 1/fps +/- tolerance_s. This also applies to frames
|
timestamps is separated to the next by 1/fps +/- tolerance_s. This also applies to frames
|
||||||
decoded from video files. It is also used to check that `delta_timestamps` (when provided) are
|
decoded from video files. It is also used to check that `delta_timestamps` (when provided) are
|
||||||
multiples of 1/fps. Defaults to 1e-4.
|
multiples of 1/fps. Defaults to 1e-4.
|
||||||
download_data (bool, optional): Flag to download actual data. Defaults to True.
|
download_videos (bool, optional): Flag to download the videos. Defaults to True.
|
||||||
video_backend (str | None, optional): Video backend to use for decoding videos. There is currently
|
video_backend (str | None, optional): Video backend to use for decoding videos. There is currently
|
||||||
a single option which is the pyav decoder used by Torchvision. Defaults to pyav.
|
a single option which is the pyav decoder used by Torchvision. Defaults to pyav.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.repo_id = repo_id
|
self.repo_id = repo_id
|
||||||
self.root = root if root is not None else LEROBOT_HOME / repo_id
|
self.root = root if root is not None else LEROBOT_HOME / repo_id
|
||||||
self.split = split
|
|
||||||
self.image_transforms = image_transforms
|
self.image_transforms = image_transforms
|
||||||
self.delta_timestamps = delta_timestamps
|
self.delta_timestamps = delta_timestamps
|
||||||
self.episodes = episodes
|
self.episodes = episodes
|
||||||
self.tolerance_s = tolerance_s
|
self.tolerance_s = tolerance_s
|
||||||
self.download_data = download_data
|
self.download_videos = download_videos
|
||||||
self.video_backend = video_backend if video_backend is not None else "pyav"
|
self.video_backend = video_backend if video_backend is not None else "pyav"
|
||||||
self.delta_indices = None
|
self.delta_indices = None
|
||||||
|
|
||||||
|
@ -152,13 +149,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
self.stats = load_stats(repo_id, self._version, self.root)
|
self.stats = load_stats(repo_id, self._version, self.root)
|
||||||
self.tasks = load_tasks(repo_id, self._version, self.root)
|
self.tasks = load_tasks(repo_id, self._version, self.root)
|
||||||
|
|
||||||
if not self.download_data:
|
|
||||||
# TODO(aliberts): Add actual support for this
|
|
||||||
# maybe use local_files_only=True or HF_HUB_OFFLINE=True
|
|
||||||
# see thread https://huggingface.slack.com/archives/C06ME3E7JUD/p1728637455476019
|
|
||||||
self.hf_dataset, self.episode_data_index = None, None
|
|
||||||
return
|
|
||||||
|
|
||||||
# Load actual data
|
# Load actual data
|
||||||
self.download_episodes()
|
self.download_episodes()
|
||||||
self.hf_dataset = load_hf_dataset(self.root, self.data_path, self.total_episodes, self.episodes)
|
self.hf_dataset = load_hf_dataset(self.root, self.data_path, self.total_episodes, self.episodes)
|
||||||
|
@ -192,12 +182,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
# TODO(rcadene, aliberts): implement faster transfer
|
# TODO(rcadene, aliberts): implement faster transfer
|
||||||
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
|
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
|
||||||
files = None
|
files = None
|
||||||
|
ignore_patterns = None if self.download_videos else "videos/"
|
||||||
if self.episodes is not None:
|
if self.episodes is not None:
|
||||||
files = [
|
files = [
|
||||||
self.data_path.format(episode_index=ep_idx, total_episodes=self.total_episodes)
|
self.data_path.format(episode_index=ep_idx, total_episodes=self.total_episodes)
|
||||||
for ep_idx in self.episodes
|
for ep_idx in self.episodes
|
||||||
]
|
]
|
||||||
if len(self.video_keys) > 0:
|
if len(self.video_keys) > 0 and self.download_videos:
|
||||||
video_files = [
|
video_files = [
|
||||||
self.videos_path.format(video_key=vid_key, episode_index=ep_idx)
|
self.videos_path.format(video_key=vid_key, episode_index=ep_idx)
|
||||||
for vid_key in self.video_keys
|
for vid_key in self.video_keys
|
||||||
|
@ -211,6 +202,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
revision=self._version,
|
revision=self._version,
|
||||||
local_dir=self.root,
|
local_dir=self.root,
|
||||||
allow_patterns=files,
|
allow_patterns=files,
|
||||||
|
ignore_patterns=ignore_patterns,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -371,7 +363,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
item = {**video_frames, **item}
|
item = {**video_frames, **item}
|
||||||
|
|
||||||
if self.image_transforms is not None:
|
if self.image_transforms is not None:
|
||||||
for cam in self.camera_keys:
|
image_keys = self.camera_keys if self.download_videos else self.image_keys
|
||||||
|
for cam in image_keys:
|
||||||
item[cam] = self.image_transforms(item[cam])
|
item[cam] = self.image_transforms(item[cam])
|
||||||
|
|
||||||
return item
|
return item
|
||||||
|
@ -380,7 +373,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
return (
|
return (
|
||||||
f"{self.__class__.__name__}(\n"
|
f"{self.__class__.__name__}(\n"
|
||||||
f" Repository ID: '{self.repo_id}',\n"
|
f" Repository ID: '{self.repo_id}',\n"
|
||||||
f" Split: '{self.split}',\n"
|
|
||||||
f" Number of Samples: {self.num_samples},\n"
|
f" Number of Samples: {self.num_samples},\n"
|
||||||
f" Number of Episodes: {self.num_episodes},\n"
|
f" Number of Episodes: {self.num_episodes},\n"
|
||||||
f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n"
|
f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n"
|
||||||
|
|
Loading…
Reference in New Issue