Add padding keys and download_data option
This commit is contained in:
parent
7f680886b0
commit
3ea53124e0
|
@ -52,6 +52,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
image_transforms: Callable | None = None,
|
image_transforms: Callable | None = None,
|
||||||
delta_timestamps: dict[list[float]] | None = None,
|
delta_timestamps: dict[list[float]] | None = None,
|
||||||
tolerance_s: float = 1e-4,
|
tolerance_s: float = 1e-4,
|
||||||
|
download_data: bool = True,
|
||||||
video_backend: str | None = None,
|
video_backend: str | None = None,
|
||||||
):
|
):
|
||||||
"""LeRobotDataset encapsulates 3 main things:
|
"""LeRobotDataset encapsulates 3 main things:
|
||||||
|
@ -128,6 +129,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
timestamps is separated to the next by 1/fps +/- tolerance_s. This also applies to frames
|
timestamps is separated to the next by 1/fps +/- tolerance_s. This also applies to frames
|
||||||
decoded from video files. It is also used to check that `delta_timestamps` (when provided) are
|
decoded from video files. It is also used to check that `delta_timestamps` (when provided) are
|
||||||
multiples of 1/fps. Defaults to 1e-4.
|
multiples of 1/fps. Defaults to 1e-4.
|
||||||
|
download_data (bool, optional): Flag to download actual data. Defaults to True.
|
||||||
video_backend (str | None, optional): Video backend to use for decoding videos. There is currently
|
video_backend (str | None, optional): Video backend to use for decoding videos. There is currently
|
||||||
a single option which is the pyav decoder used by Torchvision. Defaults to pyav.
|
a single option which is the pyav decoder used by Torchvision. Defaults to pyav.
|
||||||
"""
|
"""
|
||||||
|
@ -139,6 +141,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
self.delta_timestamps = delta_timestamps
|
self.delta_timestamps = delta_timestamps
|
||||||
self.episodes = episodes
|
self.episodes = episodes
|
||||||
self.tolerance_s = tolerance_s
|
self.tolerance_s = tolerance_s
|
||||||
|
self.download_data = download_data
|
||||||
self.video_backend = video_backend if video_backend is not None else "pyav"
|
self.video_backend = video_backend if video_backend is not None else "pyav"
|
||||||
self.delta_indices = None
|
self.delta_indices = None
|
||||||
|
|
||||||
|
@ -149,6 +152,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
self.stats = load_stats(repo_id, self._version, self.root)
|
self.stats = load_stats(repo_id, self._version, self.root)
|
||||||
self.tasks = load_tasks(repo_id, self._version, self.root)
|
self.tasks = load_tasks(repo_id, self._version, self.root)
|
||||||
|
|
||||||
|
if not self.download_data:
|
||||||
|
# TODO(aliberts): Add actual support for this
|
||||||
|
# maybe use local_files_only=True or HF_HUB_OFFLINE=True
|
||||||
|
# see thread https://huggingface.slack.com/archives/C06ME3E7JUD/p1728637455476019
|
||||||
|
self.hf_dataset, self.episode_data_index = None, None
|
||||||
|
return
|
||||||
|
|
||||||
# Load actual data
|
# Load actual data
|
||||||
self.download_episodes()
|
self.download_episodes()
|
||||||
self.hf_dataset = load_hf_dataset(self.root, self.data_path, self.total_episodes, self.episodes)
|
self.hf_dataset = load_hf_dataset(self.root, self.data_path, self.total_episodes, self.episodes)
|
||||||
|
@ -243,6 +253,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
"""Keys to access image and video streams from cameras (regardless of their storage method)."""
|
"""Keys to access image and video streams from cameras (regardless of their storage method)."""
|
||||||
return self.image_keys + self.video_keys
|
return self.image_keys + self.video_keys
|
||||||
|
|
||||||
|
@property
|
||||||
|
def names(self) -> dict[list[str]]:
|
||||||
|
"""Names of the various dimensions of vector modalities."""
|
||||||
|
return self.info["names"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_samples(self) -> int:
|
def num_samples(self) -> int:
|
||||||
"""Number of samples/frames."""
|
"""Number of samples/frames."""
|
||||||
|
@ -275,21 +290,29 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
"""Number of samples/frames for given episode."""
|
"""Number of samples/frames for given episode."""
|
||||||
return self.info["episodes"][episode_index]["length"]
|
return self.info["episodes"][episode_index]["length"]
|
||||||
|
|
||||||
def _get_query_indices(self, idx: int, ep_idx: int) -> dict[str, list[int]]:
|
def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]:
|
||||||
# Pad values outside of current episode range
|
|
||||||
ep_start = self.episode_data_index["from"][ep_idx]
|
ep_start = self.episode_data_index["from"][ep_idx]
|
||||||
ep_end = self.episode_data_index["to"][ep_idx]
|
ep_end = self.episode_data_index["to"][ep_idx]
|
||||||
return {
|
query_indices = {
|
||||||
key: [max(ep_start.item(), min(ep_end.item() - 1, idx + delta)) for delta in delta_idx]
|
key: [max(ep_start.item(), min(ep_end.item() - 1, idx + delta)) for delta in delta_idx]
|
||||||
for key, delta_idx in self.delta_indices.items()
|
for key, delta_idx in self.delta_indices.items()
|
||||||
}
|
}
|
||||||
|
padding = { # Pad values outside of current episode range
|
||||||
|
f"{key}_is_pad": torch.BoolTensor(
|
||||||
|
[(idx + delta < ep_start.item()) | (idx + delta >= ep_end.item()) for delta in delta_idx]
|
||||||
|
)
|
||||||
|
for key, delta_idx in self.delta_indices.items()
|
||||||
|
}
|
||||||
|
return query_indices, padding
|
||||||
|
|
||||||
def _get_query_timestamps(
|
def _get_query_timestamps(
|
||||||
self, query_indices: dict[str, list[int]], current_ts: float
|
self,
|
||||||
|
current_ts: float,
|
||||||
|
query_indices: dict[str, list[int]] | None = None,
|
||||||
) -> dict[str, list[float]]:
|
) -> dict[str, list[float]]:
|
||||||
query_timestamps = {}
|
query_timestamps = {}
|
||||||
for key in self.video_keys:
|
for key in self.video_keys:
|
||||||
if key in query_indices:
|
if query_indices is not None and key in query_indices:
|
||||||
timestamps = self.hf_dataset.select(query_indices[key])["timestamp"]
|
timestamps = self.hf_dataset.select(query_indices[key])["timestamp"]
|
||||||
query_timestamps[key] = torch.stack(timestamps).tolist()
|
query_timestamps[key] = torch.stack(timestamps).tolist()
|
||||||
else:
|
else:
|
||||||
|
@ -320,6 +343,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
|
|
||||||
return item
|
return item
|
||||||
|
|
||||||
|
def _add_padding_keys(self, item: dict, padding: dict[str, list[bool]]) -> dict:
|
||||||
|
for key, val in padding.items():
|
||||||
|
item[key] = torch.BoolTensor(val)
|
||||||
|
return item
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.num_samples
|
return self.num_samples
|
||||||
|
|
||||||
|
@ -327,16 +355,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
item = self.hf_dataset[idx]
|
item = self.hf_dataset[idx]
|
||||||
ep_idx = item["episode_index"].item()
|
ep_idx = item["episode_index"].item()
|
||||||
|
|
||||||
|
query_indices = None
|
||||||
if self.delta_indices is not None:
|
if self.delta_indices is not None:
|
||||||
current_ep_idx = self.episodes.index(ep_idx) if self.episodes is not None else ep_idx
|
current_ep_idx = self.episodes.index(ep_idx) if self.episodes is not None else ep_idx
|
||||||
query_indices = self._get_query_indices(idx, current_ep_idx)
|
query_indices, padding = self._get_query_indices(idx, current_ep_idx)
|
||||||
query_result = self._query_hf_dataset(query_indices)
|
query_result = self._query_hf_dataset(query_indices)
|
||||||
|
item = {**item, **padding}
|
||||||
for key, val in query_result.items():
|
for key, val in query_result.items():
|
||||||
item[key] = val
|
item[key] = val
|
||||||
|
|
||||||
if len(self.video_keys) > 0:
|
if len(self.video_keys) > 0:
|
||||||
current_ts = item["timestamp"].item()
|
current_ts = item["timestamp"].item()
|
||||||
query_timestamps = self._get_query_timestamps(query_indices, current_ts)
|
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
|
||||||
video_frames = self._query_videos(query_timestamps, ep_idx)
|
video_frames = self._query_videos(query_timestamps, ep_idx)
|
||||||
item = {**video_frames, **item}
|
item = {**video_frames, **item}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue