From fbf2f2222a65cbfda352298eebef164b636b8a53 Mon Sep 17 00:00:00 2001 From: Simon Alibert <75076266+aliberts@users.noreply.github.com> Date: Wed, 19 Feb 2025 08:36:32 +0100 Subject: [PATCH] Remove `local_files_only` and use `codebase_version` instead of branches (#734) --- examples/port_datasets/pusht_zarr.py | 2 +- .../common/datasets/backward_compatibility.py | 40 ++++++ lerobot/common/datasets/factory.py | 4 +- lerobot/common/datasets/lerobot_dataset.py | 129 ++++++++++-------- lerobot/common/datasets/utils.py | 95 +++++-------- .../datasets/v2/convert_dataset_v1_to_v2.py | 4 +- .../v21/batch_convert_dataset_v20_to_v21.py | 49 +++++++ .../v21/convert_dataset_v20_to_v21.py | 64 +++++---- .../common/robot_devices/control_configs.py | 6 - lerobot/configs/default.py | 2 +- lerobot/scripts/control_robot.py | 6 +- lerobot/scripts/visualize_dataset.py | 9 +- lerobot/scripts/visualize_dataset_html.py | 13 +- lerobot/scripts/visualize_image_transforms.py | 2 +- pyproject.toml | 3 +- tests/fixtures/dataset_factories.py | 14 +- tests/test_control_robot.py | 5 +- tests/test_datasets.py | 4 - 18 files changed, 253 insertions(+), 198 deletions(-) create mode 100644 lerobot/common/datasets/backward_compatibility.py create mode 100644 lerobot/common/datasets/v21/batch_convert_dataset_v20_to_v21.py diff --git a/examples/port_datasets/pusht_zarr.py b/examples/port_datasets/pusht_zarr.py index 2eaf1c1c..622fbd14 100644 --- a/examples/port_datasets/pusht_zarr.py +++ b/examples/port_datasets/pusht_zarr.py @@ -223,5 +223,5 @@ if __name__ == "__main__": main(raw_dir, repo_id=repo_id, mode=mode) # Uncomment if you want to load the local dataset and explore it - # dataset = LeRobotDataset(repo_id=repo_id, local_files_only=True) + # dataset = LeRobotDataset(repo_id=repo_id) # breakpoint() diff --git a/lerobot/common/datasets/backward_compatibility.py b/lerobot/common/datasets/backward_compatibility.py new file mode 100644 index 00000000..aa814549 --- /dev/null +++ b/lerobot/common/datasets/backward_compatibility.py @@ -0,0 +1,40 @@ +import packaging.version + +V2_MESSAGE = """ +The dataset you requested ({repo_id}) is in {version} format. + +We introduced a new format since v2.0 which is not backward compatible with v1.x. +Please, use our conversion script. Modify the following command with your own task description: +``` +python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \\ + --repo-id {repo_id} \\ + --single-task "TASK DESCRIPTION." # <---- /!\\ Replace TASK DESCRIPTION /!\\ +``` + +A few examples to replace TASK DESCRIPTION: "Pick up the blue cube and place it into the bin.", "Insert the +peg into the socket.", "Slide open the ziploc bag.", "Take the elevator to the 1st floor.", "Open the top +cabinet, store the pot inside it then close the cabinet.", "Push the T-shaped block onto the T-shaped +target.", "Grab the spray paint on the shelf and place it in the bin on top of the robot dog.", "Fold the +sweatshirt.", ... + +If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb) +or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose). +""" + +V21_MESSAGE = """ +The dataset you requested ({repo_id}) is in {version} format. +While current version of LeRobot is backward-compatible with it, the version of your dataset still uses global +stats instead of per-episode stats. Update your dataset stats to the new format using this command: +``` +python lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py --repo-id={repo_id} +``` + +If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb) +or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose). +""" + + +class BackwardCompatibilityError(Exception): + def __init__(self, repo_id: str, version: packaging.version.Version): + message = V2_MESSAGE.format(repo_id=repo_id, version=version) + super().__init__(message) diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 95ba76b8..fb1fe6d6 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -83,15 +83,15 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas ) if isinstance(cfg.dataset.repo_id, str): - ds_meta = LeRobotDatasetMetadata(cfg.dataset.repo_id, local_files_only=cfg.dataset.local_files_only) + ds_meta = LeRobotDatasetMetadata(cfg.dataset.repo_id, revision=cfg.dataset.revision) delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta) dataset = LeRobotDataset( cfg.dataset.repo_id, episodes=cfg.dataset.episodes, delta_timestamps=delta_timestamps, image_transforms=image_transforms, + revision=cfg.dataset.revision, video_backend=cfg.dataset.video_backend, - local_files_only=cfg.dataset.local_files_only, ) else: raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.") diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index c7f0b2b3..dfdb3618 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -16,7 +16,6 @@ import logging import os import shutil -from functools import cached_property from pathlib import Path from typing import Callable @@ -27,6 +26,8 @@ import torch import torch.utils from datasets import load_dataset from huggingface_hub import HfApi, snapshot_download +from huggingface_hub.constants import REPOCARD_NAME +from packaging import version from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image @@ -41,14 +42,13 @@ from lerobot.common.datasets.utils import ( check_frame_features, check_timestamps_sync, check_version_compatibility, - create_branch, create_empty_dataset_info, create_lerobot_dataset_card, get_delta_indices, get_episode_data_index, get_features_from_robot, get_hf_features_from_features, - get_hub_safe_version, + get_safe_revision, hf_transform_to_torch, load_episodes, load_episodes_stats, @@ -79,30 +79,35 @@ class LeRobotDatasetMetadata: self, repo_id: str, root: str | Path | None = None, - local_files_only: bool = False, + revision: str | None = None, + force_cache_sync: bool = False, ): self.repo_id = repo_id + self.revision = revision if revision else CODEBASE_VERSION self.root = Path(root) if root is not None else LEROBOT_HOME / repo_id - self.local_files_only = local_files_only - # Load metadata - (self.root / "meta").mkdir(exist_ok=True, parents=True) - self.pull_from_repo(allow_patterns="meta/") + try: + if force_cache_sync: + raise FileNotFoundError + self.load_metadata() + except (FileNotFoundError, NotADirectoryError): + (self.root / "meta").mkdir(exist_ok=True, parents=True) + self.revision = get_safe_revision(self.repo_id, self.revision) + self.pull_from_repo(allow_patterns="meta/") + self.load_metadata() + + check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION) + + def load_metadata(self): self.info = load_info(self.root) - self.stats = load_stats(self.root) self.tasks, self.task_to_task_index = load_tasks(self.root) self.episodes = load_episodes(self.root) - try: - self.episodes_stats = load_episodes_stats(self.root) - self.stats = aggregate_stats(list(self.episodes_stats.values())) - except FileNotFoundError: - logging.warning( - f"""'episodes_stats.jsonl' not found. Using global dataset stats for each episode instead. - Convert your dataset stats to the new format using this command: - python lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py --repo-id={self.repo_id} """ - ) + if version.parse(self._version) < version.parse("v2.1"): self.stats = load_stats(self.root) self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes) + else: + self.episodes_stats = load_episodes_stats(self.root) + self.stats = aggregate_stats(list(self.episodes_stats.values())) def pull_from_repo( self, @@ -112,17 +117,12 @@ class LeRobotDatasetMetadata: snapshot_download( self.repo_id, repo_type="dataset", - revision=self._hub_version, + revision=self.revision, local_dir=self.root, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, - local_files_only=self.local_files_only, ) - @cached_property - def _hub_version(self) -> str | None: - return None if self.local_files_only else get_hub_safe_version(self.repo_id, CODEBASE_VERSION) - @property def _version(self) -> str: """Codebase version used to create this dataset.""" @@ -342,7 +342,7 @@ class LeRobotDatasetMetadata: if len(obj.video_keys) > 0 and not use_videos: raise ValueError() write_json(obj.info, obj.root / INFO_PATH) - obj.local_files_only = True + obj.revision = None return obj @@ -355,8 +355,9 @@ class LeRobotDataset(torch.utils.data.Dataset): image_transforms: Callable | None = None, delta_timestamps: dict[list[float]] | None = None, tolerance_s: float = 1e-4, + revision: str | None = None, + force_cache_sync: bool = False, download_videos: bool = True, - local_files_only: bool = False, video_backend: str | None = None, ): """ @@ -366,7 +367,7 @@ class LeRobotDataset(torch.utils.data.Dataset): - On your local disk in the 'root' folder. This is typically the case when you recorded your dataset locally and you may or may not have pushed it to the hub yet. Instantiating this class with 'root' will load your dataset directly from disk. This can happen while you're offline (no - internet connection), in that case, use local_files_only=True. + internet connection). - On the Hugging Face Hub at the address https://huggingface.co/datasets/{repo_id} and not on your local disk in the 'root' folder. Instantiating this class with this 'repo_id' will download @@ -448,11 +449,15 @@ class LeRobotDataset(torch.utils.data.Dataset): timestamps is separated to the next by 1/fps +/- tolerance_s. This also applies to frames decoded from video files. It is also used to check that `delta_timestamps` (when provided) are multiples of 1/fps. Defaults to 1e-4. + revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a + commit hash. Defaults to current codebase version tag. + sync_cache_first (bool, optional): Flag to sync and refresh local files first. If True and files + are already present in the local cache, this will be faster. However, files loaded might not + be in sync with the version on the hub, especially if you specified 'revision'. Defaults to + False. download_videos (bool, optional): Flag to download the videos. Note that when set to True but the video files are already present on local disk, they won't be downloaded again. Defaults to True. - local_files_only (bool, optional): Flag to use local files only. If True, no requests to the hub - will be made. Defaults to False. video_backend (str | None, optional): Video backend to use for decoding videos. There is currently a single option which is the pyav decoder used by Torchvision. Defaults to pyav. """ @@ -463,9 +468,9 @@ class LeRobotDataset(torch.utils.data.Dataset): self.delta_timestamps = delta_timestamps self.episodes = episodes self.tolerance_s = tolerance_s + self.revision = revision if revision else CODEBASE_VERSION self.video_backend = video_backend if video_backend else "pyav" self.delta_indices = None - self.local_files_only = local_files_only # Unused attributes self.image_writer = None @@ -474,17 +479,24 @@ class LeRobotDataset(torch.utils.data.Dataset): self.root.mkdir(exist_ok=True, parents=True) # Load metadata - self.meta = LeRobotDatasetMetadata(self.repo_id, self.root, self.local_files_only) - if self.episodes is not None and self.meta._version == CODEBASE_VERSION: + self.meta = LeRobotDatasetMetadata( + self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync + ) + if self.episodes is not None and version.parse(self.meta._version) >= version.parse("v2.1"): episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes] self.stats = aggregate_stats(episodes_stats) - # Check version - check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION) - # Load actual data - self.download_episodes(download_videos) - self.hf_dataset = self.load_hf_dataset() + try: + if force_cache_sync: + raise FileNotFoundError + assert all((self.root / fpath).is_file() for fpath in self.get_episodes_file_paths()) + self.hf_dataset = self.load_hf_dataset() + except (AssertionError, FileNotFoundError, NotADirectoryError): + self.revision = get_safe_revision(self.repo_id, self.revision) + self.download_episodes(download_videos) + self.hf_dataset = self.load_hf_dataset() + self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes) # Check timestamps @@ -501,7 +513,6 @@ class LeRobotDataset(torch.utils.data.Dataset): def push_to_hub( self, branch: str | None = None, - create_card: bool = True, tags: list | None = None, license: str | None = "apache-2.0", push_videos: bool = True, @@ -528,7 +539,13 @@ class LeRobotDataset(torch.utils.data.Dataset): exist_ok=True, ) if branch: - create_branch(repo_id=self.repo_id, branch=branch, repo_type="dataset") + hub_api.create_branch( + repo_id=self.repo_id, + branch=branch, + revision=self.revision, + repo_type="dataset", + exist_ok=True, + ) hub_api.upload_folder( repo_id=self.repo_id, @@ -538,15 +555,12 @@ class LeRobotDataset(torch.utils.data.Dataset): allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, ) - if create_card: + if not hub_api.file_exists(self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch): card = create_lerobot_dataset_card( tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs ) card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch) - if not branch: - create_branch(repo_id=self.repo_id, branch=CODEBASE_VERSION, repo_type="dataset") - def pull_from_repo( self, allow_patterns: list[str] | str | None = None, @@ -555,11 +569,10 @@ class LeRobotDataset(torch.utils.data.Dataset): snapshot_download( self.repo_id, repo_type="dataset", - revision=self.meta._hub_version, + revision=self.revision, local_dir=self.root, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, - local_files_only=self.local_files_only, ) def download_episodes(self, download_videos: bool = True) -> None: @@ -573,17 +586,23 @@ class LeRobotDataset(torch.utils.data.Dataset): files = None ignore_patterns = None if download_videos else "videos/" if self.episodes is not None: - files = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes] - if len(self.meta.video_keys) > 0 and download_videos: - video_files = [ - str(self.meta.get_video_file_path(ep_idx, vid_key)) - for vid_key in self.meta.video_keys - for ep_idx in self.episodes - ] - files += video_files + files = self.get_episodes_file_paths() self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns) + def get_episodes_file_paths(self) -> list[Path]: + episodes = self.episodes if self.episodes is not None else list(range(self.meta.total_episodes)) + fpaths = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in episodes] + if len(self.meta.video_keys) > 0: + video_files = [ + str(self.meta.get_video_file_path(ep_idx, vid_key)) + for vid_key in self.meta.video_keys + for ep_idx in episodes + ] + fpaths += video_files + + return fpaths + def load_hf_dataset(self) -> datasets.Dataset: """hf_dataset contains all the observations, states, actions, rewards, etc.""" if self.episodes is None: @@ -991,7 +1010,7 @@ class LeRobotDataset(torch.utils.data.Dataset): ) obj.repo_id = obj.meta.repo_id obj.root = obj.meta.root - obj.local_files_only = obj.meta.local_files_only + obj.revision = None obj.tolerance_s = tolerance_s obj.image_writer = None @@ -1033,7 +1052,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): delta_timestamps: dict[list[float]] | None = None, tolerances_s: dict | None = None, download_videos: bool = True, - local_files_only: bool = False, video_backend: str | None = None, ): super().__init__() @@ -1051,7 +1069,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): delta_timestamps=delta_timestamps, tolerance_s=self.tolerances_s[repo_id], download_videos=download_videos, - local_files_only=local_files_only, video_backend=video_backend, ) for repo_id in repo_ids diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 8b734042..c9b0c345 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -13,10 +13,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import importlib.resources import json import logging -import textwrap from collections.abc import Iterator from itertools import accumulate from pathlib import Path @@ -31,9 +31,11 @@ import pyarrow.compute as pc import torch from datasets.table import embed_table_storage from huggingface_hub import DatasetCard, DatasetCardData, HfApi +from packaging import version from PIL import Image as PILImage from torchvision import transforms +from lerobot.common.datasets.backward_compatibility import V21_MESSAGE, BackwardCompatibilityError from lerobot.common.robot_devices.robots.utils import Robot from lerobot.common.utils.utils import is_valid_numpy_dtype_string from lerobot.configs.types import DictLike, FeatureType, PolicyFeature @@ -200,7 +202,7 @@ def write_task(task_index: int, task: dict, local_dir: Path): append_jsonlines(task_dict, local_dir / TASKS_PATH) -def load_tasks(local_dir: Path) -> dict: +def load_tasks(local_dir: Path) -> tuple[dict, dict]: tasks = load_jsonlines(local_dir / TASKS_PATH) tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])} task_to_task_index = {task: task_index for task_index, task in tasks.items()} @@ -231,7 +233,9 @@ def load_episodes_stats(local_dir: Path) -> dict: } -def backward_compatible_episodes_stats(stats, episodes: list[int]) -> dict[str, dict[str, np.ndarray]]: +def backward_compatible_episodes_stats( + stats: dict[str, dict[str, np.ndarray]], episodes: list[int] +) -> dict[str, dict[str, np.ndarray]]: return {ep_idx: stats for ep_idx in episodes} @@ -265,73 +269,38 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]): return items_dict -def _get_major_minor(version: str) -> tuple[int]: - split = version.strip("v").split(".") - return int(split[0]), int(split[1]) - - -class BackwardCompatibilityError(Exception): - def __init__(self, repo_id, version): - message = textwrap.dedent(f""" - BackwardCompatibilityError: The dataset you requested ({repo_id}) is in {version} format. - - We introduced a new format since v2.0 which is not backward compatible with v1.x. - Please, use our conversion script. Modify the following command with your own task description: - ``` - python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \\ - --repo-id {repo_id} \\ - --single-task "TASK DESCRIPTION." # <---- /!\\ Replace TASK DESCRIPTION /!\\ - ``` - - A few examples to replace TASK DESCRIPTION: "Pick up the blue cube and place it into the bin.", - "Insert the peg into the socket.", "Slide open the ziploc bag.", "Take the elevator to the 1st floor.", - "Open the top cabinet, store the pot inside it then close the cabinet.", "Push the T-shaped block onto the T-shaped target.", - "Grab the spray paint on the shelf and place it in the bin on top of the robot dog.", "Fold the sweatshirt.", ... - - If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb) - or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose). - """) - super().__init__(message) - - def check_version_compatibility( repo_id: str, version_to_check: str, current_version: str, enforce_breaking_major: bool = True ) -> None: - current_major, _ = _get_major_minor(current_version) - major_to_check, _ = _get_major_minor(version_to_check) - if major_to_check < current_major and enforce_breaking_major: - raise BackwardCompatibilityError(repo_id, version_to_check) - elif float(version_to_check.strip("v")) < float(current_version.strip("v")): - logging.warning( - f"""The dataset you requested ({repo_id}) was created with a previous version ({version_to_check}) of the - codebase. The current codebase version is {current_version}. You should be fine since - backward compatibility is maintained. If you encounter a problem, contact LeRobot maintainers on - Discord ('https://discord.com/invite/s3KuuzsPFb') or open an issue on github.""", - ) + v_check = version.parse(version_to_check) + v_current = version.parse(current_version) + if v_check.major < v_current.major and enforce_breaking_major: + raise BackwardCompatibilityError(repo_id, v_check) + elif v_check.minor < v_current.minor: + logging.warning(V21_MESSAGE.format(repo_id=repo_id, version=version_to_check)) -def get_hub_safe_version(repo_id: str, version: str) -> str: +def get_repo_versions(repo_id: str) -> list[version.Version]: + """Returns available valid versions (branches and tags) on given repo.""" api = HfApi() - dataset_info = api.list_repo_refs(repo_id, repo_type="dataset") - branches = [b.name for b in dataset_info.branches] - if version not in branches: - num_version = float(version.strip("v")) - hub_num_versions = [float(v.strip("v")) for v in branches if v.startswith("v")] - if num_version >= 2.0 and all(v < 2.0 for v in hub_num_versions): - raise BackwardCompatibilityError(repo_id, version) + repo_refs = api.list_repo_refs(repo_id, repo_type="dataset") + repo_refs = [b.name for b in repo_refs.branches + repo_refs.tags] + repo_versions = [] + for ref in repo_refs: + with contextlib.suppress(version.InvalidVersion): + repo_versions.append(version.parse(ref)) - logging.warning( - f"""You are trying to load a dataset from {repo_id} created with a previous version of the - codebase. The following versions are available: {branches}. - The requested version ('{version}') is not found. You should be fine since - backward compatibility is maintained. If you encounter a problem, contact LeRobot maintainers on - Discord ('https://discord.com/invite/s3KuuzsPFb') or open an issue on github.""", - ) - if "main" not in branches: - raise ValueError(f"Version 'main' not found on {repo_id}") - return "main" - else: - return version + return repo_versions + + +def get_safe_revision(repo_id: str, revision: str) -> str: + """Returns the version if available on repo, otherwise return the latest available.""" + api = HfApi() + if api.revision_exists(repo_id, revision, repo_type="dataset"): + return revision + + hub_versions = get_repo_versions(repo_id) + return f"v{max(hub_versions)}" def get_hf_features_from_features(features: dict) -> datasets.Features: diff --git a/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py b/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py index 62ca9932..943e94f0 100644 --- a/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py +++ b/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py @@ -130,7 +130,7 @@ from lerobot.common.datasets.utils import ( create_branch, create_lerobot_dataset_card, flatten_dict, - get_hub_safe_version, + get_safe_revision, load_json, unflatten_dict, write_json, @@ -443,7 +443,7 @@ def convert_dataset( test_branch: str | None = None, **card_kwargs, ): - v1 = get_hub_safe_version(repo_id, V16) + v1 = get_safe_revision(repo_id, V16) v1x_dir = local_dir / V16 / repo_id v20_dir = local_dir / V20 / repo_id v1x_dir.mkdir(parents=True, exist_ok=True) diff --git a/lerobot/common/datasets/v21/batch_convert_dataset_v20_to_v21.py b/lerobot/common/datasets/v21/batch_convert_dataset_v20_to_v21.py new file mode 100644 index 00000000..624827bd --- /dev/null +++ b/lerobot/common/datasets/v21/batch_convert_dataset_v20_to_v21.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script is for internal use to convert all datasets under the 'lerobot' hub user account to v2.1. +""" + +import traceback +from pathlib import Path + +from lerobot import available_datasets +from lerobot.common.datasets.v21.convert_dataset_v20_to_v21 import convert_dataset + +LOCAL_DIR = Path("data/") + + +def batch_convert(): + status = {} + logfile = LOCAL_DIR / "conversion_log_v21.txt" + for num, repo_id in available_datasets: + print(f"\nConverting {repo_id} ({num}/{len(available_datasets)})") + print("---------------------------------------------------------") + try: + convert_dataset(repo_id) + status = f"{repo_id}: success." + with open(logfile, "a") as file: + file.write(status + "\n") + except Exception: + status = f"{repo_id}: failed\n {traceback.format_exc()}" + with open(logfile, "a") as file: + file.write(status + "\n") + continue + + +if __name__ == "__main__": + batch_convert() diff --git a/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py b/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py index 0c5d2688..d52a0a10 100644 --- a/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py +++ b/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py @@ -1,10 +1,12 @@ """ This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.0 to -2.1. It performs the following: +2.1. It will: - Generates per-episodes stats and writes them in `episodes_stats.jsonl` +- Check consistency between these new stats and the old ones. - Removes the deprecated `stats.json` (by default) - Updates codebase_version in `info.json` +- Push this new version to the hub on the 'main' branch and tags it with "v2.1". Usage: @@ -14,9 +16,9 @@ python lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py \ ``` """ -# TODO(rcadene, aliberts): ensure this script works for any other changes for the final v2.1 import argparse +import logging from huggingface_hub import HfApi @@ -24,14 +26,27 @@ from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDat from lerobot.common.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info from lerobot.common.datasets.v21.convert_stats import check_aggregate_stats, convert_stats +V20 = "v2.0" +V21 = "v2.1" -def main( + +class SuppressWarnings: + def __enter__(self): + self.previous_level = logging.getLogger().getEffectiveLevel() + logging.getLogger().setLevel(logging.ERROR) + + def __exit__(self, exc_type, exc_val, exc_tb): + logging.getLogger().setLevel(self.previous_level) + + +def convert_dataset( repo_id: str, - test_branch: str | None = None, - delete_old_stats: bool = False, + branch: str | None = None, num_workers: int = 4, ): - dataset = LeRobotDataset(repo_id) + with SuppressWarnings(): + dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True) + if (dataset.root / EPISODES_STATS_PATH).is_file(): raise FileExistsError("episodes_stats.jsonl already exists.") @@ -42,18 +57,21 @@ def main( dataset.meta.info["codebase_version"] = CODEBASE_VERSION write_info(dataset.meta.info, dataset.root) - dataset.push_to_hub(branch=test_branch, create_card=False, allow_patterns="meta/") + dataset.push_to_hub(branch=branch, allow_patterns="meta/") - if delete_old_stats: - if (dataset.root / STATS_PATH).is_file: - (dataset.root / STATS_PATH).unlink() - hub_api = HfApi() - if hub_api.file_exists( - STATS_PATH, repo_id=dataset.repo_id, revision=test_branch, repo_type="dataset" - ): - hub_api.delete_file( - STATS_PATH, repo_id=dataset.repo_id, revision=test_branch, repo_type="dataset" - ) + # delete old stats.json file + if (dataset.root / STATS_PATH).is_file: + (dataset.root / STATS_PATH).unlink() + + hub_api = HfApi() + if hub_api.file_exists( + repo_id=dataset.repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset" + ): + hub_api.delete_file( + path_in_repo=STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset" + ) + + hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset") if __name__ == "__main__": @@ -65,16 +83,10 @@ if __name__ == "__main__": help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).", ) parser.add_argument( - "--test-branch", + "--branch", type=str, default=None, - help="Repo branch to test your conversion first (e.g. 'v2.0.test')", - ) - parser.add_argument( - "--delete-old-stats", - type=bool, - default=False, - help="Delete the deprecated `stats.json`", + help="Repo branch to push your dataset (defaults to the main branch)", ) parser.add_argument( "--num-workers", @@ -84,4 +96,4 @@ if __name__ == "__main__": ) args = parser.parse_args() - main(**vars(args)) + convert_dataset(**vars(args)) diff --git a/lerobot/common/robot_devices/control_configs.py b/lerobot/common/robot_devices/control_configs.py index c96a87f0..d7d03ac0 100644 --- a/lerobot/common/robot_devices/control_configs.py +++ b/lerobot/common/robot_devices/control_configs.py @@ -81,9 +81,6 @@ class RecordControlConfig(ControlConfig): play_sounds: bool = True # Resume recording on an existing dataset. resume: bool = False - # TODO(rcadene, aliberts): remove local_files_only when refactor with dataset as argument - # Use local files only. By default, this script will try to fetch the dataset from the hub if it exists. - local_files_only: bool = False def __post_init__(self): # HACK: We parse again the cli args here to get the pretrained path if there was one. @@ -128,9 +125,6 @@ class ReplayControlConfig(ControlConfig): fps: int | None = None # Use vocal synthesis to read events. play_sounds: bool = True - # TODO(rcadene, aliberts): remove local_files_only when refactor with dataset as argument - # Use local files only. By default, this script will try to fetch the dataset from the hub if it exists. - local_files_only: bool = False @dataclass diff --git a/lerobot/configs/default.py b/lerobot/configs/default.py index 5dd2f898..a5013431 100644 --- a/lerobot/configs/default.py +++ b/lerobot/configs/default.py @@ -31,7 +31,7 @@ class DatasetConfig: repo_id: str episodes: list[int] | None = None image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig) - local_files_only: bool = False + revision: str | None = None use_imagenet_stats: bool = True video_backend: str = "pyav" diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index 5f51c81b..dee2792d 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -85,7 +85,6 @@ python lerobot/scripts/control_robot.py record \ This might require a sudo permission to allow your terminal to monitor keyboard events. **NOTE**: You can resume/continue data recording by running the same data recording command and adding `--control.resume=true`. -If the dataset you want to extend is not on the hub, you also need to add `--control.local_files_only=true`. - Train on this dataset with the ACT policy: ```bash @@ -216,7 +215,6 @@ def record( dataset = LeRobotDataset( cfg.repo_id, root=cfg.root, - local_files_only=cfg.local_files_only, ) if len(robot.cameras) > 0: dataset.start_image_writer( @@ -318,9 +316,7 @@ def replay( # TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset # TODO(rcadene): Add option to record logs - dataset = LeRobotDataset( - cfg.repo_id, root=cfg.root, episodes=[cfg.episode], local_files_only=cfg.local_files_only - ) + dataset = LeRobotDataset(cfg.repo_id, root=cfg.root, episodes=[cfg.episode]) actions = dataset.hf_dataset.select_columns("action") if not robot.is_connected: diff --git a/lerobot/scripts/visualize_dataset.py b/lerobot/scripts/visualize_dataset.py index 626b0bde..11feb1af 100644 --- a/lerobot/scripts/visualize_dataset.py +++ b/lerobot/scripts/visualize_dataset.py @@ -207,12 +207,6 @@ def main(): required=True, help="Episode to visualize.", ) - parser.add_argument( - "--local-files-only", - type=int, - default=0, - help="Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.", - ) parser.add_argument( "--root", type=Path, @@ -275,10 +269,9 @@ def main(): kwargs = vars(args) repo_id = kwargs.pop("repo_id") root = kwargs.pop("root") - local_files_only = kwargs.pop("local_files_only") logging.info("Loading dataset") - dataset = LeRobotDataset(repo_id, root=root, local_files_only=local_files_only) + dataset = LeRobotDataset(repo_id, root=root) visualize_dataset(dataset, **vars(args)) diff --git a/lerobot/scripts/visualize_dataset_html.py b/lerobot/scripts/visualize_dataset_html.py index cc3f3930..a022c91e 100644 --- a/lerobot/scripts/visualize_dataset_html.py +++ b/lerobot/scripts/visualize_dataset_html.py @@ -384,12 +384,6 @@ def main(): default=None, help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht).", ) - parser.add_argument( - "--local-files-only", - type=int, - default=0, - help="Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.", - ) parser.add_argument( "--root", type=Path, @@ -445,15 +439,10 @@ def main(): repo_id = kwargs.pop("repo_id") load_from_hf_hub = kwargs.pop("load_from_hf_hub") root = kwargs.pop("root") - local_files_only = kwargs.pop("local_files_only") dataset = None if repo_id: - dataset = ( - LeRobotDataset(repo_id, root=root, local_files_only=local_files_only) - if not load_from_hf_hub - else get_dataset_info(repo_id) - ) + dataset = LeRobotDataset(repo_id, root=root) if not load_from_hf_hub else get_dataset_info(repo_id) visualize_dataset_html(dataset, **vars(args)) diff --git a/lerobot/scripts/visualize_image_transforms.py b/lerobot/scripts/visualize_image_transforms.py index 727fe178..80935d32 100644 --- a/lerobot/scripts/visualize_image_transforms.py +++ b/lerobot/scripts/visualize_image_transforms.py @@ -109,7 +109,7 @@ def visualize_image_transforms(cfg: DatasetConfig, output_dir: Path = OUTPUT_DIR dataset = LeRobotDataset( repo_id=cfg.repo_id, episodes=cfg.episodes, - local_files_only=cfg.local_files_only, + revision=cfg.revision, video_backend=cfg.video_backend, ) diff --git a/pyproject.toml b/pyproject.toml index 6e7e0575..4a3355e1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ dependencies = [ "numba>=0.59.0", "omegaconf>=2.3.0", "opencv-python>=4.9.0", + "packaging>=24.2", "pyav>=12.0.5", "pymunk>=6.6.0", "rerun-sdk>=0.21.0", @@ -54,7 +55,7 @@ dependencies = [ "torch>=2.2.1", "torchvision>=0.21.0", "wandb>=0.16.3", - "zarr>=2.17.0" + "zarr>=2.17.0", ] [project.optional-dependencies] diff --git a/tests/fixtures/dataset_factories.py b/tests/fixtures/dataset_factories.py index f57f945b..811e29b7 100644 --- a/tests/fixtures/dataset_factories.py +++ b/tests/fixtures/dataset_factories.py @@ -309,7 +309,6 @@ def lerobot_dataset_metadata_factory( episodes_stats: list[dict] | None = None, tasks: list[dict] | None = None, episodes: list[dict] | None = None, - local_files_only: bool = False, ) -> LeRobotDatasetMetadata: if not info: info = info_factory() @@ -335,16 +334,16 @@ def lerobot_dataset_metadata_factory( ) with ( patch( - "lerobot.common.datasets.lerobot_dataset.get_hub_safe_version" - ) as mock_get_hub_safe_version_patch, + "lerobot.common.datasets.lerobot_dataset.get_safe_revision" + ) as mock_get_safe_revision_patch, patch( "lerobot.common.datasets.lerobot_dataset.snapshot_download" ) as mock_snapshot_download_patch, ): - mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version: version + mock_get_safe_revision_patch.side_effect = lambda repo_id, version: version mock_snapshot_download_patch.side_effect = mock_snapshot_download - return LeRobotDatasetMetadata(repo_id=repo_id, root=root, local_files_only=local_files_only) + return LeRobotDatasetMetadata(repo_id=repo_id, root=root) return _create_lerobot_dataset_metadata @@ -411,15 +410,18 @@ def lerobot_dataset_factory( episodes_stats=episodes_stats, tasks=tasks, episodes=episode_dicts, - local_files_only=kwargs.get("local_files_only", False), ) with ( patch("lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch, + patch( + "lerobot.common.datasets.lerobot_dataset.get_safe_revision" + ) as mock_get_safe_revision_patch, patch( "lerobot.common.datasets.lerobot_dataset.snapshot_download" ) as mock_snapshot_download_patch, ): mock_metadata_patch.return_value = mock_metadata + mock_get_safe_revision_patch.side_effect = lambda repo_id, version: version mock_snapshot_download_patch.side_effect = mock_snapshot_download return LeRobotDataset(repo_id=repo_id, root=root, **kwargs) diff --git a/tests/test_control_robot.py b/tests/test_control_robot.py index a4f538a6..12b68641 100644 --- a/tests/test_control_robot.py +++ b/tests/test_control_robot.py @@ -167,9 +167,7 @@ def test_record_and_replay_and_policy(tmp_path, request, robot_type, mock): assert dataset.meta.total_episodes == 2 assert len(dataset) == 2 - replay_cfg = ReplayControlConfig( - episode=0, fps=1, root=root, repo_id=repo_id, play_sounds=False, local_files_only=True - ) + replay_cfg = ReplayControlConfig(episode=0, fps=1, root=root, repo_id=repo_id, play_sounds=False) replay(robot, replay_cfg) policy_cfg = ACTConfig() @@ -266,7 +264,6 @@ def test_resume_record(tmp_path, request, robot_type, mock): video=False, display_cameras=False, play_sounds=False, - local_files_only=True, num_episodes=1, ) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 72fa5b50..3e8b531d 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -77,10 +77,6 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory): root_init = tmp_path / "init" dataset_init = lerobot_dataset_factory(root=root_init) - # Access the '_hub_version' cached_property in both instances to force its creation - _ = dataset_init.meta._hub_version - _ = dataset_create.meta._hub_version - init_attr = set(vars(dataset_init).keys()) create_attr = set(vars(dataset_create).keys())