Remove `local_files_only` and use `codebase_version` instead of branches (#734)

This commit is contained in:
Simon Alibert 2025-02-19 08:36:32 +01:00 committed by GitHub
parent 624eaf1175
commit fbf2f2222a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 253 additions and 198 deletions

View File

@ -223,5 +223,5 @@ if __name__ == "__main__":
main(raw_dir, repo_id=repo_id, mode=mode)
# Uncomment if you want to load the local dataset and explore it
# dataset = LeRobotDataset(repo_id=repo_id, local_files_only=True)
# dataset = LeRobotDataset(repo_id=repo_id)
# breakpoint()

View File

@ -0,0 +1,40 @@
import packaging.version
V2_MESSAGE = """
The dataset you requested ({repo_id}) is in {version} format.
We introduced a new format since v2.0 which is not backward compatible with v1.x.
Please, use our conversion script. Modify the following command with your own task description:
```
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \\
--repo-id {repo_id} \\
--single-task "TASK DESCRIPTION." # <---- /!\\ Replace TASK DESCRIPTION /!\\
```
A few examples to replace TASK DESCRIPTION: "Pick up the blue cube and place it into the bin.", "Insert the
peg into the socket.", "Slide open the ziploc bag.", "Take the elevator to the 1st floor.", "Open the top
cabinet, store the pot inside it then close the cabinet.", "Push the T-shaped block onto the T-shaped
target.", "Grab the spray paint on the shelf and place it in the bin on top of the robot dog.", "Fold the
sweatshirt.", ...
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
"""
V21_MESSAGE = """
The dataset you requested ({repo_id}) is in {version} format.
While current version of LeRobot is backward-compatible with it, the version of your dataset still uses global
stats instead of per-episode stats. Update your dataset stats to the new format using this command:
```
python lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py --repo-id={repo_id}
```
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
"""
class BackwardCompatibilityError(Exception):
def __init__(self, repo_id: str, version: packaging.version.Version):
message = V2_MESSAGE.format(repo_id=repo_id, version=version)
super().__init__(message)

View File

@ -83,15 +83,15 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
)
if isinstance(cfg.dataset.repo_id, str):
ds_meta = LeRobotDatasetMetadata(cfg.dataset.repo_id, local_files_only=cfg.dataset.local_files_only)
ds_meta = LeRobotDatasetMetadata(cfg.dataset.repo_id, revision=cfg.dataset.revision)
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
dataset = LeRobotDataset(
cfg.dataset.repo_id,
episodes=cfg.dataset.episodes,
delta_timestamps=delta_timestamps,
image_transforms=image_transforms,
revision=cfg.dataset.revision,
video_backend=cfg.dataset.video_backend,
local_files_only=cfg.dataset.local_files_only,
)
else:
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")

View File

