From b417cebc4e0c2dd8cc087d17684ed25902c91854 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Thu, 10 Oct 2024 21:32:14 +0200 Subject: [PATCH] Update LeRobotDataset.__get_item__ --- lerobot/common/datasets/lerobot_dataset.py | 191 +++++++++++++++------ lerobot/common/datasets/utils.py | 130 ++++++++++---- lerobot/common/datasets/video_utils.py | 39 +---- 3 files changed, 232 insertions(+), 128 deletions(-) diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 35e9c762..b91eb75f 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -15,25 +15,27 @@ # limitations under the License. import logging import os -from itertools import accumulate from pathlib import Path from typing import Callable import datasets import torch import torch.utils +from huggingface_hub import snapshot_download from lerobot.common.datasets.compute_stats import aggregate_stats from lerobot.common.datasets.utils import ( - download_episodes, + check_delta_timestamps, + check_timestamps_sync, + get_delta_indices, + get_episode_data_index, get_hub_safe_version, load_hf_dataset, load_info, - load_previous_and_future_frames, load_stats, load_tasks, ) -from lerobot.common.datasets.video_utils import VideoFrame, load_from_videos +from lerobot.common.datasets.video_utils import VideoFrame, decode_video_frames_torchvision # For maintainers, see lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md CODEBASE_VERSION = "v2.0" @@ -49,6 +51,7 @@ class LeRobotDataset(torch.utils.data.Dataset): split: str = "train", image_transforms: Callable | None = None, delta_timestamps: dict[list[float]] | None = None, + tolerance_s: float = 1e-4, video_backend: str | None = None, ): super().__init__() @@ -58,7 +61,9 @@ class LeRobotDataset(torch.utils.data.Dataset): self.image_transforms = image_transforms self.delta_timestamps = delta_timestamps self.episodes = episodes + self.tolerance_s = tolerance_s self.video_backend = video_backend if video_backend is not None else "pyav" + self.delta_indices = None # Load metadata self.root.mkdir(exist_ok=True, parents=True) @@ -68,34 +73,60 @@ class LeRobotDataset(torch.utils.data.Dataset): self.tasks = load_tasks(repo_id, self._version, self.root) # Load actual data - download_episodes( - repo_id, - self._version, - self.root, - self.data_path, - self.video_keys, - self.num_episodes, - self.episodes, - self.videos_path, - ) + self.download_episodes() self.hf_dataset = load_hf_dataset(self.root, self.data_path, self.total_episodes, self.episodes) - self.episode_data_index = self.get_episode_data_index() + self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts) + + # Check timestamps + check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s) + + # Setup delta_indices + if self.delta_timestamps is not None: + check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s) + self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps) # TODO(aliberts): - # - [ ] Update __get_item__ + # - [X] Move delta_timestamp logic outside __get_item__ + # - [X] Update __get_item__ + # - [ ] Add self.add_frame() # - [ ] Add self.consolidate() for: + # - [X] Check timestamps sync # - [ ] Sanity checks (episodes num, shapes, files, etc.) # - [ ] Update episode_index (arg update=True) # - [ ] Update info.json (arg update=True) - # TODO(aliberts): remove (deprecated) - # if split == "train": - # self.episode_data_index = load_episode_data_index(self.episodes, self.episode_list) - # else: - # self.episode_data_index = calculate_episode_data_index(self.hf_dataset) - # self.hf_dataset = reset_episode_index(self.hf_dataset) - # if self.video: - # self.videos_dir = load_videos(repo_id, CODEBASE_VERSION, root) + def download_episodes(self) -> None: + """Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this + will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole + dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present + in 'local_dir', they won't be downloaded again. + + Note: Currently, if you're running this code offline but you already have the files in 'local_dir', + snapshot_download will still fail. This behavior will be fixed in an upcoming update of huggingface_hub. + """ + # TODO(rcadene, aliberts): implement faster transfer + # https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads + files = None + if self.episodes is not None: + files = [ + self.data_path.format(episode_index=ep_idx, total_episodes=self.total_episodes) + for ep_idx in self.episodes + ] + if len(self.video_keys) > 0: + video_files = [ + self.videos_path.format(video_key=vid_key, episode_index=ep_idx) + for vid_key in self.video_keys + for ep_idx in self.episodes + ] + files += video_files + + snapshot_download( + self.repo_id, + repo_type="dataset", + revision=self._version, + local_dir=self.root, + allow_patterns=files, + ) @property def data_path(self) -> str: @@ -134,17 +165,20 @@ class LeRobotDataset(torch.utils.data.Dataset): @property def camera_keys(self) -> list[str]: - """Keys to access image and video streams from cameras.""" + """Keys to access image and video streams from cameras (regardless of their storage method).""" return self.image_keys + self.video_keys @property def video_frame_keys(self) -> list[str]: - """Keys to access video frames that requires to be decoded into images. + """ + DEPRECATED, USE 'video_keys' INSTEAD + Keys to access video frames that requires to be decoded into images. Note: It is empty if the dataset contains images only, or equal to `self.cameras` if the dataset contains videos only, or can even be a subset of `self.cameras` in a case of a mixed image/video dataset. """ + # TODO(aliberts): remove video_frame_keys = [] for key, feats in self.hf_dataset.features.items(): if isinstance(feats, VideoFrame): @@ -166,54 +200,97 @@ class LeRobotDataset(torch.utils.data.Dataset): """Total number of episodes available.""" return self.info["total_episodes"] - @property - def tolerance_s(self) -> float: - """Tolerance in seconds used to discard loaded frames when their timestamps - are not close enough from the requested frames. It is only used when `delta_timestamps` - is provided or when loading video frames from mp4 files. - """ - # 1e-4 to account for possible numerical error - return 1 / self.fps - 1e-4 + # @property + # def tolerance_s(self) -> float: + # """Tolerance in seconds used to discard loaded frames when their timestamps + # are not close enough from the requested frames. It is used at the init of the dataset to make sure + # that each timestamps is separated to the next by 1/fps +/- tolerance. It is only used when + # `delta_timestamps` is provided or when loading video frames from mp4 files. + # """ + # # 1e-4 to account for possible numerical error + # return 1e-4 @property def shapes(self) -> dict: """Shapes for the different features.""" self.info.get("shapes") - def get_episode_data_index(self) -> dict[str, torch.Tensor]: - episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in enumerate(self.episode_dicts)} + def current_episode_index(self, idx: int) -> int: + episode_index = self.hf_dataset["episode_index"][idx] if self.episodes is not None: - episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in self.episodes} + # get episode_index from selected episodes + episode_index = self.episodes.index(episode_index) - cumulative_lenghts = list(accumulate(episode_lengths.values())) + return episode_index + + def episode_length(self, episode_index) -> int: + """Number of samples/frames for given episode.""" + return self.info["episodes"][episode_index]["length"] + + def _get_query_indices(self, idx: int, ep_idx: int) -> dict[str, list[int]]: + # Pad values outside of current episode range + ep_start = self.episode_data_index["from"][ep_idx] + ep_end = self.episode_data_index["to"][ep_idx] return { - "from": torch.LongTensor([0] + cumulative_lenghts[:-1]), - "to": torch.LongTensor(cumulative_lenghts), + key: [max(ep_start.item(), min(ep_end.item() - 1, idx + delta)) for delta in delta_idx] + for key, delta_idx in self.delta_indices.items() } + def _get_query_timestamps( + self, query_indices: dict[str, list[int]], current_ts: float + ) -> dict[str, list[float]]: + query_timestamps = {} + for key in self.video_keys: + if key in query_indices: + timestamps = self.hf_dataset.select(query_indices[key])["timestamp"] + query_timestamps[key] = torch.stack(timestamps).tolist() + else: + query_timestamps[key] = [current_ts] + + return query_timestamps + + def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict: + return { + key: torch.stack(self.hf_dataset.select(q_idx)[key]) + for key, q_idx in query_indices.items() + if key not in self.video_keys + } + + def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict: + """Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function + in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a + Segmentation Fault. This probably happens because a memory reference to the video loader is created in + the main process and a subprocess fails to access it. + """ + item = {} + for vid_key, query_ts in query_timestamps.items(): + video_path = self.root / self.videos_path.format(video_key=vid_key, episode_index=ep_idx) + frames = decode_video_frames_torchvision( + video_path, query_ts, self.tolerance_s, self.video_backend + ) + item[vid_key] = frames + + return item + def __len__(self): return self.num_samples - def __getitem__(self, idx): + def __getitem__(self, idx) -> dict: item = self.hf_dataset[idx] + ep_idx = item["episode_index"].item() - if self.delta_timestamps is not None: - item = load_previous_and_future_frames( - item, - self.hf_dataset, - self.episode_data_index, - self.delta_timestamps, - self.tolerance_s, - ) + if self.delta_indices is not None: + current_ep_idx = self.episodes.index(ep_idx) if self.episodes is not None else ep_idx + query_indices = self._get_query_indices(idx, current_ep_idx) + query_result = self._query_hf_dataset(query_indices) + for key, val in query_result.items(): + item[key] = val - if self.video: - item = load_from_videos( - item, - self.video_keys, - self.videos_dir, - self.tolerance_s, - self.video_backend, - ) + if len(self.video_keys) > 0: + current_ts = item["timestamp"].item() + query_timestamps = self._get_query_timestamps(query_indices, current_ts) + video_frames = self._query_videos(query_timestamps, ep_idx) + item = {**video_frames, **item} if self.image_transforms is not None: for cam in self.camera_keys: diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index fd76ccd1..9b70d4f6 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -16,13 +16,15 @@ import json import warnings from functools import cache +from itertools import accumulate from pathlib import Path +from pprint import pformat from typing import Dict import datasets import torch from datasets import load_dataset -from huggingface_hub import DatasetCard, HfApi, hf_hub_download, snapshot_download +from huggingface_hub import DatasetCard, HfApi, hf_hub_download from PIL import Image as PILImage from torchvision import transforms @@ -193,40 +195,102 @@ def load_tasks(repo_id: str, version: str, local_dir: Path) -> dict: return json.load(f) -def download_episodes( - repo_id: str, - version: str, - local_dir: Path, - data_path: str, - video_keys: list, - total_episodes: int, - episodes: list[int] | None = None, - videos_path: str | None = None, -) -> None: - """Downloads the dataset from the given 'repo_id' at the provided 'version'. If 'episodes' is given, this - will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole - dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present - in 'local_dir', they won't be downloaded again. - - Note: Currently, if you're running this code offline but you already have the files in 'local_dir', - snapshot_download will still fail. This behavior will be fixed in an upcoming update of huggingface_hub. - """ - # TODO(rcadene, aliberts): implement faster transfer - # https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads - files = None +def get_episode_data_index(episodes: list, episode_dicts: list[dict]) -> dict[str, torch.Tensor]: + episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in enumerate(episode_dicts)} if episodes is not None: - files = [data_path.format(episode_index=ep_idx, total_episodes=total_episodes) for ep_idx in episodes] - if len(video_keys) > 0: - video_files = [ - videos_path.format(video_key=vid_key, episode_index=ep_idx) - for vid_key in video_keys - for ep_idx in episodes - ] - files += video_files + episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes} - snapshot_download( - repo_id, repo_type="dataset", revision=version, local_dir=local_dir, allow_patterns=files - ) + cumulative_lenghts = list(accumulate(episode_lengths.values())) + return { + "from": torch.LongTensor([0] + cumulative_lenghts[:-1]), + "to": torch.LongTensor(cumulative_lenghts), + } + + +def check_timestamps_sync( + hf_dataset: datasets.Dataset, + episode_data_index: dict[str, torch.Tensor], + fps: int, + tolerance_s: float, + raise_value_error: bool = True, +) -> bool: + """ + This check is to make sure that each timestamps is separated to the next by 1/fps +/- tolerance to + account for possible numerical error. + """ + timestamps = torch.stack(hf_dataset["timestamp"]) + # timestamps[2] += tolerance_s # TODO delete + # timestamps[-2] += tolerance_s/2 # TODO delete + diffs = torch.diff(timestamps) + within_tolerance = torch.abs(diffs - 1 / fps) <= tolerance_s + + # We mask differences between the timestamp at the end of an episode + # and the one the start of the next episode since these are expected + # to be outside tolerance. + mask = torch.ones(len(diffs), dtype=torch.bool) + ignored_diffs = episode_data_index["to"][:-1] - 1 + mask[ignored_diffs] = False + filtered_within_tolerance = within_tolerance[mask] + + if not torch.all(filtered_within_tolerance): + # Track original indices before masking + original_indices = torch.arange(len(diffs)) + filtered_indices = original_indices[mask] + outside_tolerance_filtered_indices = torch.nonzero(~filtered_within_tolerance).squeeze() + outside_tolerance_indices = filtered_indices[outside_tolerance_filtered_indices] + episode_indices = torch.stack(hf_dataset["episode_index"]) + + outside_tolerances = [] + for idx in outside_tolerance_indices: + entry = { + "timestamps": [timestamps[idx], timestamps[idx + 1]], + "diff": diffs[idx], + "episode_index": episode_indices[idx].item(), + } + outside_tolerances.append(entry) + + if raise_value_error: + raise ValueError( + f"""One or several timestamps unexpectedly violate the tolerance inside episode range. + This might be due to synchronization issues with timestamps during data collection. + \n{pformat(outside_tolerances)}""" + ) + return False + + return True + + +def check_delta_timestamps( + delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True +) -> bool: + outside_tolerance = {} + for key, delta_ts in delta_timestamps.items(): + abs_delta_ts = torch.abs(torch.tensor(delta_ts)) + within_tolerance = (abs_delta_ts % (1 / fps)) <= tolerance_s + if not torch.all(within_tolerance): + outside_tolerance[key] = torch.tensor(delta_ts)[~within_tolerance] + + if len(outside_tolerance) > 0: + if raise_value_error: + raise ValueError( + f""" + The following delta_timestamps are found outside of tolerance range. + Please make sure they are multiples of 1/{fps} +/- tolerance and adjust + their values accordingly. + \n{pformat(outside_tolerance)} + """ + ) + return False + + return True + + +def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]: + delta_indices = {} + for key, delta_ts in delta_timestamps.items(): + delta_indices[key] = (torch.tensor(delta_ts) * fps).long().tolist() + + return delta_indices def load_previous_and_future_frames( diff --git a/lerobot/common/datasets/video_utils.py b/lerobot/common/datasets/video_utils.py index 4d4ac6b0..6a606415 100644 --- a/lerobot/common/datasets/video_utils.py +++ b/lerobot/common/datasets/video_utils.py @@ -27,45 +27,8 @@ import torchvision from datasets.features.features import register_feature -def load_from_videos( - item: dict[str, torch.Tensor], - video_frame_keys: list[str], - videos_dir: Path, - tolerance_s: float, - backend: str = "pyav", -): - """Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function - in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a Segmentation Fault. - This probably happens because a memory reference to the video loader is created in the main process and a - subprocess fails to access it. - """ - # since video path already contains "videos" (e.g. videos_dir="data/videos", path="videos/episode_0.mp4") - data_dir = videos_dir.parent - - for key in video_frame_keys: - if isinstance(item[key], list): - # load multiple frames at once (expected when delta_timestamps is not None) - timestamps = [frame["timestamp"] for frame in item[key]] - paths = [frame["path"] for frame in item[key]] - if len(set(paths)) > 1: - raise NotImplementedError("All video paths are expected to be the same for now.") - video_path = data_dir / paths[0] - - frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend) - item[key] = frames - else: - # load one frame - timestamps = [item[key]["timestamp"]] - video_path = data_dir / item[key]["path"] - - frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend) - item[key] = frames[0] - - return item - - def decode_video_frames_torchvision( - video_path: str, + video_path: Path | str, timestamps: list[float], tolerance_s: float, backend: str = "pyav",