diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index b283a185..61d27287 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -52,6 +52,7 @@ class LeRobotDataset(torch.utils.data.Dataset): image_transforms: Callable | None = None, delta_timestamps: dict[list[float]] | None = None, tolerance_s: float = 1e-4, + download_data: bool = True, video_backend: str | None = None, ): """LeRobotDataset encapsulates 3 main things: @@ -128,6 +129,7 @@ class LeRobotDataset(torch.utils.data.Dataset): timestamps is separated to the next by 1/fps +/- tolerance_s. This also applies to frames decoded from video files. It is also used to check that `delta_timestamps` (when provided) are multiples of 1/fps. Defaults to 1e-4. + download_data (bool, optional): Flag to download actual data. Defaults to True. video_backend (str | None, optional): Video backend to use for decoding videos. There is currently a single option which is the pyav decoder used by Torchvision. Defaults to pyav. """ @@ -139,6 +141,7 @@ class LeRobotDataset(torch.utils.data.Dataset): self.delta_timestamps = delta_timestamps self.episodes = episodes self.tolerance_s = tolerance_s + self.download_data = download_data self.video_backend = video_backend if video_backend is not None else "pyav" self.delta_indices = None @@ -149,6 +152,13 @@ class LeRobotDataset(torch.utils.data.Dataset): self.stats = load_stats(repo_id, self._version, self.root) self.tasks = load_tasks(repo_id, self._version, self.root) + if not self.download_data: + # TODO(aliberts): Add actual support for this + # maybe use local_files_only=True or HF_HUB_OFFLINE=True + # see thread https://huggingface.slack.com/archives/C06ME3E7JUD/p1728637455476019 + self.hf_dataset, self.episode_data_index = None, None + return + # Load actual data self.download_episodes() self.hf_dataset = load_hf_dataset(self.root, self.data_path, self.total_episodes, self.episodes) @@ -243,6 +253,11 @@ class LeRobotDataset(torch.utils.data.Dataset): """Keys to access image and video streams from cameras (regardless of their storage method).""" return self.image_keys + self.video_keys + @property + def names(self) -> dict[list[str]]: + """Names of the various dimensions of vector modalities.""" + return self.info["names"] + @property def num_samples(self) -> int: """Number of samples/frames.""" @@ -275,21 +290,29 @@ class LeRobotDataset(torch.utils.data.Dataset): """Number of samples/frames for given episode.""" return self.info["episodes"][episode_index]["length"] - def _get_query_indices(self, idx: int, ep_idx: int) -> dict[str, list[int]]: - # Pad values outside of current episode range + def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]: ep_start = self.episode_data_index["from"][ep_idx] ep_end = self.episode_data_index["to"][ep_idx] - return { + query_indices = { key: [max(ep_start.item(), min(ep_end.item() - 1, idx + delta)) for delta in delta_idx] for key, delta_idx in self.delta_indices.items() } + padding = { # Pad values outside of current episode range + f"{key}_is_pad": torch.BoolTensor( + [(idx + delta < ep_start.item()) | (idx + delta >= ep_end.item()) for delta in delta_idx] + ) + for key, delta_idx in self.delta_indices.items() + } + return query_indices, padding def _get_query_timestamps( - self, query_indices: dict[str, list[int]], current_ts: float + self, + current_ts: float, + query_indices: dict[str, list[int]] | None = None, ) -> dict[str, list[float]]: query_timestamps = {} for key in self.video_keys: - if key in query_indices: + if query_indices is not None and key in query_indices: timestamps = self.hf_dataset.select(query_indices[key])["timestamp"] query_timestamps[key] = torch.stack(timestamps).tolist() else: @@ -320,6 +343,11 @@ class LeRobotDataset(torch.utils.data.Dataset): return item + def _add_padding_keys(self, item: dict, padding: dict[str, list[bool]]) -> dict: + for key, val in padding.items(): + item[key] = torch.BoolTensor(val) + return item + def __len__(self): return self.num_samples @@ -327,16 +355,18 @@ class LeRobotDataset(torch.utils.data.Dataset): item = self.hf_dataset[idx] ep_idx = item["episode_index"].item() + query_indices = None if self.delta_indices is not None: current_ep_idx = self.episodes.index(ep_idx) if self.episodes is not None else ep_idx - query_indices = self._get_query_indices(idx, current_ep_idx) + query_indices, padding = self._get_query_indices(idx, current_ep_idx) query_result = self._query_hf_dataset(query_indices) + item = {**item, **padding} for key, val in query_result.items(): item[key] = val if len(self.video_keys) > 0: current_ts = item["timestamp"].item() - query_timestamps = self._get_query_timestamps(query_indices, current_ts) + query_timestamps = self._get_query_timestamps(current_ts, query_indices) video_frames = self._query_videos(query_timestamps, ep_idx) item = {**video_frames, **item}