@ -16,7 +16,6 @@
import logging
import os
import shutil
from functools import cached_property
from pathlib import Path
from typing import Callable
@ -27,6 +26,8 @@ import torch
import torch.utils
from datasets import load_dataset
from huggingface_hub import HfApi, snapshot_download
from huggingface_hub.constants import REPOCARD_NAME
from packaging import version
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
@ -41,14 +42,13 @@ from lerobot.common.datasets.utils import (
check_frame_features,
check_timestamps_sync,
check_version_compatibility,
create_branch,
create_empty_dataset_info,
create_lerobot_dataset_card,
get_delta_indices,
get_episode_data_index,
get_features_from_robot,
get_hf_features_from_features,
get_hub_safe_version,
get_safe_revision,
hf_transform_to_torch,
load_episodes,
load_episodes_stats,
@ -79,30 +79,35 @@ class LeRobotDatasetMetadata:
self,
repo_id: str,
root: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
force_cache_sync: bool = False,
):
self.repo_id = repo_id
self.revision = revision if revision else CODEBASE_VERSION
self.root = Path(root) if root is not None else LEROBOT_HOME / repo_id
self.local_files_only = local_files_only
# Load metadata
(self.root / "meta").mkdir(exist_ok=True, parents=True)
self.pull_from_repo(allow_patterns="meta/")
try:
if force_cache_sync:
raise FileNotFoundError
self.load_metadata()
except (FileNotFoundError, NotADirectoryError):
(self.root / "meta").mkdir(exist_ok=True, parents=True)
self.revision = get_safe_revision(self.repo_id, self.revision)
self.pull_from_repo(allow_patterns="meta/")
self.load_metadata()
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
def load_metadata(self):
self.info = load_info(self.root)
self.stats = load_stats(self.root)
self.tasks, self.task_to_task_index = load_tasks(self.root)
self.episodes = load_episodes(self.root)
try:
self.episodes_stats = load_episodes_stats(self.root)
self.stats = aggregate_stats(list(self.episodes_stats.values()))
except FileNotFoundError:
logging.warning(
f"""'episodes_stats.jsonl' not found. Using global dataset stats for each episode instead.
Convert your dataset stats to the new format using this command:
python lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py --repo-id={self.repo_id} """
)
if version.parse(self._version) < version.parse("v2.1"):
self.stats = load_stats(self.root)
self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes)
else:
self.episodes_stats = load_episodes_stats(self.root)
self.stats = aggregate_stats(list(self.episodes_stats.values()))
def pull_from_repo(
self,
@ -112,17 +117,12 @@ class LeRobotDatasetMetadata:
snapshot_download(
self.repo_id,
repo_type="dataset",
revision=self._hub_version,
revision=self.revision,
local_dir=self.root,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
local_files_only=self.local_files_only,
)
@cached_property
def _hub_version(self) -> str | None:
return None if self.local_files_only else get_hub_safe_version(self.repo_id, CODEBASE_VERSION)
@property
def _version(self) -> str:
"""Codebase version used to create this dataset."""
@ -342,7 +342,7 @@ class LeRobotDatasetMetadata:
if len(obj.video_keys) > 0 and not use_videos:
raise ValueError()
write_json(obj.info, obj.root / INFO_PATH)
obj.local_files_only = True
obj.revision = None
return obj
@ -355,8 +355,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
image_transforms: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None,
tolerance_s: float = 1e-4,
revision: str | None = None,
force_cache_sync: bool = False,
download_videos: bool = True,
local_files_only: bool = False,
video_backend: str | None = None,
):
"""
@ -366,7 +367,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
- On your local disk in the 'root' folder. This is typically the case when you recorded your
dataset locally and you may or may not have pushed it to the hub yet. Instantiating this class
with 'root' will load your dataset directly from disk. This can happen while you're offline (no
internet connection), in that case, use local_files_only=True.
internet connection).
- On the Hugging Face Hub at the address https://huggingface.co/datasets/{repo_id} and not on
your local disk in the 'root' folder. Instantiating this class with this 'repo_id' will download
@ -448,11 +449,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
timestamps is separated to the next by 1/fps +/- tolerance_s. This also applies to frames
decoded from video files. It is also used to check that `delta_timestamps` (when provided) are
multiples of 1/fps. Defaults to 1e-4.
revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a
commit hash. Defaults to current codebase version tag.
sync_cache_first (bool, optional): Flag to sync and refresh local files first. If True and files
are already present in the local cache, this will be faster. However, files loaded might not
be in sync with the version on the hub, especially if you specified 'revision'. Defaults to
False.
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
video files are already present on local disk, they won't be downloaded again. Defaults to
True.
local_files_only (bool, optional): Flag to use local files only. If True, no requests to the hub
will be made. Defaults to False.
video_backend (str | None, optional): Video backend to use for decoding videos. There is currently
a single option which is the pyav decoder used by Torchvision. Defaults to pyav.
"""
@ -463,9 +468,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.delta_timestamps = delta_timestamps
self.episodes = episodes
self.tolerance_s = tolerance_s
self.revision = revision if revision else CODEBASE_VERSION
self.video_backend = video_backend if video_backend else "pyav"
self.delta_indices = None
self.local_files_only = local_files_only
# Unused attributes
self.image_writer = None
@ -474,17 +479,24 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.root.mkdir(exist_ok=True, parents=True)
# Load metadata
self.meta = LeRobotDatasetMetadata(self.repo_id, self.root, self.local_files_only)
if self.episodes is not None and self.meta._version == CODEBASE_VERSION:
self.meta = LeRobotDatasetMetadata(
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
)
if self.episodes is not None and version.parse(self.meta._version) >= version.parse("v2.1"):
episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes]
self.stats = aggregate_stats(episodes_stats)
# Check version
check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION)
# Load actual data
self.download_episodes(download_videos)
self.hf_dataset = self.load_hf_dataset()
try:
if force_cache_sync:
raise FileNotFoundError
assert all((self.root / fpath).is_file() for fpath in self.get_episodes_file_paths())
self.hf_dataset = self.load_hf_dataset()
except (AssertionError, FileNotFoundError, NotADirectoryError):
self.revision = get_safe_revision(self.repo_id, self.revision)
self.download_episodes(download_videos)
self.hf_dataset = self.load_hf_dataset()
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
# Check timestamps
@ -501,7 +513,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
def push_to_hub(
self,
branch: str | None = None,
create_card: bool = True,
tags: list | None = None,
license: str | None = "apache-2.0",
push_videos: bool = True,
@ -528,7 +539,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
exist_ok=True,
)
if branch:
create_branch(repo_id=self.repo_id, branch=branch, repo_type="dataset")
hub_api.create_branch(
repo_id=self.repo_id,
branch=branch,
revision=self.revision,
repo_type="dataset",
exist_ok=True,
)
hub_api.upload_folder(
repo_id=self.repo_id,
@ -538,15 +555,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
)
if create_card:
if not hub_api.file_exists(self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch):
card = create_lerobot_dataset_card(
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
)
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)
if not branch:
create_branch(repo_id=self.repo_id, branch=CODEBASE_VERSION, repo_type="dataset")
def pull_from_repo(
self,
allow_patterns: list[str] | str | None = None,
@ -555,11 +569,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
snapshot_download(
self.repo_id,
repo_type="dataset",
revision=self.meta._hub_version,
revision=self.revision,
local_dir=self.root,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
local_files_only=self.local_files_only,
)
def download_episodes(self, download_videos: bool = True) -> None:
@ -573,17 +586,23 @@ class LeRobotDataset(torch.utils.data.Dataset):
files = None
ignore_patterns = None if download_videos else "videos/"
if self.episodes is not None:
files = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
if len(self.meta.video_keys) > 0 and download_videos:
video_files = [
str(self.meta.get_video_file_path(ep_idx, vid_key))
for vid_key in self.meta.video_keys
for ep_idx in self.episodes
]
files += video_files
files = self.get_episodes_file_paths()
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
def get_episodes_file_paths(self) -> list[Path]:
episodes = self.episodes if self.episodes is not None else list(range(self.meta.total_episodes))
fpaths = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in episodes]
if len(self.meta.video_keys) > 0:
video_files = [
str(self.meta.get_video_file_path(ep_idx, vid_key))
for vid_key in self.meta.video_keys
for ep_idx in episodes
]
fpaths += video_files
return fpaths
def load_hf_dataset(self) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
if self.episodes is None:
@ -991,7 +1010,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
)
obj.repo_id = obj.meta.repo_id
obj.root = obj.meta.root
obj.local_files_only = obj.meta.local_files_only
obj.revision = None
obj.tolerance_s = tolerance_s
obj.image_writer = None
@ -1033,7 +1052,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
delta_timestamps: dict[list[float]] | None = None,
tolerances_s: dict | None = None,
download_videos: bool = True,
local_files_only: bool = False,
video_backend: str | None = None,
):
super().__init__()
@ -1051,7 +1069,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
delta_timestamps=delta_timestamps,
tolerance_s=self.tolerances_s[repo_id],
download_videos=download_videos,
local_files_only=local_files_only,
video_backend=video_backend,
)
for repo_id in repo_ids

