diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index eb76f78d..35e9c762 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -15,6 +15,7 @@ # limitations under the License. import logging import os +from itertools import accumulate from pathlib import Path from typing import Callable @@ -24,27 +25,27 @@ import torch.utils from lerobot.common.datasets.compute_stats import aggregate_stats from lerobot.common.datasets.utils import ( - calculate_episode_data_index, - load_episode_data_index, + download_episodes, + get_hub_safe_version, load_hf_dataset, load_info, load_previous_and_future_frames, load_stats, - load_videos, - reset_episode_index, + load_tasks, ) from lerobot.common.datasets.video_utils import VideoFrame, load_from_videos # For maintainers, see lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md -CODEBASE_VERSION = "v1.6" -DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None +CODEBASE_VERSION = "v2.0" +LEROBOT_HOME = Path(os.getenv("LEROBOT_HOME", "~/.cache/huggingface/lerobot")).expanduser() class LeRobotDataset(torch.utils.data.Dataset): def __init__( self, repo_id: str, - root: Path | None = DATA_DIR, + root: Path | None = None, + episodes: list[int] | None = None, split: str = "train", image_transforms: Callable | None = None, delta_timestamps: dict[list[float]] | None = None, @@ -52,24 +53,64 @@ class LeRobotDataset(torch.utils.data.Dataset): ): super().__init__() self.repo_id = repo_id - self.root = root + self.root = root if root is not None else LEROBOT_HOME / repo_id self.split = split self.image_transforms = image_transforms self.delta_timestamps = delta_timestamps - # load data from hub or locally when root is provided - # TODO(rcadene, aliberts): implement faster transfer - # https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads - self.hf_dataset = load_hf_dataset(repo_id, CODEBASE_VERSION, root, split) - if split == "train": - self.episode_data_index = load_episode_data_index(repo_id, CODEBASE_VERSION, root) - else: - self.episode_data_index = calculate_episode_data_index(self.hf_dataset) - self.hf_dataset = reset_episode_index(self.hf_dataset) - self.stats = load_stats(repo_id, CODEBASE_VERSION, root) - self.info = load_info(repo_id, CODEBASE_VERSION, root) - if self.video: - self.videos_dir = load_videos(repo_id, CODEBASE_VERSION, root) - self.video_backend = video_backend if video_backend is not None else "pyav" + self.episodes = episodes + self.video_backend = video_backend if video_backend is not None else "pyav" + + # Load metadata + self.root.mkdir(exist_ok=True, parents=True) + self._version = get_hub_safe_version(repo_id, CODEBASE_VERSION) + self.info = load_info(repo_id, self._version, self.root) + self.stats = load_stats(repo_id, self._version, self.root) + self.tasks = load_tasks(repo_id, self._version, self.root) + + # Load actual data + download_episodes( + repo_id, + self._version, + self.root, + self.data_path, + self.video_keys, + self.num_episodes, + self.episodes, + self.videos_path, + ) + self.hf_dataset = load_hf_dataset(self.root, self.data_path, self.total_episodes, self.episodes) + self.episode_data_index = self.get_episode_data_index() + + # TODO(aliberts): + # - [ ] Update __get_item__ + # - [ ] Add self.consolidate() for: + # - [ ] Sanity checks (episodes num, shapes, files, etc.) + # - [ ] Update episode_index (arg update=True) + # - [ ] Update info.json (arg update=True) + + # TODO(aliberts): remove (deprecated) + # if split == "train": + # self.episode_data_index = load_episode_data_index(self.episodes, self.episode_list) + # else: + # self.episode_data_index = calculate_episode_data_index(self.hf_dataset) + # self.hf_dataset = reset_episode_index(self.hf_dataset) + # if self.video: + # self.videos_dir = load_videos(repo_id, CODEBASE_VERSION, root) + + @property + def data_path(self) -> str: + """Formattable string for the parquet files.""" + return self.info["data_path"] + + @property + def videos_path(self) -> str | None: + """Formattable string for the video files.""" + return self.info["videos"]["videos_path"] if len(self.video_keys) > 0 else None + + @property + def episode_dicts(self) -> list[dict]: + """List of dictionary containing information for each episode, indexed by episode_index.""" + return self.info["episodes"] @property def fps(self) -> int: @@ -77,24 +118,24 @@ class LeRobotDataset(torch.utils.data.Dataset): return self.info["fps"] @property - def video(self) -> bool: - """Returns True if this dataset loads video frames from mp4 files. - Returns False if it only loads images from png files. - """ - return self.info.get("video", False) + def keys(self) -> list[str]: + """Keys to access non-image data (state, actions etc.).""" + return self.info["keys"] @property - def features(self) -> datasets.Features: - return self.hf_dataset.features + def image_keys(self) -> list[str]: + """Keys to access visual modalities stored as images.""" + return self.info["image_keys"] + + @property + def video_keys(self) -> list[str]: + """Keys to access visual modalities stored as videos.""" + return self.info["video_keys"] @property def camera_keys(self) -> list[str]: - """Keys to access image and video stream from cameras.""" - keys = [] - for key, feats in self.hf_dataset.features.items(): - if isinstance(feats, (datasets.Image, VideoFrame)): - keys.append(key) - return keys + """Keys to access image and video streams from cameras.""" + return self.image_keys + self.video_keys @property def video_frame_keys(self) -> list[str]: @@ -117,8 +158,13 @@ class LeRobotDataset(torch.utils.data.Dataset): @property def num_episodes(self) -> int: - """Number of episodes.""" - return len(self.hf_dataset.unique("episode_index")) + """Number of episodes selected.""" + return len(self.episodes) if self.episodes is not None else self.total_episodes + + @property + def total_episodes(self) -> int: + """Total number of episodes available.""" + return self.info["total_episodes"] @property def tolerance_s(self) -> float: @@ -129,6 +175,22 @@ class LeRobotDataset(torch.utils.data.Dataset): # 1e-4 to account for possible numerical error return 1 / self.fps - 1e-4 + @property + def shapes(self) -> dict: + """Shapes for the different features.""" + self.info.get("shapes") + + def get_episode_data_index(self) -> dict[str, torch.Tensor]: + episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in enumerate(self.episode_dicts)} + if self.episodes is not None: + episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in self.episodes} + + cumulative_lenghts = list(accumulate(episode_lengths.values())) + return { + "from": torch.LongTensor([0] + cumulative_lenghts[:-1]), + "to": torch.LongTensor(cumulative_lenghts), + } + def __len__(self): return self.num_samples @@ -147,7 +209,7 @@ class LeRobotDataset(torch.utils.data.Dataset): if self.video: item = load_from_videos( item, - self.video_frame_keys, + self.video_keys, self.videos_dir, self.tolerance_s, self.video_backend, @@ -225,7 +287,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): def __init__( self, repo_ids: list[str], - root: Path | None = DATA_DIR, + root: Path | None = LEROBOT_HOME, split: str = "train", image_transforms: Callable | None = None, delta_timestamps: dict[list[float]] | None = None, diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index d6aef15f..fd76ccd1 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import json -import re import warnings from functools import cache from pathlib import Path @@ -22,10 +21,9 @@ from typing import Dict import datasets import torch -from datasets import load_dataset, load_from_disk +from datasets import load_dataset from huggingface_hub import DatasetCard, HfApi, hf_hub_download, snapshot_download from PIL import Image as PILImage -from safetensors.torch import load_file from torchvision import transforms DATASET_CARD_TEMPLATE = """ @@ -96,7 +94,14 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]): @cache -def get_hf_dataset_safe_version(repo_id: str, version: str) -> str: +def get_hub_safe_version(repo_id: str, version: str) -> str: + num_version = float(version.strip("v")) + if num_version < 2: + raise ValueError( + f"""The dataset you requested ({repo_id}) is in {version} format. We introduced a new + format with v2.0 that is not backward compatible. Please use our conversion script + first (convert_dataset_16_to_20.py) to convert your dataset to this new format.""" + ) api = HfApi() dataset_info = api.list_repo_refs(repo_id, repo_type="dataset") branches = [b.name for b in dataset_info.branches] @@ -116,56 +121,27 @@ def get_hf_dataset_safe_version(repo_id: str, version: str) -> str: return version -def load_hf_dataset(repo_id: str, version: str, root: Path, split: str) -> datasets.Dataset: +def load_hf_dataset( + local_dir: Path, + data_path: str, + total_episodes: int, + episodes: list[int] | None = None, + split: str = "train", +) -> datasets.Dataset: """hf_dataset contains all the observations, states, actions, rewards, etc.""" - if root is not None: - hf_dataset = load_from_disk(str(Path(root) / repo_id / "train")) - # TODO(rcadene): clean this which enables getting a subset of dataset - if split != "train": - if "%" in split: - raise NotImplementedError(f"We dont support splitting based on percentage for now ({split}).") - match_from = re.search(r"train\[(\d+):\]", split) - match_to = re.search(r"train\[:(\d+)\]", split) - if match_from: - from_frame_index = int(match_from.group(1)) - hf_dataset = hf_dataset.select(range(from_frame_index, len(hf_dataset))) - elif match_to: - to_frame_index = int(match_to.group(1)) - hf_dataset = hf_dataset.select(range(to_frame_index)) - else: - raise ValueError( - f'`split` ({split}) should either be "train", "train[INT:]", or "train[:INT]"' - ) + if episodes is None: + path = str(local_dir / "data") + hf_dataset = load_dataset("parquet", data_dir=path, split=split) else: - safe_version = get_hf_dataset_safe_version(repo_id, version) - hf_dataset = load_dataset(repo_id, revision=safe_version, split=split) + files = [data_path.format(episode_index=ep_idx, total_episodes=total_episodes) for ep_idx in episodes] + files = [str(local_dir / fpath) for fpath in files] + hf_dataset = load_dataset("parquet", data_files=files, split=split) hf_dataset.set_transform(hf_transform_to_torch) return hf_dataset -def load_episode_data_index(repo_id, version, root) -> dict[str, torch.Tensor]: - """episode_data_index contains the range of indices for each episode - - Example: - ```python - from_id = episode_data_index["from"][episode_id].item() - to_id = episode_data_index["to"][episode_id].item() - episode_frames = [dataset[i] for i in range(from_id, to_id)] - ``` - """ - if root is not None: - path = Path(root) / repo_id / "meta_data" / "episode_data_index.safetensors" - else: - safe_version = get_hf_dataset_safe_version(repo_id, version) - path = hf_hub_download( - repo_id, "meta_data/episode_data_index.safetensors", repo_type="dataset", revision=safe_version - ) - - return load_file(path) - - -def load_stats(repo_id, version, root) -> dict[str, dict[str, torch.Tensor]]: +def load_stats(repo_id: str, version: str, local_dir: Path) -> dict[str, dict[str, torch.Tensor]]: """stats contains the statistics per modality computed over the full dataset, such as max, min, mean, std Example: @@ -173,47 +149,84 @@ def load_stats(repo_id, version, root) -> dict[str, dict[str, torch.Tensor]]: normalized_action = (action - stats["action"]["mean"]) / stats["action"]["std"] ``` """ - if root is not None: - path = Path(root) / repo_id / "meta_data" / "stats.safetensors" - else: - safe_version = get_hf_dataset_safe_version(repo_id, version) - path = hf_hub_download( - repo_id, "meta_data/stats.safetensors", repo_type="dataset", revision=safe_version - ) + fpath = hf_hub_download( + repo_id, filename="meta/stats.json", local_dir=local_dir, repo_type="dataset", revision=version + ) + with open(fpath) as f: + stats = json.load(f) - stats = load_file(path) + stats = flatten_dict(stats) + stats = {key: torch.tensor(value) for key, value in stats.items()} return unflatten_dict(stats) -def load_info(repo_id, version, root) -> dict: - """info contains useful information regarding the dataset that are not stored elsewhere +def load_info(repo_id: str, version: str, local_dir: Path) -> dict: + """info contains structural information about the dataset. It should be the reference and + act as the 'source of thruth' for what's inside the dataset. Example: ```python print("frame per second used to collect the video", info["fps"]) ``` """ - if root is not None: - path = Path(root) / repo_id / "meta_data" / "info.json" - else: - safe_version = get_hf_dataset_safe_version(repo_id, version) - path = hf_hub_download(repo_id, "meta_data/info.json", repo_type="dataset", revision=safe_version) - - with open(path) as f: - info = json.load(f) - return info + fpath = hf_hub_download( + repo_id, filename="meta/info.json", local_dir=local_dir, repo_type="dataset", revision=version + ) + with open(fpath) as f: + return json.load(f) -def load_videos(repo_id, version, root) -> Path: - if root is not None: - path = Path(root) / repo_id / "videos" - else: - # TODO(rcadene): we download the whole repo here. see if we can avoid this - safe_version = get_hf_dataset_safe_version(repo_id, version) - repo_dir = snapshot_download(repo_id, repo_type="dataset", revision=safe_version) - path = Path(repo_dir) / "videos" +def load_tasks(repo_id: str, version: str, local_dir: Path) -> dict: + """tasks contains all the tasks of the dataset, indexed by their task_index. - return path + Example: + ```json + { + "0": "Pick the Lego block and drop it in the box on the right." + } + ``` + """ + fpath = hf_hub_download( + repo_id, filename="meta/tasks.json", local_dir=local_dir, repo_type="dataset", revision=version + ) + with open(fpath) as f: + return json.load(f) + + +def download_episodes( + repo_id: str, + version: str, + local_dir: Path, + data_path: str, + video_keys: list, + total_episodes: int, + episodes: list[int] | None = None, + videos_path: str | None = None, +) -> None: + """Downloads the dataset from the given 'repo_id' at the provided 'version'. If 'episodes' is given, this + will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole + dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present + in 'local_dir', they won't be downloaded again. + + Note: Currently, if you're running this code offline but you already have the files in 'local_dir', + snapshot_download will still fail. This behavior will be fixed in an upcoming update of huggingface_hub. + """ + # TODO(rcadene, aliberts): implement faster transfer + # https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads + files = None + if episodes is not None: + files = [data_path.format(episode_index=ep_idx, total_episodes=total_episodes) for ep_idx in episodes] + if len(video_keys) > 0: + video_files = [ + videos_path.format(video_key=vid_key, episode_index=ep_idx) + for vid_key in video_keys + for ep_idx in episodes + ] + files += video_files + + snapshot_download( + repo_id, repo_type="dataset", revision=version, local_dir=local_dir, allow_patterns=files + ) def load_previous_and_future_frames(