Fix datasets missing versions (#318)

This commit is contained in:
Simon Alibert 2024-07-16 23:02:31 +02:00 committed by GitHub
parent 5f5efe7cb9
commit 8865e19c12
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 156 additions and 120 deletions

View File

@ -35,15 +35,16 @@ from lerobot.common.datasets.utils import (
)
from lerobot.common.datasets.video_utils import VideoFrame, load_from_videos
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
# For maintainers, see lerobot/common/datasets/push_dataset_to_hub/codebase_version.md
CODEBASE_VERSION = "v1.5"
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
class LeRobotDataset(torch.utils.data.Dataset):
def __init__(
self,
repo_id: str,
version: str | None = CODEBASE_VERSION,
root: Path | None = DATA_DIR,
split: str = "train",
image_transforms: Callable | None = None,
@ -52,7 +53,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
):
super().__init__()
self.repo_id = repo_id
self.version = version
self.root = root
self.split = split
self.image_transforms = image_transforms
@ -60,16 +60,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
# load data from hub or locally when root is provided
# TODO(rcadene, aliberts): implement faster transfer
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
self.hf_dataset = load_hf_dataset(repo_id, version, root, split)
self.hf_dataset = load_hf_dataset(repo_id, CODEBASE_VERSION, root, split)
if split == "train":
self.episode_data_index = load_episode_data_index(repo_id, version, root)
self.episode_data_index = load_episode_data_index(repo_id, CODEBASE_VERSION, root)
else:
self.episode_data_index = calculate_episode_data_index(self.hf_dataset)
self.hf_dataset = reset_episode_index(self.hf_dataset)
self.stats = load_stats(repo_id, version, root)
self.info = load_info(repo_id, version, root)
self.stats = load_stats(repo_id, CODEBASE_VERSION, root)
self.info = load_info(repo_id, CODEBASE_VERSION, root)
if self.video:
self.videos_dir = load_videos(repo_id, version, root)
self.videos_dir = load_videos(repo_id, CODEBASE_VERSION, root)
self.video_backend = video_backend if video_backend is not None else "pyav"
@property
@ -164,7 +164,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
return (
f"{self.__class__.__name__}(\n"
f" Repository ID: '{self.repo_id}',\n"
f" Version: '{self.version}',\n"
f" Split: '{self.split}',\n"
f" Number of Samples: {self.num_samples},\n"
f" Number of Episodes: {self.num_episodes},\n"
@ -173,6 +172,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
f" Camera Keys: {self.camera_keys},\n"
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
f" Transformations: {self.image_transforms},\n"
f" Codebase Version: {self.info.get('codebase_version', '< v1.6')},\n"
f")"
)
@ -180,7 +180,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
def from_preloaded(
cls,
repo_id: str = "from_preloaded",
version: str | None = CODEBASE_VERSION,
root: Path | None = None,
split: str = "train",
transform: callable = None,
@ -204,7 +203,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
# create an empty object of type LeRobotDataset
obj = cls.__new__(cls)
obj.repo_id = repo_id
obj.version = version
obj.root = root
obj.split = split
obj.image_transforms = transform
@ -228,7 +226,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
def __init__(
self,
repo_ids: list[str],
version: str | None = CODEBASE_VERSION,
root: Path | None = DATA_DIR,
split: str = "train",
image_transforms: Callable | None = None,
@ -242,7 +239,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
self._datasets = [
LeRobotDataset(
repo_id,
version=version,
root=root,
split=split,
delta_timestamps=delta_timestamps,
@ -279,7 +275,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
)
self.disabled_data_keys.update(extra_keys)
self.version = version
self.root = root
self.split = split
self.image_transforms = image_transforms
@ -395,7 +390,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
return (
f"{self.__class__.__name__}(\n"
f" Repository IDs: '{self.repo_ids}',\n"
f" Version: '{self.version}',\n"
f" Split: '{self.split}',\n"
f" Number of Samples: {self.num_samples},\n"
f" Number of Episodes: {self.num_episodes},\n"

View File

@ -0,0 +1,57 @@
## Using / Updating `CODEBASE_VERSION` (for maintainers)
Since our dataset pushed to the hub are decoupled with the evolution of this repo, we ensure compatibility of
the datasets with our code, we use a `CODEBASE_VERSION` (defined in
lerobot/common/datasets/lerobot_dataset.py) variable.
For instance, [`lerobot/pusht`](https://huggingface.co/datasets/lerobot/pusht) has many versions to maintain backward compatibility between LeRobot codebase versions:
- [v1.0](https://huggingface.co/datasets/lerobot/pusht/tree/v1.0)
- [v1.1](https://huggingface.co/datasets/lerobot/pusht/tree/v1.1)
- [v1.2](https://huggingface.co/datasets/lerobot/pusht/tree/v1.2)
- [v1.3](https://huggingface.co/datasets/lerobot/pusht/tree/v1.3)
- [v1.4](https://huggingface.co/datasets/lerobot/pusht/tree/v1.4)
- [v1.5](https://huggingface.co/datasets/lerobot/pusht/tree/v1.5) <-- last version
- [main](https://huggingface.co/datasets/lerobot/pusht/tree/main) <-- points to the last version
Starting with v1.6, every dataset pushed to the hub or saved locally also have this version number in their
`info.json` metadata.
### Uploading a new dataset
If you are pushing a new dataset, you don't need to worry about any of the instructions below, nor to be
compatible with previous codebase versions. The `push_dataset_to_hub.py` script will automatically tag your
dataset with the current `CODEBASE_VERSION`.
### Updating an existing dataset
If you want to update an existing dataset, you need to change the `CODEBASE_VERSION` from `lerobot_dataset.py`
before running `push_dataset_to_hub.py`. This is especially useful if you introduce a breaking change
intentionally or not (i.e. something not backward compatible such as modifying the reward functions used,
deleting some frames at the end of an episode, etc.). That way, people running a previous version of the
codebase won't be affected by your change and backward compatibility is maintained.
However, you will need to update the version of ALL the other datasets so that they have the new
`CODEBASE_VERSION` as a branch in their hugging face dataset repository. Don't worry, there is an easy way
that doesn't require to run `push_dataset_to_hub.py`. You can just "branch-out" from the `main` branch on HF
dataset repo by running this script which corresponds to a `git checkout -b` (so no copy or upload needed):
```python
from huggingface_hub import HfApi
from lerobot import available_datasets
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
api = HfApi()
for repo_id in available_datasets:
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
branches = [b.name for b in dataset_info.branches]
if CODEBASE_VERSION in branches:
# First check if the newer version already exists.
print(f"Found existing branch for {repo_id}. Please contact a member of the core LeRobot team.")
print("Exiting early")
break
else:
# Now create a branch named after the new version by branching out from "main"
# which is expected to be the preceding version
api.create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION, revision="main")
print(f"{repo_id} successfully updated")
```

View File

@ -32,46 +32,41 @@ from pathlib import Path
from huggingface_hub import snapshot_download
AVAILABLE_RAW_REPO_IDS = [
"cadene/pusht_image_raw",
"cadene/xarm_lift_medium_image_raw",
"cadene/xarm_lift_medium_replay_image_raw",
"cadene/xarm_push_medium_image_raw",
"cadene/xarm_push_medium_replay_image_raw",
"cadene/aloha_sim_insertion_human_image_raw",
"cadene/aloha_sim_insertion_scripted_image_raw",
"cadene/aloha_sim_transfer_cube_human_image_raw",
"cadene/aloha_sim_transfer_cube_scripted_image_raw",
"cadene/pusht_raw",
"cadene/xarm_lift_medium_raw",
"cadene/xarm_lift_medium_replay_raw",
"cadene/xarm_push_medium_raw",
"cadene/xarm_push_medium_replay_raw",
"cadene/aloha_sim_insertion_human_raw",
"cadene/aloha_sim_insertion_scripted_raw",
"cadene/aloha_sim_transfer_cube_human_raw",
"cadene/aloha_sim_transfer_cube_scripted_raw",
"cadene/aloha_mobile_cabinet_raw",
"cadene/aloha_mobile_chair_raw",
"cadene/aloha_mobile_elevator_raw",
"cadene/aloha_mobile_shrimp_raw",
"cadene/aloha_mobile_wash_pan_raw",
"cadene/aloha_mobile_wipe_wine_raw",
"cadene/aloha_static_battery_raw",
"cadene/aloha_static_candy_raw",
"cadene/aloha_static_coffee_raw",
"cadene/aloha_static_coffee_new_raw",
"cadene/aloha_static_cups_open_raw",
"cadene/aloha_static_fork_pick_up_raw",
"cadene/aloha_static_pingpong_test_raw",
"cadene/aloha_static_pro_pencil_raw",
"cadene/aloha_static_screw_driver_raw",
"cadene/aloha_static_tape_raw",
"cadene/aloha_static_thread_velcro_raw",
"cadene/aloha_static_towel_raw",
"cadene/aloha_static_vinh_cup_raw",
"cadene/aloha_static_vinh_cup_left_raw",
"cadene/aloha_static_ziploc_slide_raw",
"cadene/umi_cup_in_the_wild_raw",
"lerobot-raw/aloha_mobile_cabinet_raw",
"lerobot-raw/aloha_mobile_chair_raw",
"lerobot-raw/aloha_mobile_elevator_raw",
"lerobot-raw/aloha_mobile_shrimp_raw",
"lerobot-raw/aloha_mobile_wash_pan_raw",
"lerobot-raw/aloha_mobile_wipe_wine_raw",
"lerobot-raw/aloha_sim_insertion_human_raw",
"lerobot-raw/aloha_sim_insertion_scripted_raw",
"lerobot-raw/aloha_sim_transfer_cube_human_raw",
"lerobot-raw/aloha_sim_transfer_cube_scripted_raw",
"lerobot-raw/aloha_static_battery_raw",
"lerobot-raw/aloha_static_candy_raw",
"lerobot-raw/aloha_static_coffee_new_raw",
"lerobot-raw/aloha_static_coffee_raw",
"lerobot-raw/aloha_static_cups_open_raw",
"lerobot-raw/aloha_static_fork_pick_up_raw",
"lerobot-raw/aloha_static_pingpong_test_raw",
"lerobot-raw/aloha_static_pro_pencil_raw",
"lerobot-raw/aloha_static_screw_driver_raw",
"lerobot-raw/aloha_static_tape_raw",
"lerobot-raw/aloha_static_thread_velcro_raw",
"lerobot-raw/aloha_static_towel_raw",
"lerobot-raw/aloha_static_vinh_cup_left_raw",
"lerobot-raw/aloha_static_vinh_cup_raw",
"lerobot-raw/aloha_static_ziploc_slide_raw",
"lerobot-raw/pusht_raw",
"lerobot-raw/umi_cup_in_the_wild_raw",
"lerobot-raw/unitreeh1_fold_clothes_raw",
"lerobot-raw/unitreeh1_rearrange_objects_raw",
"lerobot-raw/unitreeh1_two_robot_greeting_raw",
"lerobot-raw/unitreeh1_warehouse_raw",
"lerobot-raw/xarm_lift_medium_raw",
"lerobot-raw/xarm_lift_medium_replay_raw",
"lerobot-raw/xarm_push_medium_raw",
"lerobot-raw/xarm_push_medium_replay_raw",
]
@ -89,7 +84,6 @@ def download_raw(raw_dir: Path, repo_id: str):
stacklevel=1,
)
raw_dir = Path(raw_dir)
# Send warning if raw_dir isn't well formated
if raw_dir.parts[-2] != user_id or raw_dir.parts[-1] != dataset_id:
warnings.warn(
@ -99,7 +93,7 @@ def download_raw(raw_dir: Path, repo_id: str):
raw_dir.mkdir(parents=True, exist_ok=True)
logging.info(f"Start downloading from huggingface.co/{user_id} for {dataset_id}")
snapshot_download(f"{repo_id}", repo_type="dataset", local_dir=raw_dir)
snapshot_download(repo_id, repo_type="dataset", local_dir=raw_dir)
logging.info(f"Finish downloading from huggingface.co/{user_id} for {dataset_id}")

View File

@ -28,6 +28,7 @@ import tqdm
from datasets import Dataset, Features, Image, Sequence, Value
from PIL import Image as PILImage
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
from lerobot.common.datasets.utils import (
calculate_episode_data_index,
@ -210,6 +211,7 @@ def from_raw_to_lerobot_format(
hf_dataset = to_hf_dataset(data_dict, video)
episode_data_index = calculate_episode_data_index(hf_dataset)
info = {
"codebase_version": CODEBASE_VERSION,
"fps": fps,
"video": video,
}

View File

@ -23,6 +23,7 @@ import torch
from datasets import Dataset, Features, Image, Value
from PIL import Image as PILImage
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes
from lerobot.common.datasets.utils import calculate_episode_data_index, hf_transform_to_torch
from lerobot.common.datasets.video_utils import VideoFrame
@ -95,6 +96,7 @@ def from_raw_to_lerobot_format(
hf_dataset = to_hf_dataset(data_dict, video)
episode_data_index = calculate_episode_data_index(hf_dataset)
info = {
"codebase_version": CODEBASE_VERSION,
"fps": fps,
"video": video,
}

View File

@ -24,6 +24,7 @@ import pandas as pd
import torch
from datasets import Dataset, Features, Image, Sequence, Value
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
from lerobot.common.datasets.utils import (
calculate_episode_data_index,
hf_transform_to_torch,
@ -214,6 +215,7 @@ def from_raw_to_lerobot_format(
hf_dataset = to_hf_dataset(data_df, video)
episode_data_index = calculate_episode_data_index(hf_dataset)
info = {
"codebase_version": CODEBASE_VERSION,
"fps": fps,
"video": video,
}

View File

@ -25,6 +25,7 @@ import zarr
from datasets import Dataset, Features, Image, Sequence, Value
from PIL import Image as PILImage
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
from lerobot.common.datasets.utils import (
calculate_episode_data_index,
@ -258,6 +259,7 @@ def from_raw_to_lerobot_format(
hf_dataset = to_hf_dataset(data_dict, video, keypoints_instead_of_image)
episode_data_index = calculate_episode_data_index(hf_dataset)
info = {
"codebase_version": CODEBASE_VERSION,
"fps": fps,
"video": video if not keypoints_instead_of_image else 0,
}

View File

@ -25,6 +25,7 @@ import zarr
from datasets import Dataset, Features, Image, Sequence, Value
from PIL import Image as PILImage
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
from lerobot.common.datasets.push_dataset_to_hub._umi_imagecodecs_numcodecs import register_codecs
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
from lerobot.common.datasets.utils import (
@ -199,6 +200,7 @@ def from_raw_to_lerobot_format(
hf_dataset = to_hf_dataset(data_dict, video)
episode_data_index = calculate_episode_data_index(hf_dataset)
info = {
"codebase_version": CODEBASE_VERSION,
"fps": fps,
"video": video,
}

View File

@ -25,6 +25,7 @@ import tqdm
from datasets import Dataset, Features, Image, Sequence, Value
from PIL import Image as PILImage
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
from lerobot.common.datasets.utils import (
calculate_episode_data_index,
@ -177,6 +178,7 @@ def from_raw_to_lerobot_format(
hf_dataset = to_hf_dataset(data_dict, video)
episode_data_index = calculate_episode_data_index(hf_dataset)
info = {
"codebase_version": CODEBASE_VERSION,
"fps": fps,
"video": video,
}

View File

@ -15,13 +15,15 @@
# limitations under the License.
import json
import re
import warnings
from functools import cache
from pathlib import Path
from typing import Dict
import datasets
import torch
from datasets import load_dataset, load_from_disk
from huggingface_hub import hf_hub_download, snapshot_download
from huggingface_hub import HfApi, hf_hub_download, snapshot_download
from PIL import Image as PILImage
from safetensors.torch import load_file
from torchvision import transforms
@ -80,7 +82,28 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
return items_dict
def load_hf_dataset(repo_id, version, root, split) -> datasets.Dataset:
@cache
def get_hf_dataset_safe_version(repo_id: str, version: str) -> str:
api = HfApi()
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
branches = [b.name for b in dataset_info.branches]
if version not in branches:
warnings.warn(
f"""You are trying to load a dataset from {repo_id} created with a previous version of the
codebase. The following versions are available: {branches}.
The requested version ('{version}') is not found. You should be fine since
backward compatibility is maintained. If you encounter a problem, contact LeRobot maintainers on
Discord ('https://discord.com/invite/s3KuuzsPFb') or open an issue on github.""",
stacklevel=1,
)
if "main" not in branches:
raise ValueError(f"Version 'main' not found on {repo_id}")
return "main"
else:
return version
def load_hf_dataset(repo_id: str, version: str, root: Path, split: str) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
if root is not None:
hf_dataset = load_from_disk(str(Path(root) / repo_id / "train"))
@ -101,7 +124,9 @@ def load_hf_dataset(repo_id, version, root, split) -> datasets.Dataset:
f'`split` ({split}) should either be "train", "train[INT:]", or "train[:INT]"'
)
else:
hf_dataset = load_dataset(repo_id, revision=version, split=split)
safe_version = get_hf_dataset_safe_version(repo_id, version)
hf_dataset = load_dataset(repo_id, revision=safe_version, split=split)
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset
@ -119,8 +144,9 @@ def load_episode_data_index(repo_id, version, root) -> dict[str, torch.Tensor]:
if root is not None:
path = Path(root) / repo_id / "meta_data" / "episode_data_index.safetensors"
else:
safe_version = get_hf_dataset_safe_version(repo_id, version)
path = hf_hub_download(
repo_id, "meta_data/episode_data_index.safetensors", repo_type="dataset", revision=version
repo_id, "meta_data/episode_data_index.safetensors", repo_type="dataset", revision=safe_version
)
return load_file(path)
@ -137,7 +163,10 @@ def load_stats(repo_id, version, root) -> dict[str, dict[str, torch.Tensor]]:
if root is not None:
path = Path(root) / repo_id / "meta_data" / "stats.safetensors"
else:
path = hf_hub_download(repo_id, "meta_data/stats.safetensors", repo_type="dataset", revision=version)
safe_version = get_hf_dataset_safe_version(repo_id, version)
path = hf_hub_download(
repo_id, "meta_data/stats.safetensors", repo_type="dataset", revision=safe_version
)
stats = load_file(path)
return unflatten_dict(stats)
@ -154,7 +183,8 @@ def load_info(repo_id, version, root) -> dict:
if root is not None:
path = Path(root) / repo_id / "meta_data" / "info.json"
else:
path = hf_hub_download(repo_id, "meta_data/info.json", repo_type="dataset", revision=version)
safe_version = get_hf_dataset_safe_version(repo_id, version)
path = hf_hub_download(repo_id, "meta_data/info.json", repo_type="dataset", revision=safe_version)
with open(path) as f:
info = json.load(f)
@ -166,7 +196,8 @@ def load_videos(repo_id, version, root) -> Path:
path = Path(root) / repo_id / "videos"
else:
# TODO(rcadene): we download the whole repo here. see if we can avoid this
repo_dir = snapshot_download(repo_id, repo_type="dataset", revision=version)
safe_version = get_hf_dataset_safe_version(repo_id, version)
repo_dir = snapshot_download(repo_id, repo_type="dataset", revision=safe_version)
path = Path(repo_dir) / "videos"
return path

View File

@ -475,6 +475,7 @@ def record_dataset(
hf_dataset = to_hf_dataset(data_dict, video)
episode_data_index = calculate_episode_data_index(hf_dataset)
info = {
"codebase_version": CODEBASE_VERSION,
"fps": fps,
"video": video,
}

View File

@ -40,60 +40,6 @@ python lerobot/scripts/push_dataset_to_hub.py \
--raw-format umi_zarr \
--repo-id lerobot/umi_cup_in_the_wild
```
**WARNING: Updating an existing dataset**
If you want to update an existing dataset, you need to change the `CODEBASE_VERSION` from `lerobot_dataset.py`
before running `push_dataset_to_hub.py`. This is especially useful if you introduce a breaking change
intentionally or not (i.e. something not backward compatible such as modifying the reward functions used,
deleting some frames at the end of an episode, etc.). That way, people running a previous version of the
codebase won't be affected by your change and backward compatibility is maintained.
For instance, Pusht has many versions to maintain backward compatibility between LeRobot codebase versions:
- [v1.0](https://huggingface.co/datasets/lerobot/pusht/tree/v1.0)
- [v1.1](https://huggingface.co/datasets/lerobot/pusht/tree/v1.1)
- [v1.2](https://huggingface.co/datasets/lerobot/pusht/tree/v1.2)
- [v1.3](https://huggingface.co/datasets/lerobot/pusht/tree/v1.3)
- [v1.4](https://huggingface.co/datasets/lerobot/pusht/tree/v1.4)
- [v1.5](https://huggingface.co/datasets/lerobot/pusht/tree/v1.5) <-- last version
- [main](https://huggingface.co/datasets/lerobot/pusht/tree/main) <-- points to the last version
However, you will need to update the version of ALL the other datasets so that they have the new
`CODEBASE_VERSION` as a branch in their hugging face dataset repository. Don't worry, there is an easy way
that doesn't require to run `push_dataset_to_hub.py`. You can just "branch-out" from the `main` branch on HF
dataset repo by running this script which corresponds to a `git checkout -b` (so no copy or upload needed):
```python
import os
from huggingface_hub import create_branch, hf_hub_download
from huggingface_hub.utils._errors import RepositoryNotFoundError
from lerobot import available_datasets
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" # makes it easier to see the print-out below
NEW_CODEBASE_VERSION = "v1.5" # REPLACE THIS WITH YOUR DESIRED VERSION
for repo_id in available_datasets:
# First check if the newer version already exists.
try:
hf_hub_download(
repo_id=repo_id, repo_type="dataset", filename=".gitattributes", revision=NEW_CODEBASE_VERSION
)
print(f"Found existing branch for {repo_id}. Please contact a member of the core LeRobot team.")
print("Exiting early")
break
except RepositoryNotFoundError:
# Now create a branch.
create_branch(repo_id, repo_type="dataset", branch=NEW_CODEBASE_VERSION, revision=CODEBASE_VERSION)
print(f"{repo_id} successfully updated")
```
On the other hand, if you are pushing a new dataset, you don't need to worry about any of the instructions
above, nor to be compatible with previous codebase versions.
"""
import argparse
@ -104,7 +50,7 @@ from pathlib import Path
from typing import Any
import torch
from huggingface_hub import HfApi, create_branch
from huggingface_hub import HfApi
from safetensors.torch import save_file
from lerobot.common.datasets.compute_stats import compute_stats
@ -270,7 +216,8 @@ def push_dataset_to_hub(
push_meta_data_to_hub(repo_id, meta_data_dir, revision="main")
if video:
push_videos_to_hub(repo_id, videos_dir, revision="main")
create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION)
api = HfApi()
api.create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION)
if tests_data_dir:
# get the first episode