View File

@ -13,10 +13,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import importlib.resources
import json
import logging
import textwrap
from collections.abc import Iterator
from itertools import accumulate
from pathlib import Path
@ -31,9 +31,11 @@ import pyarrow.compute as pc
import torch
from datasets.table import embed_table_storage
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
from packaging import version
from PIL import Image as PILImage
from torchvision import transforms
from lerobot.common.datasets.backward_compatibility import V21_MESSAGE, BackwardCompatibilityError
from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.utils.utils import is_valid_numpy_dtype_string
from lerobot.configs.types import DictLike, FeatureType, PolicyFeature
@ -200,7 +202,7 @@ def write_task(task_index: int, task: dict, local_dir: Path):
append_jsonlines(task_dict, local_dir / TASKS_PATH)
def load_tasks(local_dir: Path) -> dict:
def load_tasks(local_dir: Path) -> tuple[dict, dict]:
tasks = load_jsonlines(local_dir / TASKS_PATH)
tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
task_to_task_index = {task: task_index for task_index, task in tasks.items()}
@ -231,7 +233,9 @@ def load_episodes_stats(local_dir: Path) -> dict:
}
def backward_compatible_episodes_stats(stats, episodes: list[int]) -> dict[str, dict[str, np.ndarray]]:
def backward_compatible_episodes_stats(
stats: dict[str, dict[str, np.ndarray]], episodes: list[int]
) -> dict[str, dict[str, np.ndarray]]:
return {ep_idx: stats for ep_idx in episodes}
@ -265,73 +269,38 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
return items_dict
def _get_major_minor(version: str) -> tuple[int]:
split = version.strip("v").split(".")
return int(split[0]), int(split[1])
class BackwardCompatibilityError(Exception):
def __init__(self, repo_id, version):
message = textwrap.dedent(f"""
BackwardCompatibilityError: The dataset you requested ({repo_id}) is in {version} format.
We introduced a new format since v2.0 which is not backward compatible with v1.x.
Please, use our conversion script. Modify the following command with your own task description:
```
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \\
--repo-id {repo_id} \\
--single-task "TASK DESCRIPTION." # <---- /!\\ Replace TASK DESCRIPTION /!\\
```
A few examples to replace TASK DESCRIPTION: "Pick up the blue cube and place it into the bin.",
"Insert the peg into the socket.", "Slide open the ziploc bag.", "Take the elevator to the 1st floor.",
"Open the top cabinet, store the pot inside it then close the cabinet.", "Push the T-shaped block onto the T-shaped target.",
"Grab the spray paint on the shelf and place it in the bin on top of the robot dog.", "Fold the sweatshirt.", ...
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
""")
super().__init__(message)
def check_version_compatibility(
repo_id: str, version_to_check: str, current_version: str, enforce_breaking_major: bool = True
) -> None:
current_major, _ = _get_major_minor(current_version)
major_to_check, _ = _get_major_minor(version_to_check)
if major_to_check < current_major and enforce_breaking_major:
raise BackwardCompatibilityError(repo_id, version_to_check)
elif float(version_to_check.strip("v")) < float(current_version.strip("v")):
logging.warning(
f"""The dataset you requested ({repo_id}) was created with a previous version ({version_to_check}) of the
codebase. The current codebase version is {current_version}. You should be fine since
backward compatibility is maintained. If you encounter a problem, contact LeRobot maintainers on
Discord ('https://discord.com/invite/s3KuuzsPFb') or open an issue on github.""",
)
v_check = version.parse(version_to_check)
v_current = version.parse(current_version)
if v_check.major < v_current.major and enforce_breaking_major:
raise BackwardCompatibilityError(repo_id, v_check)
elif v_check.minor < v_current.minor:
logging.warning(V21_MESSAGE.format(repo_id=repo_id, version=version_to_check))
def get_hub_safe_version(repo_id: str, version: str) -> str:
def get_repo_versions(repo_id: str) -> list[version.Version]:
"""Returns available valid versions (branches and tags) on given repo."""
api = HfApi()
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
branches = [b.name for b in dataset_info.branches]
if version not in branches:
num_version = float(version.strip("v"))
hub_num_versions = [float(v.strip("v")) for v in branches if v.startswith("v")]
if num_version >= 2.0 and all(v < 2.0 for v in hub_num_versions):
raise BackwardCompatibilityError(repo_id, version)
repo_refs = api.list_repo_refs(repo_id, repo_type="dataset")
repo_refs = [b.name for b in repo_refs.branches + repo_refs.tags]
repo_versions = []
for ref in repo_refs:
with contextlib.suppress(version.InvalidVersion):
repo_versions.append(version.parse(ref))
logging.warning(
f"""You are trying to load a dataset from {repo_id} created with a previous version of the
codebase. The following versions are available: {branches}.
The requested version ('{version}') is not found. You should be fine since
backward compatibility is maintained. If you encounter a problem, contact LeRobot maintainers on
Discord ('https://discord.com/invite/s3KuuzsPFb') or open an issue on github.""",
)
if "main" not in branches:
raise ValueError(f"Version 'main' not found on {repo_id}")
return "main"
else:
return version
return repo_versions
def get_safe_revision(repo_id: str, revision: str) -> str:
"""Returns the version if available on repo, otherwise return the latest available."""
api = HfApi()
if api.revision_exists(repo_id, revision, repo_type="dataset"):
return revision
hub_versions = get_repo_versions(repo_id)
return f"v{max(hub_versions)}"
def get_hf_features_from_features(features: dict) -> datasets.Features:

