Add padding keys and download_data option

This commit is contained in:
Simon Alibert 2024-10-11 17:38:47 +02:00
parent 7f680886b0
commit 3ea53124e0
1 changed files with 37 additions and 7 deletions

View File

@ -52,6 +52,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
image_transforms: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None,
tolerance_s: float = 1e-4,
download_data: bool = True,
video_backend: str | None = None,
):
"""LeRobotDataset encapsulates 3 main things:
@ -128,6 +129,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
timestamps is separated to the next by 1/fps +/- tolerance_s. This also applies to frames
decoded from video files. It is also used to check that `delta_timestamps` (when provided) are
multiples of 1/fps. Defaults to 1e-4.
download_data (bool, optional): Flag to download actual data. Defaults to True.
video_backend (str | None, optional): Video backend to use for decoding videos. There is currently
a single option which is the pyav decoder used by Torchvision. Defaults to pyav.
"""
@ -139,6 +141,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.delta_timestamps = delta_timestamps
self.episodes = episodes
self.tolerance_s = tolerance_s
self.download_data = download_data
self.video_backend = video_backend if video_backend is not None else "pyav"
self.delta_indices = None
@ -149,6 +152,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.stats = load_stats(repo_id, self._version, self.root)
self.tasks = load_tasks(repo_id, self._version, self.root)
if not self.download_data:
# TODO(aliberts): Add actual support for this
# maybe use local_files_only=True or HF_HUB_OFFLINE=True
# see thread https://huggingface.slack.com/archives/C06ME3E7JUD/p1728637455476019
self.hf_dataset, self.episode_data_index = None, None
return
# Load actual data
self.download_episodes()
self.hf_dataset = load_hf_dataset(self.root, self.data_path, self.total_episodes, self.episodes)
@ -243,6 +253,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
"""Keys to access image and video streams from cameras (regardless of their storage method)."""
return self.image_keys + self.video_keys
@property
def names(self) -> dict[list[str]]:
"""Names of the various dimensions of vector modalities."""
return self.info["names"]
@property
def num_samples(self) -> int:
"""Number of samples/frames."""
@ -275,21 +290,29 @@ class LeRobotDataset(torch.utils.data.Dataset):
"""Number of samples/frames for given episode."""
return self.info["episodes"][episode_index]["length"]
def _get_query_indices(self, idx: int, ep_idx: int) -> dict[str, list[int]]:
# Pad values outside of current episode range
def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]:
ep_start = self.episode_data_index["from"][ep_idx]
ep_end = self.episode_data_index["to"][ep_idx]
return {
query_indices = {
key: [max(ep_start.item(), min(ep_end.item() - 1, idx + delta)) for delta in delta_idx]
for key, delta_idx in self.delta_indices.items()
}
padding = { # Pad values outside of current episode range
f"{key}_is_pad": torch.BoolTensor(
[(idx + delta < ep_start.item()) | (idx + delta >= ep_end.item()) for delta in delta_idx]
)
for key, delta_idx in self.delta_indices.items()
}
return query_indices, padding
def _get_query_timestamps(
self, query_indices: dict[str, list[int]], current_ts: float
self,
current_ts: float,
query_indices: dict[str, list[int]] | None = None,
) -> dict[str, list[float]]:
query_timestamps = {}
for key in self.video_keys:
if key in query_indices:
if query_indices is not None and key in query_indices:
timestamps = self.hf_dataset.select(query_indices[key])["timestamp"]
query_timestamps[key] = torch.stack(timestamps).tolist()
else:
@ -320,6 +343,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
return item
def _add_padding_keys(self, item: dict, padding: dict[str, list[bool]]) -> dict:
for key, val in padding.items():
item[key] = torch.BoolTensor(val)
return item
def __len__(self):
return self.num_samples
@ -327,16 +355,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
item = self.hf_dataset[idx]
ep_idx = item["episode_index"].item()
query_indices = None
if self.delta_indices is not None:
current_ep_idx = self.episodes.index(ep_idx) if self.episodes is not None else ep_idx
query_indices = self._get_query_indices(idx, current_ep_idx)
query_indices, padding = self._get_query_indices(idx, current_ep_idx)
query_result = self._query_hf_dataset(query_indices)
item = {**item, **padding}
for key, val in query_result.items():
item[key] = val
if len(self.video_keys) > 0:
current_ts = item["timestamp"].item()
query_timestamps = self._get_query_timestamps(query_indices, current_ts)
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
video_frames = self._query_videos(query_timestamps, ep_idx)
item = {**video_frames, **item}