Remove `local_files_only` and use `codebase_version` instead of branches (#734)
This commit is contained in:
parent
624eaf1175
commit
fbf2f2222a
|
@ -223,5 +223,5 @@ if __name__ == "__main__":
|
|||
main(raw_dir, repo_id=repo_id, mode=mode)
|
||||
|
||||
# Uncomment if you want to load the local dataset and explore it
|
||||
# dataset = LeRobotDataset(repo_id=repo_id, local_files_only=True)
|
||||
# dataset = LeRobotDataset(repo_id=repo_id)
|
||||
# breakpoint()
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
import packaging.version
|
||||
|
||||
V2_MESSAGE = """
|
||||
The dataset you requested ({repo_id}) is in {version} format.
|
||||
|
||||
We introduced a new format since v2.0 which is not backward compatible with v1.x.
|
||||
Please, use our conversion script. Modify the following command with your own task description:
|
||||
```
|
||||
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \\
|
||||
--repo-id {repo_id} \\
|
||||
--single-task "TASK DESCRIPTION." # <---- /!\\ Replace TASK DESCRIPTION /!\\
|
||||
```
|
||||
|
||||
A few examples to replace TASK DESCRIPTION: "Pick up the blue cube and place it into the bin.", "Insert the
|
||||
peg into the socket.", "Slide open the ziploc bag.", "Take the elevator to the 1st floor.", "Open the top
|
||||
cabinet, store the pot inside it then close the cabinet.", "Push the T-shaped block onto the T-shaped
|
||||
target.", "Grab the spray paint on the shelf and place it in the bin on top of the robot dog.", "Fold the
|
||||
sweatshirt.", ...
|
||||
|
||||
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
|
||||
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
|
||||
"""
|
||||
|
||||
V21_MESSAGE = """
|
||||
The dataset you requested ({repo_id}) is in {version} format.
|
||||
While current version of LeRobot is backward-compatible with it, the version of your dataset still uses global
|
||||
stats instead of per-episode stats. Update your dataset stats to the new format using this command:
|
||||
```
|
||||
python lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py --repo-id={repo_id}
|
||||
```
|
||||
|
||||
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
|
||||
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
|
||||
"""
|
||||
|
||||
|
||||
class BackwardCompatibilityError(Exception):
|
||||
def __init__(self, repo_id: str, version: packaging.version.Version):
|
||||
message = V2_MESSAGE.format(repo_id=repo_id, version=version)
|
||||
super().__init__(message)
|
|
@ -83,15 +83,15 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
|||
)
|
||||
|
||||
if isinstance(cfg.dataset.repo_id, str):
|
||||
ds_meta = LeRobotDatasetMetadata(cfg.dataset.repo_id, local_files_only=cfg.dataset.local_files_only)
|
||||
ds_meta = LeRobotDatasetMetadata(cfg.dataset.repo_id, revision=cfg.dataset.revision)
|
||||
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
episodes=cfg.dataset.episodes,
|
||||
delta_timestamps=delta_timestamps,
|
||||
image_transforms=image_transforms,
|
||||
revision=cfg.dataset.revision,
|
||||
video_backend=cfg.dataset.video_backend,
|
||||
local_files_only=cfg.dataset.local_files_only,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
|
@ -27,6 +26,8 @@ import torch
|
|||
import torch.utils
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import HfApi, snapshot_download
|
||||
from huggingface_hub.constants import REPOCARD_NAME
|
||||
from packaging import version
|
||||
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
|
||||
from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
|
||||
|
@ -41,14 +42,13 @@ from lerobot.common.datasets.utils import (
|
|||
check_frame_features,
|
||||
check_timestamps_sync,
|
||||
check_version_compatibility,
|
||||
create_branch,
|
||||
create_empty_dataset_info,
|
||||
create_lerobot_dataset_card,
|
||||
get_delta_indices,
|
||||
get_episode_data_index,
|
||||
get_features_from_robot,
|
||||
get_hf_features_from_features,
|
||||
get_hub_safe_version,
|
||||
get_safe_revision,
|
||||
hf_transform_to_torch,
|
||||
load_episodes,
|
||||
load_episodes_stats,
|
||||
|
@ -79,30 +79,35 @@ class LeRobotDatasetMetadata:
|
|||
self,
|
||||
repo_id: str,
|
||||
root: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
revision: str | None = None,
|
||||
force_cache_sync: bool = False,
|
||||
):
|
||||
self.repo_id = repo_id
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
self.root = Path(root) if root is not None else LEROBOT_HOME / repo_id
|
||||
self.local_files_only = local_files_only
|
||||
|
||||
# Load metadata
|
||||
(self.root / "meta").mkdir(exist_ok=True, parents=True)
|
||||
self.pull_from_repo(allow_patterns="meta/")
|
||||
try:
|
||||
if force_cache_sync:
|
||||
raise FileNotFoundError
|
||||
self.load_metadata()
|
||||
except (FileNotFoundError, NotADirectoryError):
|
||||
(self.root / "meta").mkdir(exist_ok=True, parents=True)
|
||||
self.revision = get_safe_revision(self.repo_id, self.revision)
|
||||
self.pull_from_repo(allow_patterns="meta/")
|
||||
self.load_metadata()
|
||||
|
||||
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
|
||||
|
||||
def load_metadata(self):
|
||||
self.info = load_info(self.root)
|
||||
self.stats = load_stats(self.root)
|
||||
self.tasks, self.task_to_task_index = load_tasks(self.root)
|
||||
self.episodes = load_episodes(self.root)
|
||||
try:
|
||||
self.episodes_stats = load_episodes_stats(self.root)
|
||||
self.stats = aggregate_stats(list(self.episodes_stats.values()))
|
||||
except FileNotFoundError:
|
||||
logging.warning(
|
||||
f"""'episodes_stats.jsonl' not found. Using global dataset stats for each episode instead.
|
||||
Convert your dataset stats to the new format using this command:
|
||||
python lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py --repo-id={self.repo_id} """
|
||||
)
|
||||
if version.parse(self._version) < version.parse("v2.1"):
|
||||
self.stats = load_stats(self.root)
|
||||
self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes)
|
||||
else:
|
||||
self.episodes_stats = load_episodes_stats(self.root)
|
||||
self.stats = aggregate_stats(list(self.episodes_stats.values()))
|
||||
|
||||
def pull_from_repo(
|
||||
self,
|
||||
|
@ -112,17 +117,12 @@ class LeRobotDatasetMetadata:
|
|||
snapshot_download(
|
||||
self.repo_id,
|
||||
repo_type="dataset",
|
||||
revision=self._hub_version,
|
||||
revision=self.revision,
|
||||
local_dir=self.root,
|
||||
allow_patterns=allow_patterns,
|
||||
ignore_patterns=ignore_patterns,
|
||||
local_files_only=self.local_files_only,
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def _hub_version(self) -> str | None:
|
||||
return None if self.local_files_only else get_hub_safe_version(self.repo_id, CODEBASE_VERSION)
|
||||
|
||||
@property
|
||||
def _version(self) -> str:
|
||||
"""Codebase version used to create this dataset."""
|
||||
|
@ -342,7 +342,7 @@ class LeRobotDatasetMetadata:
|
|||
if len(obj.video_keys) > 0 and not use_videos:
|
||||
raise ValueError()
|
||||
write_json(obj.info, obj.root / INFO_PATH)
|
||||
obj.local_files_only = True
|
||||
obj.revision = None
|
||||
return obj
|
||||
|
||||
|
||||
|
@ -355,8 +355,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
image_transforms: Callable | None = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
tolerance_s: float = 1e-4,
|
||||
revision: str | None = None,
|
||||
force_cache_sync: bool = False,
|
||||
download_videos: bool = True,
|
||||
local_files_only: bool = False,
|
||||
video_backend: str | None = None,
|
||||
):
|
||||
"""
|
||||
|
@ -366,7 +367,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
- On your local disk in the 'root' folder. This is typically the case when you recorded your
|
||||
dataset locally and you may or may not have pushed it to the hub yet. Instantiating this class
|
||||
with 'root' will load your dataset directly from disk. This can happen while you're offline (no
|
||||
internet connection), in that case, use local_files_only=True.
|
||||
internet connection).
|
||||
|
||||
- On the Hugging Face Hub at the address https://huggingface.co/datasets/{repo_id} and not on
|
||||
your local disk in the 'root' folder. Instantiating this class with this 'repo_id' will download
|
||||
|
@ -448,11 +449,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
timestamps is separated to the next by 1/fps +/- tolerance_s. This also applies to frames
|
||||
decoded from video files. It is also used to check that `delta_timestamps` (when provided) are
|
||||
multiples of 1/fps. Defaults to 1e-4.
|
||||
revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a
|
||||
commit hash. Defaults to current codebase version tag.
|
||||
sync_cache_first (bool, optional): Flag to sync and refresh local files first. If True and files
|
||||
are already present in the local cache, this will be faster. However, files loaded might not
|
||||
be in sync with the version on the hub, especially if you specified 'revision'. Defaults to
|
||||
False.
|
||||
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
|
||||
video files are already present on local disk, they won't be downloaded again. Defaults to
|
||||
True.
|
||||
local_files_only (bool, optional): Flag to use local files only. If True, no requests to the hub
|
||||
will be made. Defaults to False.
|
||||
video_backend (str | None, optional): Video backend to use for decoding videos. There is currently
|
||||
a single option which is the pyav decoder used by Torchvision. Defaults to pyav.
|
||||
"""
|
||||
|
@ -463,9 +468,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
self.delta_timestamps = delta_timestamps
|
||||
self.episodes = episodes
|
||||
self.tolerance_s = tolerance_s
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
self.video_backend = video_backend if video_backend else "pyav"
|
||||
self.delta_indices = None
|
||||
self.local_files_only = local_files_only
|
||||
|
||||
# Unused attributes
|
||||
self.image_writer = None
|
||||
|
@ -474,17 +479,24 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
self.root.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# Load metadata
|
||||
self.meta = LeRobotDatasetMetadata(self.repo_id, self.root, self.local_files_only)
|
||||
if self.episodes is not None and self.meta._version == CODEBASE_VERSION:
|
||||
self.meta = LeRobotDatasetMetadata(
|
||||
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
|
||||
)
|
||||
if self.episodes is not None and version.parse(self.meta._version) >= version.parse("v2.1"):
|
||||
episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes]
|
||||
self.stats = aggregate_stats(episodes_stats)
|
||||
|
||||
# Check version
|
||||
check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION)
|
||||
|
||||
# Load actual data
|
||||
self.download_episodes(download_videos)
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
try:
|
||||
if force_cache_sync:
|
||||
raise FileNotFoundError
|
||||
assert all((self.root / fpath).is_file() for fpath in self.get_episodes_file_paths())
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
except (AssertionError, FileNotFoundError, NotADirectoryError):
|
||||
self.revision = get_safe_revision(self.repo_id, self.revision)
|
||||
self.download_episodes(download_videos)
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
|
||||
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
||||
|
||||
# Check timestamps
|
||||
|
@ -501,7 +513,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
def push_to_hub(
|
||||
self,
|
||||
branch: str | None = None,
|
||||
create_card: bool = True,
|
||||
tags: list | None = None,
|
||||
license: str | None = "apache-2.0",
|
||||
push_videos: bool = True,
|
||||
|
@ -528,7 +539,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
exist_ok=True,
|
||||
)
|
||||
if branch:
|
||||
create_branch(repo_id=self.repo_id, branch=branch, repo_type="dataset")
|
||||
hub_api.create_branch(
|
||||
repo_id=self.repo_id,
|
||||
branch=branch,
|
||||
revision=self.revision,
|
||||
repo_type="dataset",
|
||||
exist_ok=True,
|
||||
)
|
||||
|
||||
hub_api.upload_folder(
|
||||
repo_id=self.repo_id,
|
||||
|
@ -538,15 +555,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
allow_patterns=allow_patterns,
|
||||
ignore_patterns=ignore_patterns,
|
||||
)
|
||||
if create_card:
|
||||
if not hub_api.file_exists(self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch):
|
||||
card = create_lerobot_dataset_card(
|
||||
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
|
||||
)
|
||||
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)
|
||||
|
||||
if not branch:
|
||||
create_branch(repo_id=self.repo_id, branch=CODEBASE_VERSION, repo_type="dataset")
|
||||
|
||||
def pull_from_repo(
|
||||
self,
|
||||
allow_patterns: list[str] | str | None = None,
|
||||
|
@ -555,11 +569,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
snapshot_download(
|
||||
self.repo_id,
|
||||
repo_type="dataset",
|
||||
revision=self.meta._hub_version,
|
||||
revision=self.revision,
|
||||
local_dir=self.root,
|
||||
allow_patterns=allow_patterns,
|
||||
ignore_patterns=ignore_patterns,
|
||||
local_files_only=self.local_files_only,
|
||||
)
|
||||
|
||||
def download_episodes(self, download_videos: bool = True) -> None:
|
||||
|
@ -573,17 +586,23 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
files = None
|
||||
ignore_patterns = None if download_videos else "videos/"
|
||||
if self.episodes is not None:
|
||||
files = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
|
||||
if len(self.meta.video_keys) > 0 and download_videos:
|
||||
video_files = [
|
||||
str(self.meta.get_video_file_path(ep_idx, vid_key))
|
||||
for vid_key in self.meta.video_keys
|
||||
for ep_idx in self.episodes
|
||||
]
|
||||
files += video_files
|
||||
files = self.get_episodes_file_paths()
|
||||
|
||||
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
|
||||
|
||||
def get_episodes_file_paths(self) -> list[Path]:
|
||||
episodes = self.episodes if self.episodes is not None else list(range(self.meta.total_episodes))
|
||||
fpaths = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in episodes]
|
||||
if len(self.meta.video_keys) > 0:
|
||||
video_files = [
|
||||
str(self.meta.get_video_file_path(ep_idx, vid_key))
|
||||
for vid_key in self.meta.video_keys
|
||||
for ep_idx in episodes
|
||||
]
|
||||
fpaths += video_files
|
||||
|
||||
return fpaths
|
||||
|
||||
def load_hf_dataset(self) -> datasets.Dataset:
|
||||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||
if self.episodes is None:
|
||||
|
@ -991,7 +1010,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
)
|
||||
obj.repo_id = obj.meta.repo_id
|
||||
obj.root = obj.meta.root
|
||||
obj.local_files_only = obj.meta.local_files_only
|
||||
obj.revision = None
|
||||
obj.tolerance_s = tolerance_s
|
||||
obj.image_writer = None
|
||||
|
||||
|
@ -1033,7 +1052,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
|||
delta_timestamps: dict[list[float]] | None = None,
|
||||
tolerances_s: dict | None = None,
|
||||
download_videos: bool = True,
|
||||
local_files_only: bool = False,
|
||||
video_backend: str | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
@ -1051,7 +1069,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
|||
delta_timestamps=delta_timestamps,
|
||||
tolerance_s=self.tolerances_s[repo_id],
|
||||
download_videos=download_videos,
|
||||
local_files_only=local_files_only,
|
||||
video_backend=video_backend,
|
||||
)
|
||||
for repo_id in repo_ids
|
||||
|
|
|
@ -13,10 +13,10 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
import importlib.resources
|
||||
import json
|
||||
import logging
|
||||
import textwrap
|
||||
from collections.abc import Iterator
|
||||
from itertools import accumulate
|
||||
from pathlib import Path
|
||||
|
@ -31,9 +31,11 @@ import pyarrow.compute as pc
|
|||
import torch
|
||||
from datasets.table import embed_table_storage
|
||||
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
|
||||
from packaging import version
|
||||
from PIL import Image as PILImage
|
||||
from torchvision import transforms
|
||||
|
||||
from lerobot.common.datasets.backward_compatibility import V21_MESSAGE, BackwardCompatibilityError
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
from lerobot.common.utils.utils import is_valid_numpy_dtype_string
|
||||
from lerobot.configs.types import DictLike, FeatureType, PolicyFeature
|
||||
|
@ -200,7 +202,7 @@ def write_task(task_index: int, task: dict, local_dir: Path):
|
|||
append_jsonlines(task_dict, local_dir / TASKS_PATH)
|
||||
|
||||
|
||||
def load_tasks(local_dir: Path) -> dict:
|
||||
def load_tasks(local_dir: Path) -> tuple[dict, dict]:
|
||||
tasks = load_jsonlines(local_dir / TASKS_PATH)
|
||||
tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
|
||||
task_to_task_index = {task: task_index for task_index, task in tasks.items()}
|
||||
|
@ -231,7 +233,9 @@ def load_episodes_stats(local_dir: Path) -> dict:
|
|||
}
|
||||
|
||||
|
||||
def backward_compatible_episodes_stats(stats, episodes: list[int]) -> dict[str, dict[str, np.ndarray]]:
|
||||
def backward_compatible_episodes_stats(
|
||||
stats: dict[str, dict[str, np.ndarray]], episodes: list[int]
|
||||
) -> dict[str, dict[str, np.ndarray]]:
|
||||
return {ep_idx: stats for ep_idx in episodes}
|
||||
|
||||
|
||||
|
@ -265,73 +269,38 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
|||
return items_dict
|
||||
|
||||
|
||||
def _get_major_minor(version: str) -> tuple[int]:
|
||||
split = version.strip("v").split(".")
|
||||
return int(split[0]), int(split[1])
|
||||
|
||||
|
||||
class BackwardCompatibilityError(Exception):
|
||||
def __init__(self, repo_id, version):
|
||||
message = textwrap.dedent(f"""
|
||||
BackwardCompatibilityError: The dataset you requested ({repo_id}) is in {version} format.
|
||||
|
||||
We introduced a new format since v2.0 which is not backward compatible with v1.x.
|
||||
Please, use our conversion script. Modify the following command with your own task description:
|
||||
```
|
||||
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \\
|
||||
--repo-id {repo_id} \\
|
||||
--single-task "TASK DESCRIPTION." # <---- /!\\ Replace TASK DESCRIPTION /!\\
|
||||
```
|
||||
|
||||
A few examples to replace TASK DESCRIPTION: "Pick up the blue cube and place it into the bin.",
|
||||
"Insert the peg into the socket.", "Slide open the ziploc bag.", "Take the elevator to the 1st floor.",
|
||||
"Open the top cabinet, store the pot inside it then close the cabinet.", "Push the T-shaped block onto the T-shaped target.",
|
||||
"Grab the spray paint on the shelf and place it in the bin on top of the robot dog.", "Fold the sweatshirt.", ...
|
||||
|
||||
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
|
||||
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
|
||||
""")
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
def check_version_compatibility(
|
||||
repo_id: str, version_to_check: str, current_version: str, enforce_breaking_major: bool = True
|
||||
) -> None:
|
||||
current_major, _ = _get_major_minor(current_version)
|
||||
major_to_check, _ = _get_major_minor(version_to_check)
|
||||
if major_to_check < current_major and enforce_breaking_major:
|
||||
raise BackwardCompatibilityError(repo_id, version_to_check)
|
||||
elif float(version_to_check.strip("v")) < float(current_version.strip("v")):
|
||||
logging.warning(
|
||||
f"""The dataset you requested ({repo_id}) was created with a previous version ({version_to_check}) of the
|
||||
codebase. The current codebase version is {current_version}. You should be fine since
|
||||
backward compatibility is maintained. If you encounter a problem, contact LeRobot maintainers on
|
||||
Discord ('https://discord.com/invite/s3KuuzsPFb') or open an issue on github.""",
|
||||
)
|
||||
v_check = version.parse(version_to_check)
|
||||
v_current = version.parse(current_version)
|
||||
if v_check.major < v_current.major and enforce_breaking_major:
|
||||
raise BackwardCompatibilityError(repo_id, v_check)
|
||||
elif v_check.minor < v_current.minor:
|
||||
logging.warning(V21_MESSAGE.format(repo_id=repo_id, version=version_to_check))
|
||||
|
||||
|
||||
def get_hub_safe_version(repo_id: str, version: str) -> str:
|
||||
def get_repo_versions(repo_id: str) -> list[version.Version]:
|
||||
"""Returns available valid versions (branches and tags) on given repo."""
|
||||
api = HfApi()
|
||||
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
|
||||
branches = [b.name for b in dataset_info.branches]
|
||||
if version not in branches:
|
||||
num_version = float(version.strip("v"))
|
||||
hub_num_versions = [float(v.strip("v")) for v in branches if v.startswith("v")]
|
||||
if num_version >= 2.0 and all(v < 2.0 for v in hub_num_versions):
|
||||
raise BackwardCompatibilityError(repo_id, version)
|
||||
repo_refs = api.list_repo_refs(repo_id, repo_type="dataset")
|
||||
repo_refs = [b.name for b in repo_refs.branches + repo_refs.tags]
|
||||
repo_versions = []
|
||||
for ref in repo_refs:
|
||||
with contextlib.suppress(version.InvalidVersion):
|
||||
repo_versions.append(version.parse(ref))
|
||||
|
||||
logging.warning(
|
||||
f"""You are trying to load a dataset from {repo_id} created with a previous version of the
|
||||
codebase. The following versions are available: {branches}.
|
||||
The requested version ('{version}') is not found. You should be fine since
|
||||
backward compatibility is maintained. If you encounter a problem, contact LeRobot maintainers on
|
||||
Discord ('https://discord.com/invite/s3KuuzsPFb') or open an issue on github.""",
|
||||
)
|
||||
if "main" not in branches:
|
||||
raise ValueError(f"Version 'main' not found on {repo_id}")
|
||||
return "main"
|
||||
else:
|
||||
return version
|
||||
return repo_versions
|
||||
|
||||
|
||||
def get_safe_revision(repo_id: str, revision: str) -> str:
|
||||
"""Returns the version if available on repo, otherwise return the latest available."""
|
||||
api = HfApi()
|
||||
if api.revision_exists(repo_id, revision, repo_type="dataset"):
|
||||
return revision
|
||||
|
||||
hub_versions = get_repo_versions(repo_id)
|
||||
return f"v{max(hub_versions)}"
|
||||
|
||||
|
||||
def get_hf_features_from_features(features: dict) -> datasets.Features:
|
||||
|
|
|
@ -130,7 +130,7 @@ from lerobot.common.datasets.utils import (
|
|||
create_branch,
|
||||
create_lerobot_dataset_card,
|
||||
flatten_dict,
|
||||
get_hub_safe_version,
|
||||
get_safe_revision,
|
||||
load_json,
|
||||
unflatten_dict,
|
||||
write_json,
|
||||
|
@ -443,7 +443,7 @@ def convert_dataset(
|
|||
test_branch: str | None = None,
|
||||
**card_kwargs,
|
||||
):
|
||||
v1 = get_hub_safe_version(repo_id, V16)
|
||||
v1 = get_safe_revision(repo_id, V16)
|
||||
v1x_dir = local_dir / V16 / repo_id
|
||||
v20_dir = local_dir / V20 / repo_id
|
||||
v1x_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script is for internal use to convert all datasets under the 'lerobot' hub user account to v2.1.
|
||||
"""
|
||||
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot import available_datasets
|
||||
from lerobot.common.datasets.v21.convert_dataset_v20_to_v21 import convert_dataset
|
||||
|
||||
LOCAL_DIR = Path("data/")
|
||||
|
||||
|
||||
def batch_convert():
|
||||
status = {}
|
||||
logfile = LOCAL_DIR / "conversion_log_v21.txt"
|
||||
for num, repo_id in available_datasets:
|
||||
print(f"\nConverting {repo_id} ({num}/{len(available_datasets)})")
|
||||
print("---------------------------------------------------------")
|
||||
try:
|
||||
convert_dataset(repo_id)
|
||||
status = f"{repo_id}: success."
|
||||
with open(logfile, "a") as file:
|
||||
file.write(status + "\n")
|
||||
except Exception:
|
||||
status = f"{repo_id}: failed\n {traceback.format_exc()}"
|
||||
with open(logfile, "a") as file:
|
||||
file.write(status + "\n")
|
||||
continue
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
batch_convert()
|
|
@ -1,10 +1,12 @@
|
|||
"""
|
||||
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.0 to
|
||||
2.1. It performs the following:
|
||||
2.1. It will:
|
||||
|
||||
- Generates per-episodes stats and writes them in `episodes_stats.jsonl`
|
||||
- Check consistency between these new stats and the old ones.
|
||||
- Removes the deprecated `stats.json` (by default)
|
||||
- Updates codebase_version in `info.json`
|
||||
- Push this new version to the hub on the 'main' branch and tags it with "v2.1".
|
||||
|
||||
Usage:
|
||||
|
||||
|
@ -14,9 +16,9 @@ python lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py \
|
|||
```
|
||||
|
||||
"""
|
||||
# TODO(rcadene, aliberts): ensure this script works for any other changes for the final v2.1
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
|
@ -24,14 +26,27 @@ from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDat
|
|||
from lerobot.common.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info
|
||||
from lerobot.common.datasets.v21.convert_stats import check_aggregate_stats, convert_stats
|
||||
|
||||
V20 = "v2.0"
|
||||
V21 = "v2.1"
|
||||
|
||||
def main(
|
||||
|
||||
class SuppressWarnings:
|
||||
def __enter__(self):
|
||||
self.previous_level = logging.getLogger().getEffectiveLevel()
|
||||
logging.getLogger().setLevel(logging.ERROR)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
logging.getLogger().setLevel(self.previous_level)
|
||||
|
||||
|
||||
def convert_dataset(
|
||||
repo_id: str,
|
||||
test_branch: str | None = None,
|
||||
delete_old_stats: bool = False,
|
||||
branch: str | None = None,
|
||||
num_workers: int = 4,
|
||||
):
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
with SuppressWarnings():
|
||||
dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True)
|
||||
|
||||
if (dataset.root / EPISODES_STATS_PATH).is_file():
|
||||
raise FileExistsError("episodes_stats.jsonl already exists.")
|
||||
|
||||
|
@ -42,18 +57,21 @@ def main(
|
|||
dataset.meta.info["codebase_version"] = CODEBASE_VERSION
|
||||
write_info(dataset.meta.info, dataset.root)
|
||||
|
||||
dataset.push_to_hub(branch=test_branch, create_card=False, allow_patterns="meta/")
|
||||
dataset.push_to_hub(branch=branch, allow_patterns="meta/")
|
||||
|
||||
if delete_old_stats:
|
||||
if (dataset.root / STATS_PATH).is_file:
|
||||
(dataset.root / STATS_PATH).unlink()
|
||||
hub_api = HfApi()
|
||||
if hub_api.file_exists(
|
||||
STATS_PATH, repo_id=dataset.repo_id, revision=test_branch, repo_type="dataset"
|
||||
):
|
||||
hub_api.delete_file(
|
||||
STATS_PATH, repo_id=dataset.repo_id, revision=test_branch, repo_type="dataset"
|
||||
)
|
||||
# delete old stats.json file
|
||||
if (dataset.root / STATS_PATH).is_file:
|
||||
(dataset.root / STATS_PATH).unlink()
|
||||
|
||||
hub_api = HfApi()
|
||||
if hub_api.file_exists(
|
||||
repo_id=dataset.repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset"
|
||||
):
|
||||
hub_api.delete_file(
|
||||
path_in_repo=STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset"
|
||||
)
|
||||
|
||||
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -65,16 +83,10 @@ if __name__ == "__main__":
|
|||
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test-branch",
|
||||
"--branch",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Repo branch to test your conversion first (e.g. 'v2.0.test')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--delete-old-stats",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Delete the deprecated `stats.json`",
|
||||
help="Repo branch to push your dataset (defaults to the main branch)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
|
@ -84,4 +96,4 @@ if __name__ == "__main__":
|
|||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(**vars(args))
|
||||
convert_dataset(**vars(args))
|
||||
|
|
|
@ -81,9 +81,6 @@ class RecordControlConfig(ControlConfig):
|
|||
play_sounds: bool = True
|
||||
# Resume recording on an existing dataset.
|
||||
resume: bool = False
|
||||
# TODO(rcadene, aliberts): remove local_files_only when refactor with dataset as argument
|
||||
# Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.
|
||||
local_files_only: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
||||
|
@ -128,9 +125,6 @@ class ReplayControlConfig(ControlConfig):
|
|||
fps: int | None = None
|
||||
# Use vocal synthesis to read events.
|
||||
play_sounds: bool = True
|
||||
# TODO(rcadene, aliberts): remove local_files_only when refactor with dataset as argument
|
||||
# Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.
|
||||
local_files_only: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
@ -31,7 +31,7 @@ class DatasetConfig:
|
|||
repo_id: str
|
||||
episodes: list[int] | None = None
|
||||
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
|
||||
local_files_only: bool = False
|
||||
revision: str | None = None
|
||||
use_imagenet_stats: bool = True
|
||||
video_backend: str = "pyav"
|
||||
|
||||
|
|
|
@ -85,7 +85,6 @@ python lerobot/scripts/control_robot.py record \
|
|||
This might require a sudo permission to allow your terminal to monitor keyboard events.
|
||||
|
||||
**NOTE**: You can resume/continue data recording by running the same data recording command and adding `--control.resume=true`.
|
||||
If the dataset you want to extend is not on the hub, you also need to add `--control.local_files_only=true`.
|
||||
|
||||
- Train on this dataset with the ACT policy:
|
||||
```bash
|
||||
|
@ -216,7 +215,6 @@ def record(
|
|||
dataset = LeRobotDataset(
|
||||
cfg.repo_id,
|
||||
root=cfg.root,
|
||||
local_files_only=cfg.local_files_only,
|
||||
)
|
||||
if len(robot.cameras) > 0:
|
||||
dataset.start_image_writer(
|
||||
|
@ -318,9 +316,7 @@ def replay(
|
|||
# TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset
|
||||
# TODO(rcadene): Add option to record logs
|
||||
|
||||
dataset = LeRobotDataset(
|
||||
cfg.repo_id, root=cfg.root, episodes=[cfg.episode], local_files_only=cfg.local_files_only
|
||||
)
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root, episodes=[cfg.episode])
|
||||
actions = dataset.hf_dataset.select_columns("action")
|
||||
|
||||
if not robot.is_connected:
|
||||
|
|
|
@ -207,12 +207,6 @@ def main():
|
|||
required=True,
|
||||
help="Episode to visualize.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--local-files-only",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root",
|
||||
type=Path,
|
||||
|
@ -275,10 +269,9 @@ def main():
|
|||
kwargs = vars(args)
|
||||
repo_id = kwargs.pop("repo_id")
|
||||
root = kwargs.pop("root")
|
||||
local_files_only = kwargs.pop("local_files_only")
|
||||
|
||||
logging.info("Loading dataset")
|
||||
dataset = LeRobotDataset(repo_id, root=root, local_files_only=local_files_only)
|
||||
dataset = LeRobotDataset(repo_id, root=root)
|
||||
|
||||
visualize_dataset(dataset, **vars(args))
|
||||
|
||||
|
|
|
@ -384,12 +384,6 @@ def main():
|
|||
default=None,
|
||||
help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--local-files-only",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root",
|
||||
type=Path,
|
||||
|
@ -445,15 +439,10 @@ def main():
|
|||
repo_id = kwargs.pop("repo_id")
|
||||
load_from_hf_hub = kwargs.pop("load_from_hf_hub")
|
||||
root = kwargs.pop("root")
|
||||
local_files_only = kwargs.pop("local_files_only")
|
||||
|
||||
dataset = None
|
||||
if repo_id:
|
||||
dataset = (
|
||||
LeRobotDataset(repo_id, root=root, local_files_only=local_files_only)
|
||||
if not load_from_hf_hub
|
||||
else get_dataset_info(repo_id)
|
||||
)
|
||||
dataset = LeRobotDataset(repo_id, root=root) if not load_from_hf_hub else get_dataset_info(repo_id)
|
||||
|
||||
visualize_dataset_html(dataset, **vars(args))
|
||||
|
||||
|
|
|
@ -109,7 +109,7 @@ def visualize_image_transforms(cfg: DatasetConfig, output_dir: Path = OUTPUT_DIR
|
|||
dataset = LeRobotDataset(
|
||||
repo_id=cfg.repo_id,
|
||||
episodes=cfg.episodes,
|
||||
local_files_only=cfg.local_files_only,
|
||||
revision=cfg.revision,
|
||||
video_backend=cfg.video_backend,
|
||||
)
|
||||
|
||||
|
|
|
@ -47,6 +47,7 @@ dependencies = [
|
|||
"numba>=0.59.0",
|
||||
"omegaconf>=2.3.0",
|
||||
"opencv-python>=4.9.0",
|
||||
"packaging>=24.2",
|
||||
"pyav>=12.0.5",
|
||||
"pymunk>=6.6.0",
|
||||
"rerun-sdk>=0.21.0",
|
||||
|
@ -54,7 +55,7 @@ dependencies = [
|
|||
"torch>=2.2.1",
|
||||
"torchvision>=0.21.0",
|
||||
"wandb>=0.16.3",
|
||||
"zarr>=2.17.0"
|
||||
"zarr>=2.17.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
|
|
@ -309,7 +309,6 @@ def lerobot_dataset_metadata_factory(
|
|||
episodes_stats: list[dict] | None = None,
|
||||
tasks: list[dict] | None = None,
|
||||
episodes: list[dict] | None = None,
|
||||
local_files_only: bool = False,
|
||||
) -> LeRobotDatasetMetadata:
|
||||
if not info:
|
||||
info = info_factory()
|
||||
|
@ -335,16 +334,16 @@ def lerobot_dataset_metadata_factory(
|
|||
)
|
||||
with (
|
||||
patch(
|
||||
"lerobot.common.datasets.lerobot_dataset.get_hub_safe_version"
|
||||
) as mock_get_hub_safe_version_patch,
|
||||
"lerobot.common.datasets.lerobot_dataset.get_safe_revision"
|
||||
) as mock_get_safe_revision_patch,
|
||||
patch(
|
||||
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
|
||||
) as mock_snapshot_download_patch,
|
||||
):
|
||||
mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version: version
|
||||
mock_get_safe_revision_patch.side_effect = lambda repo_id, version: version
|
||||
mock_snapshot_download_patch.side_effect = mock_snapshot_download
|
||||
|
||||
return LeRobotDatasetMetadata(repo_id=repo_id, root=root, local_files_only=local_files_only)
|
||||
return LeRobotDatasetMetadata(repo_id=repo_id, root=root)
|
||||
|
||||
return _create_lerobot_dataset_metadata
|
||||
|
||||
|
@ -411,15 +410,18 @@ def lerobot_dataset_factory(
|
|||
episodes_stats=episodes_stats,
|
||||
tasks=tasks,
|
||||
episodes=episode_dicts,
|
||||
local_files_only=kwargs.get("local_files_only", False),
|
||||
)
|
||||
with (
|
||||
patch("lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch,
|
||||
patch(
|
||||
"lerobot.common.datasets.lerobot_dataset.get_safe_revision"
|
||||
) as mock_get_safe_revision_patch,
|
||||
patch(
|
||||
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
|
||||
) as mock_snapshot_download_patch,
|
||||
):
|
||||
mock_metadata_patch.return_value = mock_metadata
|
||||
mock_get_safe_revision_patch.side_effect = lambda repo_id, version: version
|
||||
mock_snapshot_download_patch.side_effect = mock_snapshot_download
|
||||
|
||||
return LeRobotDataset(repo_id=repo_id, root=root, **kwargs)
|
||||
|
|
|
@ -167,9 +167,7 @@ def test_record_and_replay_and_policy(tmp_path, request, robot_type, mock):
|
|||
assert dataset.meta.total_episodes == 2
|
||||
assert len(dataset) == 2
|
||||
|
||||
replay_cfg = ReplayControlConfig(
|
||||
episode=0, fps=1, root=root, repo_id=repo_id, play_sounds=False, local_files_only=True
|
||||
)
|
||||
replay_cfg = ReplayControlConfig(episode=0, fps=1, root=root, repo_id=repo_id, play_sounds=False)
|
||||
replay(robot, replay_cfg)
|
||||
|
||||
policy_cfg = ACTConfig()
|
||||
|
@ -266,7 +264,6 @@ def test_resume_record(tmp_path, request, robot_type, mock):
|
|||
video=False,
|
||||
display_cameras=False,
|
||||
play_sounds=False,
|
||||
local_files_only=True,
|
||||
num_episodes=1,
|
||||
)
|
||||
|
||||
|
|
|
@ -77,10 +77,6 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
|
|||
root_init = tmp_path / "init"
|
||||
dataset_init = lerobot_dataset_factory(root=root_init)
|
||||
|
||||
# Access the '_hub_version' cached_property in both instances to force its creation
|
||||
_ = dataset_init.meta._hub_version
|
||||
_ = dataset_create.meta._hub_version
|
||||
|
||||
init_attr = set(vars(dataset_init).keys())
|
||||
create_attr = set(vars(dataset_create).keys())
|
||||
|
||||
|
|
Loading…
Reference in New Issue