View File

@ -130,7 +130,7 @@ from lerobot.common.datasets.utils import (
create_branch,
create_lerobot_dataset_card,
flatten_dict,
get_hub_safe_version,
get_safe_revision,
load_json,
unflatten_dict,
write_json,
@ -443,7 +443,7 @@ def convert_dataset(
test_branch: str | None = None,
**card_kwargs,
):
v1 = get_hub_safe_version(repo_id, V16)
v1 = get_safe_revision(repo_id, V16)
v1x_dir = local_dir / V16 / repo_id
v20_dir = local_dir / V20 / repo_id
v1x_dir.mkdir(parents=True, exist_ok=True)

View File

@ -0,0 +1,49 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script is for internal use to convert all datasets under the 'lerobot' hub user account to v2.1.
"""
import traceback
from pathlib import Path
from lerobot import available_datasets
from lerobot.common.datasets.v21.convert_dataset_v20_to_v21 import convert_dataset
LOCAL_DIR = Path("data/")
def batch_convert():
status = {}
logfile = LOCAL_DIR / "conversion_log_v21.txt"
for num, repo_id in available_datasets:
print(f"\nConverting {repo_id} ({num}/{len(available_datasets)})")
print("---------------------------------------------------------")
try:
convert_dataset(repo_id)
status = f"{repo_id}: success."
with open(logfile, "a") as file:
file.write(status + "\n")
except Exception:
status = f"{repo_id}: failed\n {traceback.format_exc()}"
with open(logfile, "a") as file:
file.write(status + "\n")
continue
if __name__ == "__main__":
batch_convert()

View File

@ -1,10 +1,12 @@
"""
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.0 to
2.1. It performs the following:
2.1. It will:
- Generates per-episodes stats and writes them in `episodes_stats.jsonl`
- Check consistency between these new stats and the old ones.
- Removes the deprecated `stats.json` (by default)
- Updates codebase_version in `info.json`
- Push this new version to the hub on the 'main' branch and tags it with "v2.1".
Usage:
@ -14,9 +16,9 @@ python lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py \
```
"""
# TODO(rcadene, aliberts): ensure this script works for any other changes for the final v2.1
import argparse
import logging
from huggingface_hub import HfApi
@ -24,14 +26,27 @@ from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDat
from lerobot.common.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info
from lerobot.common.datasets.v21.convert_stats import check_aggregate_stats, convert_stats
V20 = "v2.0"
V21 = "v2.1"
def main(
class SuppressWarnings:
def __enter__(self):
self.previous_level = logging.getLogger().getEffectiveLevel()
logging.getLogger().setLevel(logging.ERROR)
def __exit__(self, exc_type, exc_val, exc_tb):
logging.getLogger().setLevel(self.previous_level)
def convert_dataset(
repo_id: str,
test_branch: str | None = None,
delete_old_stats: bool = False,
branch: str | None = None,
num_workers: int = 4,
):
dataset = LeRobotDataset(repo_id)
with SuppressWarnings():
dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True)
if (dataset.root / EPISODES_STATS_PATH).is_file():
raise FileExistsError("episodes_stats.jsonl already exists.")
@ -42,18 +57,21 @@ def main(
dataset.meta.info["codebase_version"] = CODEBASE_VERSION
write_info(dataset.meta.info, dataset.root)
dataset.push_to_hub(branch=test_branch, create_card=False, allow_patterns="meta/")
dataset.push_to_hub(branch=branch, allow_patterns="meta/")
if delete_old_stats:
if (dataset.root / STATS_PATH).is_file:
(dataset.root / STATS_PATH).unlink()
hub_api = HfApi()
if hub_api.file_exists(
STATS_PATH, repo_id=dataset.repo_id, revision=test_branch, repo_type="dataset"
):
hub_api.delete_file(
STATS_PATH, repo_id=dataset.repo_id, revision=test_branch, repo_type="dataset"
)
# delete old stats.json file
if (dataset.root / STATS_PATH).is_file:
(dataset.root / STATS_PATH).unlink()
hub_api = HfApi()
if hub_api.file_exists(
repo_id=dataset.repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset"
):
hub_api.delete_file(
path_in_repo=STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset"
)
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
if __name__ == "__main__":
@ -65,16 +83,10 @@ if __name__ == "__main__":
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
)
parser.add_argument(
"--test-branch",
"--branch",
type=str,
default=None,
help="Repo branch to test your conversion first (e.g. 'v2.0.test')",
)
parser.add_argument(
"--delete-old-stats",
type=bool,
default=False,
help="Delete the deprecated `stats.json`",
help="Repo branch to push your dataset (defaults to the main branch)",
)
parser.add_argument(
"--num-workers",
@ -84,4 +96,4 @@ if __name__ == "__main__":
)
args = parser.parse_args()
main(**vars(args))
convert_dataset(**vars(args))

View File

@ -81,9 +81,6 @@ class RecordControlConfig(ControlConfig):
play_sounds: bool = True
# Resume recording on an existing dataset.
resume: bool = False
# TODO(rcadene, aliberts): remove local_files_only when refactor with dataset as argument
# Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.
local_files_only: bool = False
def __post_init__(self):
# HACK: We parse again the cli args here to get the pretrained path if there was one.
@ -128,9 +125,6 @@ class ReplayControlConfig(ControlConfig):
fps: int | None = None
# Use vocal synthesis to read events.
play_sounds: bool = True
# TODO(rcadene, aliberts): remove local_files_only when refactor with dataset as argument
# Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.
local_files_only: bool = False
@dataclass

View File

@ -31,7 +31,7 @@ class DatasetConfig:
repo_id: str
episodes: list[int] | None = None
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
local_files_only: bool = False
revision: str | None = None
use_imagenet_stats: bool = True
video_backend: str = "pyav"

View File

@ -85,7 +85,6 @@ python lerobot/scripts/control_robot.py record \
This might require a sudo permission to allow your terminal to monitor keyboard events.
**NOTE**: You can resume/continue data recording by running the same data recording command and adding `--control.resume=true`.
If the dataset you want to extend is not on the hub, you also need to add `--control.local_files_only=true`.
- Train on this dataset with the ACT policy:
```bash
@ -216,7 +215,6 @@ def record(
dataset = LeRobotDataset(
cfg.repo_id,
root=cfg.root,
local_files_only=cfg.local_files_only,
)
if len(robot.cameras) > 0:
dataset.start_image_writer(
@ -318,9 +316,7 @@ def replay(
# TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset
# TODO(rcadene): Add option to record logs
dataset = LeRobotDataset(
cfg.repo_id, root=cfg.root, episodes=[cfg.episode], local_files_only=cfg.local_files_only
)
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root, episodes=[cfg.episode])
actions = dataset.hf_dataset.select_columns("action")
if not robot.is_connected:

View File

@ -207,12 +207,6 @@ def main():
required=True,
help="Episode to visualize.",
)
parser.add_argument(
"--local-files-only",
type=int,
default=0,
help="Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.",
)
parser.add_argument(
"--root",
type=Path,
@ -275,10 +269,9 @@ def main():
kwargs = vars(args)
repo_id = kwargs.pop("repo_id")
root = kwargs.pop("root")
local_files_only = kwargs.pop("local_files_only")
logging.info("Loading dataset")
dataset = LeRobotDataset(repo_id, root=root, local_files_only=local_files_only)
dataset = LeRobotDataset(repo_id, root=root)
visualize_dataset(dataset, **vars(args))

View File

@ -384,12 +384,6 @@ def main():
default=None,
help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht).",
)
parser.add_argument(
"--local-files-only",
type=int,
default=0,
help="Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.",
)
parser.add_argument(
"--root",
type=Path,
@ -445,15 +439,10 @@ def main():
repo_id = kwargs.pop("repo_id")
load_from_hf_hub = kwargs.pop("load_from_hf_hub")
root = kwargs.pop("root")
local_files_only = kwargs.pop("local_files_only")
dataset = None
if repo_id:
dataset = (
LeRobotDataset(repo_id, root=root, local_files_only=local_files_only)
if not load_from_hf_hub
else get_dataset_info(repo_id)
)
dataset = LeRobotDataset(repo_id, root=root) if not load_from_hf_hub else get_dataset_info(repo_id)
visualize_dataset_html(dataset, **vars(args))

View File

@ -109,7 +109,7 @@ def visualize_image_transforms(cfg: DatasetConfig, output_dir: Path = OUTPUT_DIR
dataset = LeRobotDataset(
repo_id=cfg.repo_id,
episodes=cfg.episodes,
local_files_only=cfg.local_files_only,
revision=cfg.revision,
video_backend=cfg.video_backend,
)

View File

@ -47,6 +47,7 @@ dependencies = [
"numba>=0.59.0",
"omegaconf>=2.3.0",
"opencv-python>=4.9.0",
"packaging>=24.2",
"pyav>=12.0.5",
"pymunk>=6.6.0",
"rerun-sdk>=0.21.0",
@ -54,7 +55,7 @@ dependencies = [
"torch>=2.2.1",
"torchvision>=0.21.0",
"wandb>=0.16.3",
"zarr>=2.17.0"
"zarr>=2.17.0",
]
[project.optional-dependencies]

View File

@ -309,7 +309,6 @@ def lerobot_dataset_metadata_factory(
episodes_stats: list[dict] | None = None,
tasks: list[dict] | None = None,
episodes: list[dict] | None = None,
local_files_only: bool = False,
) -> LeRobotDatasetMetadata:
if not info:
info = info_factory()
@ -335,16 +334,16 @@ def lerobot_dataset_metadata_factory(
)
with (
patch(
"lerobot.common.datasets.lerobot_dataset.get_hub_safe_version"
) as mock_get_hub_safe_version_patch,
"lerobot.common.datasets.lerobot_dataset.get_safe_revision"
) as mock_get_safe_revision_patch,
patch(
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
) as mock_snapshot_download_patch,
):
mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version: version
mock_get_safe_revision_patch.side_effect = lambda repo_id, version: version
mock_snapshot_download_patch.side_effect = mock_snapshot_download
return LeRobotDatasetMetadata(repo_id=repo_id, root=root, local_files_only=local_files_only)
return LeRobotDatasetMetadata(repo_id=repo_id, root=root)
return _create_lerobot_dataset_metadata
@ -411,15 +410,18 @@ def lerobot_dataset_factory(
episodes_stats=episodes_stats,
tasks=tasks,
episodes=episode_dicts,
local_files_only=kwargs.get("local_files_only", False),
)
with (
patch("lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch,
patch(
"lerobot.common.datasets.lerobot_dataset.get_safe_revision"
) as mock_get_safe_revision_patch,
patch(
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
) as mock_snapshot_download_patch,
):
mock_metadata_patch.return_value = mock_metadata
mock_get_safe_revision_patch.side_effect = lambda repo_id, version: version
mock_snapshot_download_patch.side_effect = mock_snapshot_download
return LeRobotDataset(repo_id=repo_id, root=root, **kwargs)

View File

@ -167,9 +167,7 @@ def test_record_and_replay_and_policy(tmp_path, request, robot_type, mock):
assert dataset.meta.total_episodes == 2
assert len(dataset) == 2
replay_cfg = ReplayControlConfig(
episode=0, fps=1, root=root, repo_id=repo_id, play_sounds=False, local_files_only=True
)
replay_cfg = ReplayControlConfig(episode=0, fps=1, root=root, repo_id=repo_id, play_sounds=False)
replay(robot, replay_cfg)
policy_cfg = ACTConfig()
@ -266,7 +264,6 @@ def test_resume_record(tmp_path, request, robot_type, mock):
video=False,
display_cameras=False,
play_sounds=False,
local_files_only=True,
num_episodes=1,
)

View File

@ -77,10 +77,6 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
root_init = tmp_path / "init"
dataset_init = lerobot_dataset_factory(root=root_init)
# Access the '_hub_version' cached_property in both instances to force its creation
_ = dataset_init.meta._hub_version
_ = dataset_create.meta._hub_version
init_attr = set(vars(dataset_init).keys())
create_attr = set(vars(dataset_create).keys())