Add padding keys and download_data option
This commit is contained in:
parent
7f680886b0
commit
3ea53124e0
|
@ -52,6 +52,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
image_transforms: Callable | None = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
tolerance_s: float = 1e-4,
|
||||
download_data: bool = True,
|
||||
video_backend: str | None = None,
|
||||
):
|
||||
"""LeRobotDataset encapsulates 3 main things:
|
||||
|
@ -128,6 +129,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
timestamps is separated to the next by 1/fps +/- tolerance_s. This also applies to frames
|
||||
decoded from video files. It is also used to check that `delta_timestamps` (when provided) are
|
||||
multiples of 1/fps. Defaults to 1e-4.
|
||||
download_data (bool, optional): Flag to download actual data. Defaults to True.
|
||||
video_backend (str | None, optional): Video backend to use for decoding videos. There is currently
|
||||
a single option which is the pyav decoder used by Torchvision. Defaults to pyav.
|
||||
"""
|
||||
|
@ -139,6 +141,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
self.delta_timestamps = delta_timestamps
|
||||
self.episodes = episodes
|
||||
self.tolerance_s = tolerance_s
|
||||
self.download_data = download_data
|
||||
self.video_backend = video_backend if video_backend is not None else "pyav"
|
||||
self.delta_indices = None
|
||||
|
||||
|
@ -149,6 +152,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
self.stats = load_stats(repo_id, self._version, self.root)
|
||||
self.tasks = load_tasks(repo_id, self._version, self.root)
|
||||
|
||||
if not self.download_data:
|
||||
# TODO(aliberts): Add actual support for this
|
||||
# maybe use local_files_only=True or HF_HUB_OFFLINE=True
|
||||
# see thread https://huggingface.slack.com/archives/C06ME3E7JUD/p1728637455476019
|
||||
self.hf_dataset, self.episode_data_index = None, None
|
||||
return
|
||||
|
||||
# Load actual data
|
||||
self.download_episodes()
|
||||
self.hf_dataset = load_hf_dataset(self.root, self.data_path, self.total_episodes, self.episodes)
|
||||
|
@ -243,6 +253,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
"""Keys to access image and video streams from cameras (regardless of their storage method)."""
|
||||
return self.image_keys + self.video_keys
|
||||
|
||||
@property
|
||||
def names(self) -> dict[list[str]]:
|
||||
"""Names of the various dimensions of vector modalities."""
|
||||
return self.info["names"]
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
"""Number of samples/frames."""
|
||||
|
@ -275,21 +290,29 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
"""Number of samples/frames for given episode."""
|
||||
return self.info["episodes"][episode_index]["length"]
|
||||
|
||||
def _get_query_indices(self, idx: int, ep_idx: int) -> dict[str, list[int]]:
|
||||
# Pad values outside of current episode range
|
||||
def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]:
|
||||
ep_start = self.episode_data_index["from"][ep_idx]
|
||||
ep_end = self.episode_data_index["to"][ep_idx]
|
||||
return {
|
||||
query_indices = {
|
||||
key: [max(ep_start.item(), min(ep_end.item() - 1, idx + delta)) for delta in delta_idx]
|
||||
for key, delta_idx in self.delta_indices.items()
|
||||
}
|
||||
padding = { # Pad values outside of current episode range
|
||||
f"{key}_is_pad": torch.BoolTensor(
|
||||
[(idx + delta < ep_start.item()) | (idx + delta >= ep_end.item()) for delta in delta_idx]
|
||||
)
|
||||
for key, delta_idx in self.delta_indices.items()
|
||||
}
|
||||
return query_indices, padding
|
||||
|
||||
def _get_query_timestamps(
|
||||
self, query_indices: dict[str, list[int]], current_ts: float
|
||||
self,
|
||||
current_ts: float,
|
||||
query_indices: dict[str, list[int]] | None = None,
|
||||
) -> dict[str, list[float]]:
|
||||
query_timestamps = {}
|
||||
for key in self.video_keys:
|
||||
if key in query_indices:
|
||||
if query_indices is not None and key in query_indices:
|
||||
timestamps = self.hf_dataset.select(query_indices[key])["timestamp"]
|
||||
query_timestamps[key] = torch.stack(timestamps).tolist()
|
||||
else:
|
||||
|
@ -320,6 +343,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
|
||||
return item
|
||||
|
||||
def _add_padding_keys(self, item: dict, padding: dict[str, list[bool]]) -> dict:
|
||||
for key, val in padding.items():
|
||||
item[key] = torch.BoolTensor(val)
|
||||
return item
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
|
@ -327,16 +355,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
item = self.hf_dataset[idx]
|
||||
ep_idx = item["episode_index"].item()
|
||||
|
||||
query_indices = None
|
||||
if self.delta_indices is not None:
|
||||
current_ep_idx = self.episodes.index(ep_idx) if self.episodes is not None else ep_idx
|
||||
query_indices = self._get_query_indices(idx, current_ep_idx)
|
||||
query_indices, padding = self._get_query_indices(idx, current_ep_idx)
|
||||
query_result = self._query_hf_dataset(query_indices)
|
||||
item = {**item, **padding}
|
||||
for key, val in query_result.items():
|
||||
item[key] = val
|
||||
|
||||
if len(self.video_keys) > 0:
|
||||
current_ts = item["timestamp"].item()
|
||||
query_timestamps = self._get_query_timestamps(query_indices, current_ts)
|
||||
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
|
||||
video_frames = self._query_videos(query_timestamps, ep_idx)
|
||||
item = {**video_frames, **item}
|
||||
|
||||
|
|
Loading…
Reference in New Issue