Add padding keys and download_data option

This commit is contained in:
Simon Alibert 2024-10-11 17:38:47 +02:00
parent 7f680886b0
commit 3ea53124e0
1 changed files with 37 additions and 7 deletions

View File

@ -52,6 +52,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
image_transforms: Callable | None = None, image_transforms: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None, delta_timestamps: dict[list[float]] | None = None,
tolerance_s: float = 1e-4, tolerance_s: float = 1e-4,
download_data: bool = True,
video_backend: str | None = None, video_backend: str | None = None,
): ):
"""LeRobotDataset encapsulates 3 main things: """LeRobotDataset encapsulates 3 main things:
@ -128,6 +129,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
timestamps is separated to the next by 1/fps +/- tolerance_s. This also applies to frames timestamps is separated to the next by 1/fps +/- tolerance_s. This also applies to frames
decoded from video files. It is also used to check that `delta_timestamps` (when provided) are decoded from video files. It is also used to check that `delta_timestamps` (when provided) are
multiples of 1/fps. Defaults to 1e-4. multiples of 1/fps. Defaults to 1e-4.
download_data (bool, optional): Flag to download actual data. Defaults to True.
video_backend (str | None, optional): Video backend to use for decoding videos. There is currently video_backend (str | None, optional): Video backend to use for decoding videos. There is currently
a single option which is the pyav decoder used by Torchvision. Defaults to pyav. a single option which is the pyav decoder used by Torchvision. Defaults to pyav.
""" """
@ -139,6 +141,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.delta_timestamps = delta_timestamps self.delta_timestamps = delta_timestamps
self.episodes = episodes self.episodes = episodes
self.tolerance_s = tolerance_s self.tolerance_s = tolerance_s
self.download_data = download_data
self.video_backend = video_backend if video_backend is not None else "pyav" self.video_backend = video_backend if video_backend is not None else "pyav"
self.delta_indices = None self.delta_indices = None
@ -149,6 +152,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.stats = load_stats(repo_id, self._version, self.root) self.stats = load_stats(repo_id, self._version, self.root)
self.tasks = load_tasks(repo_id, self._version, self.root) self.tasks = load_tasks(repo_id, self._version, self.root)
if not self.download_data:
# TODO(aliberts): Add actual support for this
# maybe use local_files_only=True or HF_HUB_OFFLINE=True
# see thread https://huggingface.slack.com/archives/C06ME3E7JUD/p1728637455476019
self.hf_dataset, self.episode_data_index = None, None
return
# Load actual data # Load actual data
self.download_episodes() self.download_episodes()
self.hf_dataset = load_hf_dataset(self.root, self.data_path, self.total_episodes, self.episodes) self.hf_dataset = load_hf_dataset(self.root, self.data_path, self.total_episodes, self.episodes)
@ -243,6 +253,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
"""Keys to access image and video streams from cameras (regardless of their storage method).""" """Keys to access image and video streams from cameras (regardless of their storage method)."""
return self.image_keys + self.video_keys return self.image_keys + self.video_keys
@property
def names(self) -> dict[list[str]]:
"""Names of the various dimensions of vector modalities."""
return self.info["names"]
@property @property
def num_samples(self) -> int: def num_samples(self) -> int:
"""Number of samples/frames.""" """Number of samples/frames."""
@ -275,21 +290,29 @@ class LeRobotDataset(torch.utils.data.Dataset):
"""Number of samples/frames for given episode.""" """Number of samples/frames for given episode."""
return self.info["episodes"][episode_index]["length"] return self.info["episodes"][episode_index]["length"]
def _get_query_indices(self, idx: int, ep_idx: int) -> dict[str, list[int]]: def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]:
# Pad values outside of current episode range
ep_start = self.episode_data_index["from"][ep_idx] ep_start = self.episode_data_index["from"][ep_idx]
ep_end = self.episode_data_index["to"][ep_idx] ep_end = self.episode_data_index["to"][ep_idx]
return { query_indices = {
key: [max(ep_start.item(), min(ep_end.item() - 1, idx + delta)) for delta in delta_idx] key: [max(ep_start.item(), min(ep_end.item() - 1, idx + delta)) for delta in delta_idx]
for key, delta_idx in self.delta_indices.items() for key, delta_idx in self.delta_indices.items()
} }
padding = { # Pad values outside of current episode range
f"{key}_is_pad": torch.BoolTensor(
[(idx + delta < ep_start.item()) | (idx + delta >= ep_end.item()) for delta in delta_idx]
)
for key, delta_idx in self.delta_indices.items()
}
return query_indices, padding
def _get_query_timestamps( def _get_query_timestamps(
self, query_indices: dict[str, list[int]], current_ts: float self,
current_ts: float,
query_indices: dict[str, list[int]] | None = None,
) -> dict[str, list[float]]: ) -> dict[str, list[float]]:
query_timestamps = {} query_timestamps = {}
for key in self.video_keys: for key in self.video_keys:
if key in query_indices: if query_indices is not None and key in query_indices:
timestamps = self.hf_dataset.select(query_indices[key])["timestamp"] timestamps = self.hf_dataset.select(query_indices[key])["timestamp"]
query_timestamps[key] = torch.stack(timestamps).tolist() query_timestamps[key] = torch.stack(timestamps).tolist()
else: else:
@ -320,6 +343,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
return item return item
def _add_padding_keys(self, item: dict, padding: dict[str, list[bool]]) -> dict:
for key, val in padding.items():
item[key] = torch.BoolTensor(val)
return item
def __len__(self): def __len__(self):
return self.num_samples return self.num_samples
@ -327,16 +355,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
item = self.hf_dataset[idx] item = self.hf_dataset[idx]
ep_idx = item["episode_index"].item() ep_idx = item["episode_index"].item()
query_indices = None
if self.delta_indices is not None: if self.delta_indices is not None:
current_ep_idx = self.episodes.index(ep_idx) if self.episodes is not None else ep_idx current_ep_idx = self.episodes.index(ep_idx) if self.episodes is not None else ep_idx
query_indices = self._get_query_indices(idx, current_ep_idx) query_indices, padding = self._get_query_indices(idx, current_ep_idx)
query_result = self._query_hf_dataset(query_indices) query_result = self._query_hf_dataset(query_indices)
item = {**item, **padding}
for key, val in query_result.items(): for key, val in query_result.items():
item[key] = val item[key] = val
if len(self.video_keys) > 0: if len(self.video_keys) > 0:
current_ts = item["timestamp"].item() current_ts = item["timestamp"].item()
query_timestamps = self._get_query_timestamps(query_indices, current_ts) query_timestamps = self._get_query_timestamps(current_ts, query_indices)
video_frames = self._query_videos(query_timestamps, ep_idx) video_frames = self._query_videos(query_timestamps, ep_idx)
item = {**video_frames, **item} item = {**video_frames, **item}