diff --git a/examples/advanced/1_train_act_pusht/act_pusht.yaml b/examples/advanced/1_train_act_pusht/act_pusht.yaml index 38e542fb..4963e11c 100644 --- a/examples/advanced/1_train_act_pusht/act_pusht.yaml +++ b/examples/advanced/1_train_act_pusht/act_pusht.yaml @@ -80,7 +80,7 @@ policy: n_vae_encoder_layers: 4 # Inference. - temporal_ensemble_momentum: null + temporal_ensemble_coeff: null # Training and loss computation. dropout: 0.1 diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 1bf336e0..29800c5c 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -35,15 +35,16 @@ from lerobot.common.datasets.utils import ( ) from lerobot.common.datasets.video_utils import VideoFrame, load_from_videos -DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None +# For maintainers, see lerobot/common/datasets/push_dataset_to_hub/codebase_version.md CODEBASE_VERSION = "v1.5" +DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None + class LeRobotDataset(torch.utils.data.Dataset): def __init__( self, repo_id: str, - version: str | None = CODEBASE_VERSION, root: Path | None = DATA_DIR, split: str = "train", image_transforms: Callable | None = None, @@ -52,7 +53,6 @@ class LeRobotDataset(torch.utils.data.Dataset): ): super().__init__() self.repo_id = repo_id - self.version = version self.root = root self.split = split self.image_transforms = image_transforms @@ -60,16 +60,16 @@ class LeRobotDataset(torch.utils.data.Dataset): # load data from hub or locally when root is provided # TODO(rcadene, aliberts): implement faster transfer # https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads - self.hf_dataset = load_hf_dataset(repo_id, version, root, split) + self.hf_dataset = load_hf_dataset(repo_id, CODEBASE_VERSION, root, split) if split == "train": - self.episode_data_index = load_episode_data_index(repo_id, version, root) + self.episode_data_index = load_episode_data_index(repo_id, CODEBASE_VERSION, root) else: self.episode_data_index = calculate_episode_data_index(self.hf_dataset) self.hf_dataset = reset_episode_index(self.hf_dataset) - self.stats = load_stats(repo_id, version, root) - self.info = load_info(repo_id, version, root) + self.stats = load_stats(repo_id, CODEBASE_VERSION, root) + self.info = load_info(repo_id, CODEBASE_VERSION, root) if self.video: - self.videos_dir = load_videos(repo_id, version, root) + self.videos_dir = load_videos(repo_id, CODEBASE_VERSION, root) self.video_backend = video_backend if video_backend is not None else "pyav" @property @@ -164,7 +164,6 @@ class LeRobotDataset(torch.utils.data.Dataset): return ( f"{self.__class__.__name__}(\n" f" Repository ID: '{self.repo_id}',\n" - f" Version: '{self.version}',\n" f" Split: '{self.split}',\n" f" Number of Samples: {self.num_samples},\n" f" Number of Episodes: {self.num_episodes},\n" @@ -173,6 +172,7 @@ class LeRobotDataset(torch.utils.data.Dataset): f" Camera Keys: {self.camera_keys},\n" f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n" f" Transformations: {self.image_transforms},\n" + f" Codebase Version: {self.info.get('codebase_version', '< v1.6')},\n" f")" ) @@ -180,7 +180,6 @@ class LeRobotDataset(torch.utils.data.Dataset): def from_preloaded( cls, repo_id: str = "from_preloaded", - version: str | None = CODEBASE_VERSION, root: Path | None = None, split: str = "train", transform: callable = None, @@ -204,7 +203,6 @@ class LeRobotDataset(torch.utils.data.Dataset): # create an empty object of type LeRobotDataset obj = cls.__new__(cls) obj.repo_id = repo_id - obj.version = version obj.root = root obj.split = split obj.image_transforms = transform @@ -228,7 +226,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): def __init__( self, repo_ids: list[str], - version: str | None = CODEBASE_VERSION, root: Path | None = DATA_DIR, split: str = "train", image_transforms: Callable | None = None, @@ -242,7 +239,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): self._datasets = [ LeRobotDataset( repo_id, - version=version, root=root, split=split, delta_timestamps=delta_timestamps, @@ -279,7 +275,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): ) self.disabled_data_keys.update(extra_keys) - self.version = version self.root = root self.split = split self.image_transforms = image_transforms @@ -395,7 +390,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): return ( f"{self.__class__.__name__}(\n" f" Repository IDs: '{self.repo_ids}',\n" - f" Version: '{self.version}',\n" f" Split: '{self.split}',\n" f" Number of Samples: {self.num_samples},\n" f" Number of Episodes: {self.num_episodes},\n" diff --git a/lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md b/lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md new file mode 100644 index 00000000..77948b02 --- /dev/null +++ b/lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md @@ -0,0 +1,57 @@ +## Using / Updating `CODEBASE_VERSION` (for maintainers) + +Since our dataset pushed to the hub are decoupled with the evolution of this repo, we ensure compatibility of +the datasets with our code, we use a `CODEBASE_VERSION` (defined in +lerobot/common/datasets/lerobot_dataset.py) variable. + +For instance, [`lerobot/pusht`](https://huggingface.co/datasets/lerobot/pusht) has many versions to maintain backward compatibility between LeRobot codebase versions: +- [v1.0](https://huggingface.co/datasets/lerobot/pusht/tree/v1.0) +- [v1.1](https://huggingface.co/datasets/lerobot/pusht/tree/v1.1) +- [v1.2](https://huggingface.co/datasets/lerobot/pusht/tree/v1.2) +- [v1.3](https://huggingface.co/datasets/lerobot/pusht/tree/v1.3) +- [v1.4](https://huggingface.co/datasets/lerobot/pusht/tree/v1.4) +- [v1.5](https://huggingface.co/datasets/lerobot/pusht/tree/v1.5) <-- last version +- [main](https://huggingface.co/datasets/lerobot/pusht/tree/main) <-- points to the last version + +Starting with v1.6, every dataset pushed to the hub or saved locally also have this version number in their +`info.json` metadata. + +### Uploading a new dataset +If you are pushing a new dataset, you don't need to worry about any of the instructions below, nor to be +compatible with previous codebase versions. The `push_dataset_to_hub.py` script will automatically tag your +dataset with the current `CODEBASE_VERSION`. + +### Updating an existing dataset +If you want to update an existing dataset, you need to change the `CODEBASE_VERSION` from `lerobot_dataset.py` +before running `push_dataset_to_hub.py`. This is especially useful if you introduce a breaking change +intentionally or not (i.e. something not backward compatible such as modifying the reward functions used, +deleting some frames at the end of an episode, etc.). That way, people running a previous version of the +codebase won't be affected by your change and backward compatibility is maintained. + +However, you will need to update the version of ALL the other datasets so that they have the new +`CODEBASE_VERSION` as a branch in their hugging face dataset repository. Don't worry, there is an easy way +that doesn't require to run `push_dataset_to_hub.py`. You can just "branch-out" from the `main` branch on HF +dataset repo by running this script which corresponds to a `git checkout -b` (so no copy or upload needed): + +```python +from huggingface_hub import HfApi + +from lerobot import available_datasets +from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION + +api = HfApi() + +for repo_id in available_datasets: + dataset_info = api.list_repo_refs(repo_id, repo_type="dataset") + branches = [b.name for b in dataset_info.branches] + if CODEBASE_VERSION in branches: + # First check if the newer version already exists. + print(f"Found existing branch for {repo_id}. Please contact a member of the core LeRobot team.") + print("Exiting early") + break + else: + # Now create a branch named after the new version by branching out from "main" + # which is expected to be the preceding version + api.create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION, revision="main") + print(f"{repo_id} successfully updated") +``` diff --git a/lerobot/common/datasets/push_dataset_to_hub/_download_raw.py b/lerobot/common/datasets/push_dataset_to_hub/_download_raw.py index 91ba9ef1..b630bbca 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/_download_raw.py +++ b/lerobot/common/datasets/push_dataset_to_hub/_download_raw.py @@ -32,46 +32,41 @@ from pathlib import Path from huggingface_hub import snapshot_download AVAILABLE_RAW_REPO_IDS = [ - "cadene/pusht_image_raw", - "cadene/xarm_lift_medium_image_raw", - "cadene/xarm_lift_medium_replay_image_raw", - "cadene/xarm_push_medium_image_raw", - "cadene/xarm_push_medium_replay_image_raw", - "cadene/aloha_sim_insertion_human_image_raw", - "cadene/aloha_sim_insertion_scripted_image_raw", - "cadene/aloha_sim_transfer_cube_human_image_raw", - "cadene/aloha_sim_transfer_cube_scripted_image_raw", - "cadene/pusht_raw", - "cadene/xarm_lift_medium_raw", - "cadene/xarm_lift_medium_replay_raw", - "cadene/xarm_push_medium_raw", - "cadene/xarm_push_medium_replay_raw", - "cadene/aloha_sim_insertion_human_raw", - "cadene/aloha_sim_insertion_scripted_raw", - "cadene/aloha_sim_transfer_cube_human_raw", - "cadene/aloha_sim_transfer_cube_scripted_raw", - "cadene/aloha_mobile_cabinet_raw", - "cadene/aloha_mobile_chair_raw", - "cadene/aloha_mobile_elevator_raw", - "cadene/aloha_mobile_shrimp_raw", - "cadene/aloha_mobile_wash_pan_raw", - "cadene/aloha_mobile_wipe_wine_raw", - "cadene/aloha_static_battery_raw", - "cadene/aloha_static_candy_raw", - "cadene/aloha_static_coffee_raw", - "cadene/aloha_static_coffee_new_raw", - "cadene/aloha_static_cups_open_raw", - "cadene/aloha_static_fork_pick_up_raw", - "cadene/aloha_static_pingpong_test_raw", - "cadene/aloha_static_pro_pencil_raw", - "cadene/aloha_static_screw_driver_raw", - "cadene/aloha_static_tape_raw", - "cadene/aloha_static_thread_velcro_raw", - "cadene/aloha_static_towel_raw", - "cadene/aloha_static_vinh_cup_raw", - "cadene/aloha_static_vinh_cup_left_raw", - "cadene/aloha_static_ziploc_slide_raw", - "cadene/umi_cup_in_the_wild_raw", + "lerobot-raw/aloha_mobile_cabinet_raw", + "lerobot-raw/aloha_mobile_chair_raw", + "lerobot-raw/aloha_mobile_elevator_raw", + "lerobot-raw/aloha_mobile_shrimp_raw", + "lerobot-raw/aloha_mobile_wash_pan_raw", + "lerobot-raw/aloha_mobile_wipe_wine_raw", + "lerobot-raw/aloha_sim_insertion_human_raw", + "lerobot-raw/aloha_sim_insertion_scripted_raw", + "lerobot-raw/aloha_sim_transfer_cube_human_raw", + "lerobot-raw/aloha_sim_transfer_cube_scripted_raw", + "lerobot-raw/aloha_static_battery_raw", + "lerobot-raw/aloha_static_candy_raw", + "lerobot-raw/aloha_static_coffee_new_raw", + "lerobot-raw/aloha_static_coffee_raw", + "lerobot-raw/aloha_static_cups_open_raw", + "lerobot-raw/aloha_static_fork_pick_up_raw", + "lerobot-raw/aloha_static_pingpong_test_raw", + "lerobot-raw/aloha_static_pro_pencil_raw", + "lerobot-raw/aloha_static_screw_driver_raw", + "lerobot-raw/aloha_static_tape_raw", + "lerobot-raw/aloha_static_thread_velcro_raw", + "lerobot-raw/aloha_static_towel_raw", + "lerobot-raw/aloha_static_vinh_cup_left_raw", + "lerobot-raw/aloha_static_vinh_cup_raw", + "lerobot-raw/aloha_static_ziploc_slide_raw", + "lerobot-raw/pusht_raw", + "lerobot-raw/umi_cup_in_the_wild_raw", + "lerobot-raw/unitreeh1_fold_clothes_raw", + "lerobot-raw/unitreeh1_rearrange_objects_raw", + "lerobot-raw/unitreeh1_two_robot_greeting_raw", + "lerobot-raw/unitreeh1_warehouse_raw", + "lerobot-raw/xarm_lift_medium_raw", + "lerobot-raw/xarm_lift_medium_replay_raw", + "lerobot-raw/xarm_push_medium_raw", + "lerobot-raw/xarm_push_medium_replay_raw", ] @@ -89,7 +84,6 @@ def download_raw(raw_dir: Path, repo_id: str): stacklevel=1, ) - raw_dir = Path(raw_dir) # Send warning if raw_dir isn't well formated if raw_dir.parts[-2] != user_id or raw_dir.parts[-1] != dataset_id: warnings.warn( @@ -99,7 +93,7 @@ def download_raw(raw_dir: Path, repo_id: str): raw_dir.mkdir(parents=True, exist_ok=True) logging.info(f"Start downloading from huggingface.co/{user_id} for {dataset_id}") - snapshot_download(f"{repo_id}", repo_type="dataset", local_dir=raw_dir) + snapshot_download(repo_id, repo_type="dataset", local_dir=raw_dir) logging.info(f"Finish downloading from huggingface.co/{user_id} for {dataset_id}") diff --git a/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py b/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py index 024045a0..24873ca2 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py @@ -28,6 +28,7 @@ import tqdm from datasets import Dataset, Features, Image, Sequence, Value from PIL import Image as PILImage +from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently from lerobot.common.datasets.utils import ( calculate_episode_data_index, @@ -210,6 +211,7 @@ def from_raw_to_lerobot_format( hf_dataset = to_hf_dataset(data_dict, video) episode_data_index = calculate_episode_data_index(hf_dataset) info = { + "codebase_version": CODEBASE_VERSION, "fps": fps, "video": video, } diff --git a/lerobot/common/datasets/push_dataset_to_hub/cam_png_format.py b/lerobot/common/datasets/push_dataset_to_hub/cam_png_format.py index 4972e6b4..52eabd99 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/cam_png_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/cam_png_format.py @@ -23,6 +23,7 @@ import torch from datasets import Dataset, Features, Image, Value from PIL import Image as PILImage +from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes from lerobot.common.datasets.utils import calculate_episode_data_index, hf_transform_to_torch from lerobot.common.datasets.video_utils import VideoFrame @@ -95,6 +96,7 @@ def from_raw_to_lerobot_format( hf_dataset = to_hf_dataset(data_dict, video) episode_data_index = calculate_episode_data_index(hf_dataset) info = { + "codebase_version": CODEBASE_VERSION, "fps": fps, "video": video, } diff --git a/lerobot/common/datasets/push_dataset_to_hub/dora_parquet_format.py b/lerobot/common/datasets/push_dataset_to_hub/dora_parquet_format.py index 1dc2e67e..832f3af2 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/dora_parquet_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/dora_parquet_format.py @@ -24,6 +24,7 @@ import pandas as pd import torch from datasets import Dataset, Features, Image, Sequence, Value +from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION from lerobot.common.datasets.utils import ( calculate_episode_data_index, hf_transform_to_torch, @@ -214,6 +215,7 @@ def from_raw_to_lerobot_format( hf_dataset = to_hf_dataset(data_df, video) episode_data_index = calculate_episode_data_index(hf_dataset) info = { + "codebase_version": CODEBASE_VERSION, "fps": fps, "video": video, } diff --git a/lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py b/lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py index 69b23a47..54043eee 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py @@ -25,6 +25,7 @@ import zarr from datasets import Dataset, Features, Image, Sequence, Value from PIL import Image as PILImage +from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently from lerobot.common.datasets.utils import ( calculate_episode_data_index, @@ -258,6 +259,7 @@ def from_raw_to_lerobot_format( hf_dataset = to_hf_dataset(data_dict, video, keypoints_instead_of_image) episode_data_index = calculate_episode_data_index(hf_dataset) info = { + "codebase_version": CODEBASE_VERSION, "fps": fps, "video": video if not keypoints_instead_of_image else 0, } diff --git a/lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py b/lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py index 6cd80c61..f9ac849c 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py @@ -25,6 +25,7 @@ import zarr from datasets import Dataset, Features, Image, Sequence, Value from PIL import Image as PILImage +from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION from lerobot.common.datasets.push_dataset_to_hub._umi_imagecodecs_numcodecs import register_codecs from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently from lerobot.common.datasets.utils import ( @@ -199,6 +200,7 @@ def from_raw_to_lerobot_format( hf_dataset = to_hf_dataset(data_dict, video) episode_data_index = calculate_episode_data_index(hf_dataset) info = { + "codebase_version": CODEBASE_VERSION, "fps": fps, "video": video, } diff --git a/lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py b/lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py index 57a36dba..d6ffbea1 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py @@ -25,6 +25,7 @@ import tqdm from datasets import Dataset, Features, Image, Sequence, Value from PIL import Image as PILImage +from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently from lerobot.common.datasets.utils import ( calculate_episode_data_index, @@ -177,6 +178,7 @@ def from_raw_to_lerobot_format( hf_dataset = to_hf_dataset(data_dict, video) episode_data_index = calculate_episode_data_index(hf_dataset) info = { + "codebase_version": CODEBASE_VERSION, "fps": fps, "video": video, } diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index cb2fee95..af1a3db6 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -15,13 +15,15 @@ # limitations under the License. import json import re +import warnings +from functools import cache from pathlib import Path from typing import Dict import datasets import torch from datasets import load_dataset, load_from_disk -from huggingface_hub import hf_hub_download, snapshot_download +from huggingface_hub import HfApi, hf_hub_download, snapshot_download from PIL import Image as PILImage from safetensors.torch import load_file from torchvision import transforms @@ -80,7 +82,28 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]): return items_dict -def load_hf_dataset(repo_id, version, root, split) -> datasets.Dataset: +@cache +def get_hf_dataset_safe_version(repo_id: str, version: str) -> str: + api = HfApi() + dataset_info = api.list_repo_refs(repo_id, repo_type="dataset") + branches = [b.name for b in dataset_info.branches] + if version not in branches: + warnings.warn( + f"""You are trying to load a dataset from {repo_id} created with a previous version of the + codebase. The following versions are available: {branches}. + The requested version ('{version}') is not found. You should be fine since + backward compatibility is maintained. If you encounter a problem, contact LeRobot maintainers on + Discord ('https://discord.com/invite/s3KuuzsPFb') or open an issue on github.""", + stacklevel=1, + ) + if "main" not in branches: + raise ValueError(f"Version 'main' not found on {repo_id}") + return "main" + else: + return version + + +def load_hf_dataset(repo_id: str, version: str, root: Path, split: str) -> datasets.Dataset: """hf_dataset contains all the observations, states, actions, rewards, etc.""" if root is not None: hf_dataset = load_from_disk(str(Path(root) / repo_id / "train")) @@ -101,7 +124,9 @@ def load_hf_dataset(repo_id, version, root, split) -> datasets.Dataset: f'`split` ({split}) should either be "train", "train[INT:]", or "train[:INT]"' ) else: - hf_dataset = load_dataset(repo_id, revision=version, split=split) + safe_version = get_hf_dataset_safe_version(repo_id, version) + hf_dataset = load_dataset(repo_id, revision=safe_version, split=split) + hf_dataset.set_transform(hf_transform_to_torch) return hf_dataset @@ -119,8 +144,9 @@ def load_episode_data_index(repo_id, version, root) -> dict[str, torch.Tensor]: if root is not None: path = Path(root) / repo_id / "meta_data" / "episode_data_index.safetensors" else: + safe_version = get_hf_dataset_safe_version(repo_id, version) path = hf_hub_download( - repo_id, "meta_data/episode_data_index.safetensors", repo_type="dataset", revision=version + repo_id, "meta_data/episode_data_index.safetensors", repo_type="dataset", revision=safe_version ) return load_file(path) @@ -137,7 +163,10 @@ def load_stats(repo_id, version, root) -> dict[str, dict[str, torch.Tensor]]: if root is not None: path = Path(root) / repo_id / "meta_data" / "stats.safetensors" else: - path = hf_hub_download(repo_id, "meta_data/stats.safetensors", repo_type="dataset", revision=version) + safe_version = get_hf_dataset_safe_version(repo_id, version) + path = hf_hub_download( + repo_id, "meta_data/stats.safetensors", repo_type="dataset", revision=safe_version + ) stats = load_file(path) return unflatten_dict(stats) @@ -154,7 +183,8 @@ def load_info(repo_id, version, root) -> dict: if root is not None: path = Path(root) / repo_id / "meta_data" / "info.json" else: - path = hf_hub_download(repo_id, "meta_data/info.json", repo_type="dataset", revision=version) + safe_version = get_hf_dataset_safe_version(repo_id, version) + path = hf_hub_download(repo_id, "meta_data/info.json", repo_type="dataset", revision=safe_version) with open(path) as f: info = json.load(f) @@ -166,7 +196,8 @@ def load_videos(repo_id, version, root) -> Path: path = Path(root) / repo_id / "videos" else: # TODO(rcadene): we download the whole repo here. see if we can avoid this - repo_dir = snapshot_download(repo_id, repo_type="dataset", revision=version) + safe_version = get_hf_dataset_safe_version(repo_id, version) + repo_dir = snapshot_download(repo_id, repo_type="dataset", revision=safe_version) path = Path(repo_dir) / "videos" return path diff --git a/lerobot/common/policies/act/configuration_act.py b/lerobot/common/policies/act/configuration_act.py index 92a52eac..a86c359c 100644 --- a/lerobot/common/policies/act/configuration_act.py +++ b/lerobot/common/policies/act/configuration_act.py @@ -76,12 +76,10 @@ class ACTConfig: documentation in the policy class). latent_dim: The VAE's latent dimension. n_vae_encoder_layers: The number of transformer layers to use for the VAE's encoder. - temporal_ensemble_momentum: Exponential moving average (EMA) momentum parameter (α) for ensembling - actions for a given time step over multiple policy invocations. Updates are calculated as: - x⁻ₙ = αx⁻ₙ₋₁ + (1-α)xₙ. Note that the ACT paper and original ACT code describes a different - parameter here: they refer to a weighting scheme wᵢ = exp(-m⋅i) and set m = 0.01. With our - formulation, this is equivalent to α = exp(-0.01) ≈ 0.99. When this parameter is provided, we - require `n_action_steps == 1` (since we need to query the policy every step anyway). + temporal_ensemble_coeff: Coefficient for the exponential weighting scheme to apply for temporal + ensembling. Defaults to None which means temporal ensembling is not used. `n_action_steps` must be + 1 when using this feature, as inference needs to happen at every step to form an ensemble. For + more information on how ensembling works, please see `ACTTemporalEnsembler`. dropout: Dropout to use in the transformer layers (see code for details). kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`. @@ -139,7 +137,8 @@ class ACTConfig: n_vae_encoder_layers: int = 4 # Inference. - temporal_ensemble_momentum: float | None = None + # Note: the value used in ACT when temporal ensembling is enabled is 0.01. + temporal_ensemble_coeff: float | None = None # Training and loss computation. dropout: float = 0.1 @@ -151,7 +150,7 @@ class ACTConfig: raise ValueError( f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}." ) - if self.temporal_ensemble_momentum is not None and self.n_action_steps > 1: + if self.temporal_ensemble_coeff is not None and self.n_action_steps > 1: raise NotImplementedError( "`n_action_steps` must be 1 when using temporal ensembling. This is " "because the policy needs to be queried every step to compute the ensembled action." diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 0a236100..c072c31e 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -77,12 +77,15 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin): self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] + if config.temporal_ensemble_coeff is not None: + self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size) + self.reset() def reset(self): """This should be called whenever the environment is reset.""" - if self.config.temporal_ensemble_momentum is not None: - self._ensembled_actions = None + if self.config.temporal_ensemble_coeff is not None: + self.temporal_ensembler.reset() else: self._action_queue = deque([], maxlen=self.config.n_action_steps) @@ -100,24 +103,12 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin): if len(self.expected_image_keys) > 0: batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) - # If we are doing temporal ensembling, keep track of the exponential moving average (EMA), and return - # the first action. - if self.config.temporal_ensemble_momentum is not None: + # If we are doing temporal ensembling, do online updates where we keep track of the number of actions + # we are ensembling over. + if self.config.temporal_ensemble_coeff is not None: actions = self.model(batch)[0] # (batch_size, chunk_size, action_dim) actions = self.unnormalize_outputs({"action": actions})["action"] - if self._ensembled_actions is None: - # Initializes `self._ensembled_action` to the sequence of actions predicted during the first - # time step of the episode. - self._ensembled_actions = actions.clone() - else: - # self._ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute - # the EMA update for those entries. - alpha = self.config.temporal_ensemble_momentum - self._ensembled_actions = alpha * self._ensembled_actions + (1 - alpha) * actions[:, :-1] - # The last action, which has no prior moving average, needs to get concatenated onto the end. - self._ensembled_actions = torch.cat([self._ensembled_actions, actions[:, -1:]], dim=1) - # "Consume" the first action. - action, self._ensembled_actions = self._ensembled_actions[:, 0], self._ensembled_actions[:, 1:] + action = self.temporal_ensembler.update(actions) return action # Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by @@ -162,6 +153,97 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin): return loss_dict +class ACTTemporalEnsembler: + def __init__(self, temporal_ensemble_coeff: float, chunk_size: int) -> None: + """Temporal ensembling as described in Algorithm 2 of https://arxiv.org/abs/2304.13705. + + The weights are calculated as wᵢ = exp(-temporal_ensemble_coeff * i) where w₀ is the oldest action. + They are then normalized to sum to 1 by dividing by Σwᵢ. Here's some intuition around how the + coefficient works: + - Setting it to 0 uniformly weighs all actions. + - Setting it positive gives more weight to older actions. + - Setting it negative gives more weight to newer actions. + NOTE: The default value for `temporal_ensemble_coeff` used by the original ACT work is 0.01. This + results in older actions being weighed more highly than newer actions (the experiments documented in + https://github.com/huggingface/lerobot/pull/319 hint at why highly weighing new actions might be + detrimental: doing so aggressively may diminish the benefits of action chunking). + + Here we use an online method for computing the average rather than caching a history of actions in + order to compute the average offline. For a simple 1D sequence it looks something like: + + ``` + import torch + + seq = torch.linspace(8, 8.5, 100) + print(seq) + + m = 0.01 + exp_weights = torch.exp(-m * torch.arange(len(seq))) + print(exp_weights) + + # Calculate offline + avg = (exp_weights * seq).sum() / exp_weights.sum() + print("offline", avg) + + # Calculate online + for i, item in enumerate(seq): + if i == 0: + avg = item + continue + avg *= exp_weights[:i].sum() + avg += item * exp_weights[i] + avg /= exp_weights[:i+1].sum() + print("online", avg) + ``` + """ + self.chunk_size = chunk_size + self.ensemble_weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size)) + self.ensemble_weights_cumsum = torch.cumsum(self.ensemble_weights, dim=0) + self.reset() + + def reset(self): + """Resets the online computation variables.""" + self.ensembled_actions = None + # (chunk_size,) count of how many actions are in the ensemble for each time step in the sequence. + self.ensembled_actions_count = None + + def update(self, actions: Tensor) -> Tensor: + """ + Takes a (batch, chunk_size, action_dim) sequence of actions, update the temporal ensemble for all + time steps, and pop/return the next batch of actions in the sequence. + """ + self.ensemble_weights = self.ensemble_weights.to(device=actions.device) + self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to(device=actions.device) + if self.ensembled_actions is None: + # Initializes `self._ensembled_action` to the sequence of actions predicted during the first + # time step of the episode. + self.ensembled_actions = actions.clone() + # Note: The last dimension is unsqueeze to make sure we can broadcast properly for tensor + # operations later. + self.ensembled_actions_count = torch.ones( + (self.chunk_size, 1), dtype=torch.long, device=self.ensembled_actions.device + ) + else: + # self.ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute + # the online update for those entries. + self.ensembled_actions *= self.ensemble_weights_cumsum[self.ensembled_actions_count - 1] + self.ensembled_actions += actions[:, :-1] * self.ensemble_weights[self.ensembled_actions_count] + self.ensembled_actions /= self.ensemble_weights_cumsum[self.ensembled_actions_count] + self.ensembled_actions_count = torch.clamp(self.ensembled_actions_count + 1, max=self.chunk_size) + # The last action, which has no prior online average, needs to get concatenated onto the end. + self.ensembled_actions = torch.cat([self.ensembled_actions, actions[:, -1:]], dim=1) + self.ensembled_actions_count = torch.cat( + [self.ensembled_actions_count, torch.ones_like(self.ensembled_actions_count[-1:])] + ) + # "Consume" the first action. + action, self.ensembled_actions, self.ensembled_actions_count = ( + self.ensembled_actions[:, 0], + self.ensembled_actions[:, 1:], + self.ensembled_actions_count[1:], + ) + return action + + class ACT(nn.Module): """Action Chunking Transformer: The underlying neural network for ACTPolicy. diff --git a/lerobot/configs/policy/act.yaml b/lerobot/configs/policy/act.yaml index ea2c5b75..28883936 100644 --- a/lerobot/configs/policy/act.yaml +++ b/lerobot/configs/policy/act.yaml @@ -75,7 +75,7 @@ policy: n_vae_encoder_layers: 4 # Inference. - temporal_ensemble_momentum: null + temporal_ensemble_coeff: null # Training and loss computation. dropout: 0.1 diff --git a/lerobot/configs/policy/act_real.yaml b/lerobot/configs/policy/act_real.yaml index c2f7158f..058104f4 100644 --- a/lerobot/configs/policy/act_real.yaml +++ b/lerobot/configs/policy/act_real.yaml @@ -107,7 +107,7 @@ policy: n_vae_encoder_layers: 4 # Inference. - temporal_ensemble_momentum: null + temporal_ensemble_coeff: null # Training and loss computation. dropout: 0.1 diff --git a/lerobot/configs/policy/act_real_no_state.yaml b/lerobot/configs/policy/act_real_no_state.yaml index 5b8a13b4..08261050 100644 --- a/lerobot/configs/policy/act_real_no_state.yaml +++ b/lerobot/configs/policy/act_real_no_state.yaml @@ -103,7 +103,7 @@ policy: n_vae_encoder_layers: 4 # Inference. - temporal_ensemble_momentum: null + temporal_ensemble_coeff: null # Training and loss computation. dropout: 0.1 diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index d7d3b25e..e7c83880 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -494,6 +494,7 @@ def record_dataset( hf_dataset = to_hf_dataset(data_dict, video) episode_data_index = calculate_episode_data_index(hf_dataset) info = { + "codebase_version": CODEBASE_VERSION, "fps": fps, "video": video, } diff --git a/lerobot/scripts/push_dataset_to_hub.py b/lerobot/scripts/push_dataset_to_hub.py index fe62e7c1..ce1a06f7 100644 --- a/lerobot/scripts/push_dataset_to_hub.py +++ b/lerobot/scripts/push_dataset_to_hub.py @@ -40,60 +40,6 @@ python lerobot/scripts/push_dataset_to_hub.py \ --raw-format umi_zarr \ --repo-id lerobot/umi_cup_in_the_wild ``` - -**WARNING: Updating an existing dataset** - -If you want to update an existing dataset, you need to change the `CODEBASE_VERSION` from `lerobot_dataset.py` -before running `push_dataset_to_hub.py`. This is especially useful if you introduce a breaking change -intentionally or not (i.e. something not backward compatible such as modifying the reward functions used, -deleting some frames at the end of an episode, etc.). That way, people running a previous version of the -codebase won't be affected by your change and backward compatibility is maintained. - -For instance, Pusht has many versions to maintain backward compatibility between LeRobot codebase versions: -- [v1.0](https://huggingface.co/datasets/lerobot/pusht/tree/v1.0) -- [v1.1](https://huggingface.co/datasets/lerobot/pusht/tree/v1.1) -- [v1.2](https://huggingface.co/datasets/lerobot/pusht/tree/v1.2) -- [v1.3](https://huggingface.co/datasets/lerobot/pusht/tree/v1.3) -- [v1.4](https://huggingface.co/datasets/lerobot/pusht/tree/v1.4) -- [v1.5](https://huggingface.co/datasets/lerobot/pusht/tree/v1.5) <-- last version -- [main](https://huggingface.co/datasets/lerobot/pusht/tree/main) <-- points to the last version - -However, you will need to update the version of ALL the other datasets so that they have the new -`CODEBASE_VERSION` as a branch in their hugging face dataset repository. Don't worry, there is an easy way -that doesn't require to run `push_dataset_to_hub.py`. You can just "branch-out" from the `main` branch on HF -dataset repo by running this script which corresponds to a `git checkout -b` (so no copy or upload needed): - -```python -import os - -from huggingface_hub import create_branch, hf_hub_download -from huggingface_hub.utils._errors import RepositoryNotFoundError - -from lerobot import available_datasets -from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION - -os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" # makes it easier to see the print-out below - -NEW_CODEBASE_VERSION = "v1.5" # REPLACE THIS WITH YOUR DESIRED VERSION - -for repo_id in available_datasets: - # First check if the newer version already exists. - try: - hf_hub_download( - repo_id=repo_id, repo_type="dataset", filename=".gitattributes", revision=NEW_CODEBASE_VERSION - ) - print(f"Found existing branch for {repo_id}. Please contact a member of the core LeRobot team.") - print("Exiting early") - break - except RepositoryNotFoundError: - # Now create a branch. - create_branch(repo_id, repo_type="dataset", branch=NEW_CODEBASE_VERSION, revision=CODEBASE_VERSION) - print(f"{repo_id} successfully updated") - -``` - -On the other hand, if you are pushing a new dataset, you don't need to worry about any of the instructions -above, nor to be compatible with previous codebase versions. """ import argparse @@ -104,7 +50,7 @@ from pathlib import Path from typing import Any import torch -from huggingface_hub import HfApi, create_branch +from huggingface_hub import HfApi from safetensors.torch import save_file from lerobot.common.datasets.compute_stats import compute_stats @@ -270,7 +216,8 @@ def push_dataset_to_hub( push_meta_data_to_hub(repo_id, meta_data_dir, revision="main") if video: push_videos_to_hub(repo_id, videos_dir, revision="main") - create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION) + api = HfApi() + api.create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION) if tests_data_dir: # get the first episode diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 4e636db8..f707fe12 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -272,7 +272,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No cfg.resume = True elif Logger.get_last_checkpoint_dir(out_dir).exists(): raise RuntimeError( - f"The configured output directory {Logger.get_last_checkpoint_dir(out_dir)} already exists." + f"The configured output directory {Logger.get_last_checkpoint_dir(out_dir)} already exists. If " + "you meant to resume training, please use `resume=true` in your command or yaml configuration." ) # log metrics to terminal and wandb diff --git a/tests/test_policies.py b/tests/test_policies.py index bc9c34ff..63f394e9 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -16,6 +16,7 @@ import inspect from pathlib import Path +import einops import pytest import torch from huggingface_hub import PyTorchModelHubMixin @@ -26,6 +27,7 @@ from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.utils import cycle from lerobot.common.envs.factory import make_env from lerobot.common.envs.utils import preprocess_observation +from lerobot.common.policies.act.modeling_act import ACTTemporalEnsembler from lerobot.common.policies.factory import ( _policy_cfg_from_hydra_cfg, get_policy_and_config_classes, @@ -33,7 +35,7 @@ from lerobot.common.policies.factory import ( ) from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.policy_protocol import Policy -from lerobot.common.utils.utils import init_hydra_config +from lerobot.common.utils.utils import init_hydra_config, seeded_context from lerobot.scripts.train import make_optimizer_and_scheduler from tests.scripts.save_policy_to_safetensors import get_policy_stats from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_cpu, require_env, require_x86_64_kernel @@ -390,3 +392,62 @@ def test_backward_compatibility(env_name, policy_name, extra_overrides, file_nam assert torch.isclose(param_stats[key], saved_param_stats[key], rtol=50, atol=1e-7).all() for key in saved_actions: assert torch.isclose(actions[key], saved_actions[key], rtol=0.1, atol=1e-7).all() + + +def test_act_temporal_ensembler(): + """Check that the online method in ACTTemporalEnsembler matches a simple offline calculation.""" + temporal_ensemble_coeff = 0.01 + chunk_size = 100 + episode_length = 101 + ensembler = ACTTemporalEnsembler(temporal_ensemble_coeff, chunk_size) + # An batch of arbitrary sequences of 1D actions we wish to compute the average over. We'll keep the + # "action space" in [-1, 1]. Apart from that, there is no real reason for the numbers chosen. + with seeded_context(0): + # Dimension is (batch, episode_length, chunk_size, action_dim(=1)) + # Stepping through the episode_length dim is like running inference at each rollout step and getting + # a different action chunk. + batch_seq = torch.stack( + [ + torch.rand(episode_length, chunk_size) * 0.05 - 0.6, + torch.rand(episode_length, chunk_size) * 0.02 - 0.01, + torch.rand(episode_length, chunk_size) * 0.2 + 0.3, + ], + dim=0, + ).unsqueeze(-1) # unsqueeze for action dim + batch_size = batch_seq.shape[0] + # Exponential weighting (normalized). Unsqueeze once to match the position of the `episode_length` + # dimension of `batch_seq`. + weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size)).unsqueeze(-1) + + # Simulate stepping through a rollout and computing a batch of actions with model on each step. + for i in range(episode_length): + # Mock a batch of actions. + actions = torch.zeros(size=(batch_size, chunk_size, 1)) + batch_seq[:, i] + online_avg = ensembler.update(actions) + # Simple offline calculation: avg = Σ(aᵢ*wᵢ) / Σ(wᵢ). + # Note: The complicated bit here is the slicing. Think about the (episode_length, chunk_size) grid. + # What we want to do is take diagonal slices across it starting from the left. + # eg: chunk_size=4, episode_length=6 + # ┌───────┐ + # │0 1 2 3│ + # │1 2 3 4│ + # │2 3 4 5│ + # │3 4 5 6│ + # │4 5 6 7│ + # │5 6 7 8│ + # └───────┘ + chunk_indices = torch.arange(min(i, chunk_size - 1), -1, -1) + episode_step_indices = torch.arange(i + 1)[-len(chunk_indices) :] + seq_slice = batch_seq[:, episode_step_indices, chunk_indices] + offline_avg = ( + einops.reduce(seq_slice * weights[: i + 1], "b s 1 -> b 1", "sum") / weights[: i + 1].sum() + ) + # Sanity check. The average should be between the extrema. + assert torch.all(einops.reduce(seq_slice, "b s 1 -> b 1", "min") <= offline_avg) + assert torch.all(offline_avg <= einops.reduce(seq_slice, "b s 1 -> b 1", "max")) + # Selected atol=1e-4 keeping in mind actions in [-1, 1] and excepting 0.01% error. + assert torch.allclose(online_avg, offline_avg, atol=1e-4) + + +if __name__ == "__main__": + test_act_temporal_ensembler()