Merge remote-tracking branch 'origin/main' into user/rcadene/2024_07_16_control_robot_v2

This commit is contained in:
Remi Cadene 2024-07-17 23:04:14 +02:00
commit 895182b272
20 changed files with 331 additions and 152 deletions

View File

@ -80,7 +80,7 @@ policy:
n_vae_encoder_layers: 4 n_vae_encoder_layers: 4
# Inference. # Inference.
temporal_ensemble_momentum: null temporal_ensemble_coeff: null
# Training and loss computation. # Training and loss computation.
dropout: 0.1 dropout: 0.1

View File

@ -35,15 +35,16 @@ from lerobot.common.datasets.utils import (
) )
from lerobot.common.datasets.video_utils import VideoFrame, load_from_videos from lerobot.common.datasets.video_utils import VideoFrame, load_from_videos
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None # For maintainers, see lerobot/common/datasets/push_dataset_to_hub/codebase_version.md
CODEBASE_VERSION = "v1.5" CODEBASE_VERSION = "v1.5"
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
class LeRobotDataset(torch.utils.data.Dataset): class LeRobotDataset(torch.utils.data.Dataset):
def __init__( def __init__(
self, self,
repo_id: str, repo_id: str,
version: str | None = CODEBASE_VERSION,
root: Path | None = DATA_DIR, root: Path | None = DATA_DIR,
split: str = "train", split: str = "train",
image_transforms: Callable | None = None, image_transforms: Callable | None = None,
@ -52,7 +53,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
): ):
super().__init__() super().__init__()
self.repo_id = repo_id self.repo_id = repo_id
self.version = version
self.root = root self.root = root
self.split = split self.split = split
self.image_transforms = image_transforms self.image_transforms = image_transforms
@ -60,16 +60,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
# load data from hub or locally when root is provided # load data from hub or locally when root is provided
# TODO(rcadene, aliberts): implement faster transfer # TODO(rcadene, aliberts): implement faster transfer
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads # https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
self.hf_dataset = load_hf_dataset(repo_id, version, root, split) self.hf_dataset = load_hf_dataset(repo_id, CODEBASE_VERSION, root, split)
if split == "train": if split == "train":
self.episode_data_index = load_episode_data_index(repo_id, version, root) self.episode_data_index = load_episode_data_index(repo_id, CODEBASE_VERSION, root)
else: else:
self.episode_data_index = calculate_episode_data_index(self.hf_dataset) self.episode_data_index = calculate_episode_data_index(self.hf_dataset)
self.hf_dataset = reset_episode_index(self.hf_dataset) self.hf_dataset = reset_episode_index(self.hf_dataset)
self.stats = load_stats(repo_id, version, root) self.stats = load_stats(repo_id, CODEBASE_VERSION, root)
self.info = load_info(repo_id, version, root) self.info = load_info(repo_id, CODEBASE_VERSION, root)
if self.video: if self.video:
self.videos_dir = load_videos(repo_id, version, root) self.videos_dir = load_videos(repo_id, CODEBASE_VERSION, root)
self.video_backend = video_backend if video_backend is not None else "pyav" self.video_backend = video_backend if video_backend is not None else "pyav"
@property @property
@ -164,7 +164,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
return ( return (
f"{self.__class__.__name__}(\n" f"{self.__class__.__name__}(\n"
f" Repository ID: '{self.repo_id}',\n" f" Repository ID: '{self.repo_id}',\n"
f" Version: '{self.version}',\n"
f" Split: '{self.split}',\n" f" Split: '{self.split}',\n"
f" Number of Samples: {self.num_samples},\n" f" Number of Samples: {self.num_samples},\n"
f" Number of Episodes: {self.num_episodes},\n" f" Number of Episodes: {self.num_episodes},\n"
@ -173,6 +172,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
f" Camera Keys: {self.camera_keys},\n" f" Camera Keys: {self.camera_keys},\n"
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n" f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
f" Transformations: {self.image_transforms},\n" f" Transformations: {self.image_transforms},\n"
f" Codebase Version: {self.info.get('codebase_version', '< v1.6')},\n"
f")" f")"
) )
@ -180,7 +180,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
def from_preloaded( def from_preloaded(
cls, cls,
repo_id: str = "from_preloaded", repo_id: str = "from_preloaded",
version: str | None = CODEBASE_VERSION,
root: Path | None = None, root: Path | None = None,
split: str = "train", split: str = "train",
transform: callable = None, transform: callable = None,
@ -204,7 +203,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
# create an empty object of type LeRobotDataset # create an empty object of type LeRobotDataset
obj = cls.__new__(cls) obj = cls.__new__(cls)
obj.repo_id = repo_id obj.repo_id = repo_id
obj.version = version
obj.root = root obj.root = root
obj.split = split obj.split = split
obj.image_transforms = transform obj.image_transforms = transform
@ -228,7 +226,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
def __init__( def __init__(
self, self,
repo_ids: list[str], repo_ids: list[str],
version: str | None = CODEBASE_VERSION,
root: Path | None = DATA_DIR, root: Path | None = DATA_DIR,
split: str = "train", split: str = "train",
image_transforms: Callable | None = None, image_transforms: Callable | None = None,
@ -242,7 +239,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
self._datasets = [ self._datasets = [
LeRobotDataset( LeRobotDataset(
repo_id, repo_id,
version=version,
root=root, root=root,
split=split, split=split,
delta_timestamps=delta_timestamps, delta_timestamps=delta_timestamps,
@ -279,7 +275,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
) )
self.disabled_data_keys.update(extra_keys) self.disabled_data_keys.update(extra_keys)
self.version = version
self.root = root self.root = root
self.split = split self.split = split
self.image_transforms = image_transforms self.image_transforms = image_transforms
@ -395,7 +390,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
return ( return (
f"{self.__class__.__name__}(\n" f"{self.__class__.__name__}(\n"
f" Repository IDs: '{self.repo_ids}',\n" f" Repository IDs: '{self.repo_ids}',\n"
f" Version: '{self.version}',\n"
f" Split: '{self.split}',\n" f" Split: '{self.split}',\n"
f" Number of Samples: {self.num_samples},\n" f" Number of Samples: {self.num_samples},\n"
f" Number of Episodes: {self.num_episodes},\n" f" Number of Episodes: {self.num_episodes},\n"

View File

@ -0,0 +1,57 @@
## Using / Updating `CODEBASE_VERSION` (for maintainers)
Since our dataset pushed to the hub are decoupled with the evolution of this repo, we ensure compatibility of
the datasets with our code, we use a `CODEBASE_VERSION` (defined in
lerobot/common/datasets/lerobot_dataset.py) variable.
For instance, [`lerobot/pusht`](https://huggingface.co/datasets/lerobot/pusht) has many versions to maintain backward compatibility between LeRobot codebase versions:
- [v1.0](https://huggingface.co/datasets/lerobot/pusht/tree/v1.0)
- [v1.1](https://huggingface.co/datasets/lerobot/pusht/tree/v1.1)
- [v1.2](https://huggingface.co/datasets/lerobot/pusht/tree/v1.2)
- [v1.3](https://huggingface.co/datasets/lerobot/pusht/tree/v1.3)
- [v1.4](https://huggingface.co/datasets/lerobot/pusht/tree/v1.4)
- [v1.5](https://huggingface.co/datasets/lerobot/pusht/tree/v1.5) <-- last version
- [main](https://huggingface.co/datasets/lerobot/pusht/tree/main) <-- points to the last version
Starting with v1.6, every dataset pushed to the hub or saved locally also have this version number in their
`info.json` metadata.
### Uploading a new dataset
If you are pushing a new dataset, you don't need to worry about any of the instructions below, nor to be
compatible with previous codebase versions. The `push_dataset_to_hub.py` script will automatically tag your
dataset with the current `CODEBASE_VERSION`.
### Updating an existing dataset
If you want to update an existing dataset, you need to change the `CODEBASE_VERSION` from `lerobot_dataset.py`
before running `push_dataset_to_hub.py`. This is especially useful if you introduce a breaking change
intentionally or not (i.e. something not backward compatible such as modifying the reward functions used,
deleting some frames at the end of an episode, etc.). That way, people running a previous version of the
codebase won't be affected by your change and backward compatibility is maintained.
However, you will need to update the version of ALL the other datasets so that they have the new
`CODEBASE_VERSION` as a branch in their hugging face dataset repository. Don't worry, there is an easy way
that doesn't require to run `push_dataset_to_hub.py`. You can just "branch-out" from the `main` branch on HF
dataset repo by running this script which corresponds to a `git checkout -b` (so no copy or upload needed):
```python
from huggingface_hub import HfApi
from lerobot import available_datasets
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
api = HfApi()
for repo_id in available_datasets:
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
branches = [b.name for b in dataset_info.branches]
if CODEBASE_VERSION in branches:
# First check if the newer version already exists.
print(f"Found existing branch for {repo_id}. Please contact a member of the core LeRobot team.")
print("Exiting early")
break
else:
# Now create a branch named after the new version by branching out from "main"
# which is expected to be the preceding version
api.create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION, revision="main")
print(f"{repo_id} successfully updated")
```

View File

@ -32,46 +32,41 @@ from pathlib import Path
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
AVAILABLE_RAW_REPO_IDS = [ AVAILABLE_RAW_REPO_IDS = [
"cadene/pusht_image_raw", "lerobot-raw/aloha_mobile_cabinet_raw",
"cadene/xarm_lift_medium_image_raw", "lerobot-raw/aloha_mobile_chair_raw",
"cadene/xarm_lift_medium_replay_image_raw", "lerobot-raw/aloha_mobile_elevator_raw",
"cadene/xarm_push_medium_image_raw", "lerobot-raw/aloha_mobile_shrimp_raw",
"cadene/xarm_push_medium_replay_image_raw", "lerobot-raw/aloha_mobile_wash_pan_raw",
"cadene/aloha_sim_insertion_human_image_raw", "lerobot-raw/aloha_mobile_wipe_wine_raw",
"cadene/aloha_sim_insertion_scripted_image_raw", "lerobot-raw/aloha_sim_insertion_human_raw",
"cadene/aloha_sim_transfer_cube_human_image_raw", "lerobot-raw/aloha_sim_insertion_scripted_raw",
"cadene/aloha_sim_transfer_cube_scripted_image_raw", "lerobot-raw/aloha_sim_transfer_cube_human_raw",
"cadene/pusht_raw", "lerobot-raw/aloha_sim_transfer_cube_scripted_raw",
"cadene/xarm_lift_medium_raw", "lerobot-raw/aloha_static_battery_raw",
"cadene/xarm_lift_medium_replay_raw", "lerobot-raw/aloha_static_candy_raw",
"cadene/xarm_push_medium_raw", "lerobot-raw/aloha_static_coffee_new_raw",
"cadene/xarm_push_medium_replay_raw", "lerobot-raw/aloha_static_coffee_raw",
"cadene/aloha_sim_insertion_human_raw", "lerobot-raw/aloha_static_cups_open_raw",
"cadene/aloha_sim_insertion_scripted_raw", "lerobot-raw/aloha_static_fork_pick_up_raw",
"cadene/aloha_sim_transfer_cube_human_raw", "lerobot-raw/aloha_static_pingpong_test_raw",
"cadene/aloha_sim_transfer_cube_scripted_raw", "lerobot-raw/aloha_static_pro_pencil_raw",
"cadene/aloha_mobile_cabinet_raw", "lerobot-raw/aloha_static_screw_driver_raw",
"cadene/aloha_mobile_chair_raw", "lerobot-raw/aloha_static_tape_raw",
"cadene/aloha_mobile_elevator_raw", "lerobot-raw/aloha_static_thread_velcro_raw",
"cadene/aloha_mobile_shrimp_raw", "lerobot-raw/aloha_static_towel_raw",
"cadene/aloha_mobile_wash_pan_raw", "lerobot-raw/aloha_static_vinh_cup_left_raw",
"cadene/aloha_mobile_wipe_wine_raw", "lerobot-raw/aloha_static_vinh_cup_raw",
"cadene/aloha_static_battery_raw", "lerobot-raw/aloha_static_ziploc_slide_raw",
"cadene/aloha_static_candy_raw", "lerobot-raw/pusht_raw",
"cadene/aloha_static_coffee_raw", "lerobot-raw/umi_cup_in_the_wild_raw",
"cadene/aloha_static_coffee_new_raw", "lerobot-raw/unitreeh1_fold_clothes_raw",
"cadene/aloha_static_cups_open_raw", "lerobot-raw/unitreeh1_rearrange_objects_raw",
"cadene/aloha_static_fork_pick_up_raw", "lerobot-raw/unitreeh1_two_robot_greeting_raw",
"cadene/aloha_static_pingpong_test_raw", "lerobot-raw/unitreeh1_warehouse_raw",
"cadene/aloha_static_pro_pencil_raw", "lerobot-raw/xarm_lift_medium_raw",
"cadene/aloha_static_screw_driver_raw", "lerobot-raw/xarm_lift_medium_replay_raw",
"cadene/aloha_static_tape_raw", "lerobot-raw/xarm_push_medium_raw",
"cadene/aloha_static_thread_velcro_raw", "lerobot-raw/xarm_push_medium_replay_raw",
"cadene/aloha_static_towel_raw",
"cadene/aloha_static_vinh_cup_raw",
"cadene/aloha_static_vinh_cup_left_raw",
"cadene/aloha_static_ziploc_slide_raw",
"cadene/umi_cup_in_the_wild_raw",
] ]
@ -89,7 +84,6 @@ def download_raw(raw_dir: Path, repo_id: str):
stacklevel=1, stacklevel=1,
) )
raw_dir = Path(raw_dir)
# Send warning if raw_dir isn't well formated # Send warning if raw_dir isn't well formated
if raw_dir.parts[-2] != user_id or raw_dir.parts[-1] != dataset_id: if raw_dir.parts[-2] != user_id or raw_dir.parts[-1] != dataset_id:
warnings.warn( warnings.warn(
@ -99,7 +93,7 @@ def download_raw(raw_dir: Path, repo_id: str):
raw_dir.mkdir(parents=True, exist_ok=True) raw_dir.mkdir(parents=True, exist_ok=True)
logging.info(f"Start downloading from huggingface.co/{user_id} for {dataset_id}") logging.info(f"Start downloading from huggingface.co/{user_id} for {dataset_id}")
snapshot_download(f"{repo_id}", repo_type="dataset", local_dir=raw_dir) snapshot_download(repo_id, repo_type="dataset", local_dir=raw_dir)
logging.info(f"Finish downloading from huggingface.co/{user_id} for {dataset_id}") logging.info(f"Finish downloading from huggingface.co/{user_id} for {dataset_id}")

View File

@ -28,6 +28,7 @@ import tqdm
from datasets import Dataset, Features, Image, Sequence, Value from datasets import Dataset, Features, Image, Sequence, Value
from PIL import Image as PILImage from PIL import Image as PILImage
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
calculate_episode_data_index, calculate_episode_data_index,
@ -210,6 +211,7 @@ def from_raw_to_lerobot_format(
hf_dataset = to_hf_dataset(data_dict, video) hf_dataset = to_hf_dataset(data_dict, video)
episode_data_index = calculate_episode_data_index(hf_dataset) episode_data_index = calculate_episode_data_index(hf_dataset)
info = { info = {
"codebase_version": CODEBASE_VERSION,
"fps": fps, "fps": fps,
"video": video, "video": video,
} }

View File

@ -23,6 +23,7 @@ import torch
from datasets import Dataset, Features, Image, Value from datasets import Dataset, Features, Image, Value
from PIL import Image as PILImage from PIL import Image as PILImage
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes
from lerobot.common.datasets.utils import calculate_episode_data_index, hf_transform_to_torch from lerobot.common.datasets.utils import calculate_episode_data_index, hf_transform_to_torch
from lerobot.common.datasets.video_utils import VideoFrame from lerobot.common.datasets.video_utils import VideoFrame
@ -95,6 +96,7 @@ def from_raw_to_lerobot_format(
hf_dataset = to_hf_dataset(data_dict, video) hf_dataset = to_hf_dataset(data_dict, video)
episode_data_index = calculate_episode_data_index(hf_dataset) episode_data_index = calculate_episode_data_index(hf_dataset)
info = { info = {
"codebase_version": CODEBASE_VERSION,
"fps": fps, "fps": fps,
"video": video, "video": video,
} }

View File

@ -24,6 +24,7 @@ import pandas as pd
import torch import torch
from datasets import Dataset, Features, Image, Sequence, Value from datasets import Dataset, Features, Image, Sequence, Value
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
calculate_episode_data_index, calculate_episode_data_index,
hf_transform_to_torch, hf_transform_to_torch,
@ -214,6 +215,7 @@ def from_raw_to_lerobot_format(
hf_dataset = to_hf_dataset(data_df, video) hf_dataset = to_hf_dataset(data_df, video)
episode_data_index = calculate_episode_data_index(hf_dataset) episode_data_index = calculate_episode_data_index(hf_dataset)
info = { info = {
"codebase_version": CODEBASE_VERSION,
"fps": fps, "fps": fps,
"video": video, "video": video,
} }

View File

@ -25,6 +25,7 @@ import zarr
from datasets import Dataset, Features, Image, Sequence, Value from datasets import Dataset, Features, Image, Sequence, Value
from PIL import Image as PILImage from PIL import Image as PILImage
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
calculate_episode_data_index, calculate_episode_data_index,
@ -258,6 +259,7 @@ def from_raw_to_lerobot_format(
hf_dataset = to_hf_dataset(data_dict, video, keypoints_instead_of_image) hf_dataset = to_hf_dataset(data_dict, video, keypoints_instead_of_image)
episode_data_index = calculate_episode_data_index(hf_dataset) episode_data_index = calculate_episode_data_index(hf_dataset)
info = { info = {
"codebase_version": CODEBASE_VERSION,
"fps": fps, "fps": fps,
"video": video if not keypoints_instead_of_image else 0, "video": video if not keypoints_instead_of_image else 0,
} }

View File

@ -25,6 +25,7 @@ import zarr
from datasets import Dataset, Features, Image, Sequence, Value from datasets import Dataset, Features, Image, Sequence, Value
from PIL import Image as PILImage from PIL import Image as PILImage
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
from lerobot.common.datasets.push_dataset_to_hub._umi_imagecodecs_numcodecs import register_codecs from lerobot.common.datasets.push_dataset_to_hub._umi_imagecodecs_numcodecs import register_codecs
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
@ -199,6 +200,7 @@ def from_raw_to_lerobot_format(
hf_dataset = to_hf_dataset(data_dict, video) hf_dataset = to_hf_dataset(data_dict, video)
episode_data_index = calculate_episode_data_index(hf_dataset) episode_data_index = calculate_episode_data_index(hf_dataset)
info = { info = {
"codebase_version": CODEBASE_VERSION,
"fps": fps, "fps": fps,
"video": video, "video": video,
} }

View File

@ -25,6 +25,7 @@ import tqdm
from datasets import Dataset, Features, Image, Sequence, Value from datasets import Dataset, Features, Image, Sequence, Value
from PIL import Image as PILImage from PIL import Image as PILImage
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
calculate_episode_data_index, calculate_episode_data_index,
@ -177,6 +178,7 @@ def from_raw_to_lerobot_format(
hf_dataset = to_hf_dataset(data_dict, video) hf_dataset = to_hf_dataset(data_dict, video)
episode_data_index = calculate_episode_data_index(hf_dataset) episode_data_index = calculate_episode_data_index(hf_dataset)
info = { info = {
"codebase_version": CODEBASE_VERSION,
"fps": fps, "fps": fps,
"video": video, "video": video,
} }

View File

@ -15,13 +15,15 @@
# limitations under the License. # limitations under the License.
import json import json
import re import re
import warnings
from functools import cache
from pathlib import Path from pathlib import Path
from typing import Dict from typing import Dict
import datasets import datasets
import torch import torch
from datasets import load_dataset, load_from_disk from datasets import load_dataset, load_from_disk
from huggingface_hub import hf_hub_download, snapshot_download from huggingface_hub import HfApi, hf_hub_download, snapshot_download
from PIL import Image as PILImage from PIL import Image as PILImage
from safetensors.torch import load_file from safetensors.torch import load_file
from torchvision import transforms from torchvision import transforms
@ -80,7 +82,28 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
return items_dict return items_dict
def load_hf_dataset(repo_id, version, root, split) -> datasets.Dataset: @cache
def get_hf_dataset_safe_version(repo_id: str, version: str) -> str:
api = HfApi()
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
branches = [b.name for b in dataset_info.branches]
if version not in branches:
warnings.warn(
f"""You are trying to load a dataset from {repo_id} created with a previous version of the
codebase. The following versions are available: {branches}.
The requested version ('{version}') is not found. You should be fine since
backward compatibility is maintained. If you encounter a problem, contact LeRobot maintainers on
Discord ('https://discord.com/invite/s3KuuzsPFb') or open an issue on github.""",
stacklevel=1,
)
if "main" not in branches:
raise ValueError(f"Version 'main' not found on {repo_id}")
return "main"
else:
return version
def load_hf_dataset(repo_id: str, version: str, root: Path, split: str) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc.""" """hf_dataset contains all the observations, states, actions, rewards, etc."""
if root is not None: if root is not None:
hf_dataset = load_from_disk(str(Path(root) / repo_id / "train")) hf_dataset = load_from_disk(str(Path(root) / repo_id / "train"))
@ -101,7 +124,9 @@ def load_hf_dataset(repo_id, version, root, split) -> datasets.Dataset:
f'`split` ({split}) should either be "train", "train[INT:]", or "train[:INT]"' f'`split` ({split}) should either be "train", "train[INT:]", or "train[:INT]"'
) )
else: else:
hf_dataset = load_dataset(repo_id, revision=version, split=split) safe_version = get_hf_dataset_safe_version(repo_id, version)
hf_dataset = load_dataset(repo_id, revision=safe_version, split=split)
hf_dataset.set_transform(hf_transform_to_torch) hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset return hf_dataset
@ -119,8 +144,9 @@ def load_episode_data_index(repo_id, version, root) -> dict[str, torch.Tensor]:
if root is not None: if root is not None:
path = Path(root) / repo_id / "meta_data" / "episode_data_index.safetensors" path = Path(root) / repo_id / "meta_data" / "episode_data_index.safetensors"
else: else:
safe_version = get_hf_dataset_safe_version(repo_id, version)
path = hf_hub_download( path = hf_hub_download(
repo_id, "meta_data/episode_data_index.safetensors", repo_type="dataset", revision=version repo_id, "meta_data/episode_data_index.safetensors", repo_type="dataset", revision=safe_version
) )
return load_file(path) return load_file(path)
@ -137,7 +163,10 @@ def load_stats(repo_id, version, root) -> dict[str, dict[str, torch.Tensor]]:
if root is not None: if root is not None:
path = Path(root) / repo_id / "meta_data" / "stats.safetensors" path = Path(root) / repo_id / "meta_data" / "stats.safetensors"
else: else:
path = hf_hub_download(repo_id, "meta_data/stats.safetensors", repo_type="dataset", revision=version) safe_version = get_hf_dataset_safe_version(repo_id, version)
path = hf_hub_download(
repo_id, "meta_data/stats.safetensors", repo_type="dataset", revision=safe_version
)
stats = load_file(path) stats = load_file(path)
return unflatten_dict(stats) return unflatten_dict(stats)
@ -154,7 +183,8 @@ def load_info(repo_id, version, root) -> dict:
if root is not None: if root is not None:
path = Path(root) / repo_id / "meta_data" / "info.json" path = Path(root) / repo_id / "meta_data" / "info.json"
else: else:
path = hf_hub_download(repo_id, "meta_data/info.json", repo_type="dataset", revision=version) safe_version = get_hf_dataset_safe_version(repo_id, version)
path = hf_hub_download(repo_id, "meta_data/info.json", repo_type="dataset", revision=safe_version)
with open(path) as f: with open(path) as f:
info = json.load(f) info = json.load(f)
@ -166,7 +196,8 @@ def load_videos(repo_id, version, root) -> Path:
path = Path(root) / repo_id / "videos" path = Path(root) / repo_id / "videos"
else: else:
# TODO(rcadene): we download the whole repo here. see if we can avoid this # TODO(rcadene): we download the whole repo here. see if we can avoid this
repo_dir = snapshot_download(repo_id, repo_type="dataset", revision=version) safe_version = get_hf_dataset_safe_version(repo_id, version)
repo_dir = snapshot_download(repo_id, repo_type="dataset", revision=safe_version)
path = Path(repo_dir) / "videos" path = Path(repo_dir) / "videos"
return path return path

View File

@ -76,12 +76,10 @@ class ACTConfig:
documentation in the policy class). documentation in the policy class).
latent_dim: The VAE's latent dimension. latent_dim: The VAE's latent dimension.
n_vae_encoder_layers: The number of transformer layers to use for the VAE's encoder. n_vae_encoder_layers: The number of transformer layers to use for the VAE's encoder.
temporal_ensemble_momentum: Exponential moving average (EMA) momentum parameter (α) for ensembling temporal_ensemble_coeff: Coefficient for the exponential weighting scheme to apply for temporal
actions for a given time step over multiple policy invocations. Updates are calculated as: ensembling. Defaults to None which means temporal ensembling is not used. `n_action_steps` must be
x = αx + (1-α)xₙ. Note that the ACT paper and original ACT code describes a different 1 when using this feature, as inference needs to happen at every step to form an ensemble. For
parameter here: they refer to a weighting scheme wᵢ = exp(-mi) and set m = 0.01. With our more information on how ensembling works, please see `ACTTemporalEnsembler`.
formulation, this is equivalent to α = exp(-0.01) 0.99. When this parameter is provided, we
require `n_action_steps == 1` (since we need to query the policy every step anyway).
dropout: Dropout to use in the transformer layers (see code for details). dropout: Dropout to use in the transformer layers (see code for details).
kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective
is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`. is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`.
@ -139,7 +137,8 @@ class ACTConfig:
n_vae_encoder_layers: int = 4 n_vae_encoder_layers: int = 4
# Inference. # Inference.
temporal_ensemble_momentum: float | None = None # Note: the value used in ACT when temporal ensembling is enabled is 0.01.
temporal_ensemble_coeff: float | None = None
# Training and loss computation. # Training and loss computation.
dropout: float = 0.1 dropout: float = 0.1
@ -151,7 +150,7 @@ class ACTConfig:
raise ValueError( raise ValueError(
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}." f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
) )
if self.temporal_ensemble_momentum is not None and self.n_action_steps > 1: if self.temporal_ensemble_coeff is not None and self.n_action_steps > 1:
raise NotImplementedError( raise NotImplementedError(
"`n_action_steps` must be 1 when using temporal ensembling. This is " "`n_action_steps` must be 1 when using temporal ensembling. This is "
"because the policy needs to be queried every step to compute the ensembled action." "because the policy needs to be queried every step to compute the ensembled action."

View File

@ -77,12 +77,15 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
if config.temporal_ensemble_coeff is not None:
self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size)
self.reset() self.reset()
def reset(self): def reset(self):
"""This should be called whenever the environment is reset.""" """This should be called whenever the environment is reset."""
if self.config.temporal_ensemble_momentum is not None: if self.config.temporal_ensemble_coeff is not None:
self._ensembled_actions = None self.temporal_ensembler.reset()
else: else:
self._action_queue = deque([], maxlen=self.config.n_action_steps) self._action_queue = deque([], maxlen=self.config.n_action_steps)
@ -100,24 +103,12 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
if len(self.expected_image_keys) > 0: if len(self.expected_image_keys) > 0:
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
# If we are doing temporal ensembling, keep track of the exponential moving average (EMA), and return # If we are doing temporal ensembling, do online updates where we keep track of the number of actions
# the first action. # we are ensembling over.
if self.config.temporal_ensemble_momentum is not None: if self.config.temporal_ensemble_coeff is not None:
actions = self.model(batch)[0] # (batch_size, chunk_size, action_dim) actions = self.model(batch)[0] # (batch_size, chunk_size, action_dim)
actions = self.unnormalize_outputs({"action": actions})["action"] actions = self.unnormalize_outputs({"action": actions})["action"]
if self._ensembled_actions is None: action = self.temporal_ensembler.update(actions)
# Initializes `self._ensembled_action` to the sequence of actions predicted during the first
# time step of the episode.
self._ensembled_actions = actions.clone()
else:
# self._ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute
# the EMA update for those entries.
alpha = self.config.temporal_ensemble_momentum
self._ensembled_actions = alpha * self._ensembled_actions + (1 - alpha) * actions[:, :-1]
# The last action, which has no prior moving average, needs to get concatenated onto the end.
self._ensembled_actions = torch.cat([self._ensembled_actions, actions[:, -1:]], dim=1)
# "Consume" the first action.
action, self._ensembled_actions = self._ensembled_actions[:, 0], self._ensembled_actions[:, 1:]
return action return action
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by # Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
@ -162,6 +153,97 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
return loss_dict return loss_dict
class ACTTemporalEnsembler:
def __init__(self, temporal_ensemble_coeff: float, chunk_size: int) -> None:
"""Temporal ensembling as described in Algorithm 2 of https://arxiv.org/abs/2304.13705.
The weights are calculated as wᵢ = exp(-temporal_ensemble_coeff * i) where w₀ is the oldest action.
They are then normalized to sum to 1 by dividing by Σwᵢ. Here's some intuition around how the
coefficient works:
- Setting it to 0 uniformly weighs all actions.
- Setting it positive gives more weight to older actions.
- Setting it negative gives more weight to newer actions.
NOTE: The default value for `temporal_ensemble_coeff` used by the original ACT work is 0.01. This
results in older actions being weighed more highly than newer actions (the experiments documented in
https://github.com/huggingface/lerobot/pull/319 hint at why highly weighing new actions might be
detrimental: doing so aggressively may diminish the benefits of action chunking).
Here we use an online method for computing the average rather than caching a history of actions in
order to compute the average offline. For a simple 1D sequence it looks something like:
```
import torch
seq = torch.linspace(8, 8.5, 100)
print(seq)
m = 0.01
exp_weights = torch.exp(-m * torch.arange(len(seq)))
print(exp_weights)
# Calculate offline
avg = (exp_weights * seq).sum() / exp_weights.sum()
print("offline", avg)
# Calculate online
for i, item in enumerate(seq):
if i == 0:
avg = item
continue
avg *= exp_weights[:i].sum()
avg += item * exp_weights[i]
avg /= exp_weights[:i+1].sum()
print("online", avg)
```
"""
self.chunk_size = chunk_size
self.ensemble_weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size))
self.ensemble_weights_cumsum = torch.cumsum(self.ensemble_weights, dim=0)
self.reset()
def reset(self):
"""Resets the online computation variables."""
self.ensembled_actions = None
# (chunk_size,) count of how many actions are in the ensemble for each time step in the sequence.
self.ensembled_actions_count = None
def update(self, actions: Tensor) -> Tensor:
"""
Takes a (batch, chunk_size, action_dim) sequence of actions, update the temporal ensemble for all
time steps, and pop/return the next batch of actions in the sequence.
"""
self.ensemble_weights = self.ensemble_weights.to(device=actions.device)
self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to(device=actions.device)
if self.ensembled_actions is None:
# Initializes `self._ensembled_action` to the sequence of actions predicted during the first
# time step of the episode.
self.ensembled_actions = actions.clone()
# Note: The last dimension is unsqueeze to make sure we can broadcast properly for tensor
# operations later.
self.ensembled_actions_count = torch.ones(
(self.chunk_size, 1), dtype=torch.long, device=self.ensembled_actions.device
)
else:
# self.ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute
# the online update for those entries.
self.ensembled_actions *= self.ensemble_weights_cumsum[self.ensembled_actions_count - 1]
self.ensembled_actions += actions[:, :-1] * self.ensemble_weights[self.ensembled_actions_count]
self.ensembled_actions /= self.ensemble_weights_cumsum[self.ensembled_actions_count]
self.ensembled_actions_count = torch.clamp(self.ensembled_actions_count + 1, max=self.chunk_size)
# The last action, which has no prior online average, needs to get concatenated onto the end.
self.ensembled_actions = torch.cat([self.ensembled_actions, actions[:, -1:]], dim=1)
self.ensembled_actions_count = torch.cat(
[self.ensembled_actions_count, torch.ones_like(self.ensembled_actions_count[-1:])]
)
# "Consume" the first action.
action, self.ensembled_actions, self.ensembled_actions_count = (
self.ensembled_actions[:, 0],
self.ensembled_actions[:, 1:],
self.ensembled_actions_count[1:],
)
return action
class ACT(nn.Module): class ACT(nn.Module):
"""Action Chunking Transformer: The underlying neural network for ACTPolicy. """Action Chunking Transformer: The underlying neural network for ACTPolicy.

View File

@ -75,7 +75,7 @@ policy:
n_vae_encoder_layers: 4 n_vae_encoder_layers: 4
# Inference. # Inference.
temporal_ensemble_momentum: null temporal_ensemble_coeff: null
# Training and loss computation. # Training and loss computation.
dropout: 0.1 dropout: 0.1

View File

@ -107,7 +107,7 @@ policy:
n_vae_encoder_layers: 4 n_vae_encoder_layers: 4
# Inference. # Inference.
temporal_ensemble_momentum: null temporal_ensemble_coeff: null
# Training and loss computation. # Training and loss computation.
dropout: 0.1 dropout: 0.1

View File

@ -103,7 +103,7 @@ policy:
n_vae_encoder_layers: 4 n_vae_encoder_layers: 4
# Inference. # Inference.
temporal_ensemble_momentum: null temporal_ensemble_coeff: null
# Training and loss computation. # Training and loss computation.
dropout: 0.1 dropout: 0.1

View File

@ -494,6 +494,7 @@ def record_dataset(
hf_dataset = to_hf_dataset(data_dict, video) hf_dataset = to_hf_dataset(data_dict, video)
episode_data_index = calculate_episode_data_index(hf_dataset) episode_data_index = calculate_episode_data_index(hf_dataset)
info = { info = {
"codebase_version": CODEBASE_VERSION,
"fps": fps, "fps": fps,
"video": video, "video": video,
} }

View File

@ -40,60 +40,6 @@ python lerobot/scripts/push_dataset_to_hub.py \
--raw-format umi_zarr \ --raw-format umi_zarr \
--repo-id lerobot/umi_cup_in_the_wild --repo-id lerobot/umi_cup_in_the_wild
``` ```
**WARNING: Updating an existing dataset**
If you want to update an existing dataset, you need to change the `CODEBASE_VERSION` from `lerobot_dataset.py`
before running `push_dataset_to_hub.py`. This is especially useful if you introduce a breaking change
intentionally or not (i.e. something not backward compatible such as modifying the reward functions used,
deleting some frames at the end of an episode, etc.). That way, people running a previous version of the
codebase won't be affected by your change and backward compatibility is maintained.
For instance, Pusht has many versions to maintain backward compatibility between LeRobot codebase versions:
- [v1.0](https://huggingface.co/datasets/lerobot/pusht/tree/v1.0)
- [v1.1](https://huggingface.co/datasets/lerobot/pusht/tree/v1.1)
- [v1.2](https://huggingface.co/datasets/lerobot/pusht/tree/v1.2)
- [v1.3](https://huggingface.co/datasets/lerobot/pusht/tree/v1.3)
- [v1.4](https://huggingface.co/datasets/lerobot/pusht/tree/v1.4)
- [v1.5](https://huggingface.co/datasets/lerobot/pusht/tree/v1.5) <-- last version
- [main](https://huggingface.co/datasets/lerobot/pusht/tree/main) <-- points to the last version
However, you will need to update the version of ALL the other datasets so that they have the new
`CODEBASE_VERSION` as a branch in their hugging face dataset repository. Don't worry, there is an easy way
that doesn't require to run `push_dataset_to_hub.py`. You can just "branch-out" from the `main` branch on HF
dataset repo by running this script which corresponds to a `git checkout -b` (so no copy or upload needed):
```python
import os
from huggingface_hub import create_branch, hf_hub_download
from huggingface_hub.utils._errors import RepositoryNotFoundError
from lerobot import available_datasets
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" # makes it easier to see the print-out below
NEW_CODEBASE_VERSION = "v1.5" # REPLACE THIS WITH YOUR DESIRED VERSION
for repo_id in available_datasets:
# First check if the newer version already exists.
try:
hf_hub_download(
repo_id=repo_id, repo_type="dataset", filename=".gitattributes", revision=NEW_CODEBASE_VERSION
)
print(f"Found existing branch for {repo_id}. Please contact a member of the core LeRobot team.")
print("Exiting early")
break
except RepositoryNotFoundError:
# Now create a branch.
create_branch(repo_id, repo_type="dataset", branch=NEW_CODEBASE_VERSION, revision=CODEBASE_VERSION)
print(f"{repo_id} successfully updated")
```
On the other hand, if you are pushing a new dataset, you don't need to worry about any of the instructions
above, nor to be compatible with previous codebase versions.
""" """
import argparse import argparse
@ -104,7 +50,7 @@ from pathlib import Path
from typing import Any from typing import Any
import torch import torch
from huggingface_hub import HfApi, create_branch from huggingface_hub import HfApi
from safetensors.torch import save_file from safetensors.torch import save_file
from lerobot.common.datasets.compute_stats import compute_stats from lerobot.common.datasets.compute_stats import compute_stats
@ -270,7 +216,8 @@ def push_dataset_to_hub(
push_meta_data_to_hub(repo_id, meta_data_dir, revision="main") push_meta_data_to_hub(repo_id, meta_data_dir, revision="main")
if video: if video:
push_videos_to_hub(repo_id, videos_dir, revision="main") push_videos_to_hub(repo_id, videos_dir, revision="main")
create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION) api = HfApi()
api.create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION)
if tests_data_dir: if tests_data_dir:
# get the first episode # get the first episode

View File

@ -272,7 +272,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
cfg.resume = True cfg.resume = True
elif Logger.get_last_checkpoint_dir(out_dir).exists(): elif Logger.get_last_checkpoint_dir(out_dir).exists():
raise RuntimeError( raise RuntimeError(
f"The configured output directory {Logger.get_last_checkpoint_dir(out_dir)} already exists." f"The configured output directory {Logger.get_last_checkpoint_dir(out_dir)} already exists. If "
"you meant to resume training, please use `resume=true` in your command or yaml configuration."
) )
# log metrics to terminal and wandb # log metrics to terminal and wandb

View File

@ -16,6 +16,7 @@
import inspect import inspect
from pathlib import Path from pathlib import Path
import einops
import pytest import pytest
import torch import torch
from huggingface_hub import PyTorchModelHubMixin from huggingface_hub import PyTorchModelHubMixin
@ -26,6 +27,7 @@ from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import cycle from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.factory import make_env from lerobot.common.envs.factory import make_env
from lerobot.common.envs.utils import preprocess_observation from lerobot.common.envs.utils import preprocess_observation
from lerobot.common.policies.act.modeling_act import ACTTemporalEnsembler
from lerobot.common.policies.factory import ( from lerobot.common.policies.factory import (
_policy_cfg_from_hydra_cfg, _policy_cfg_from_hydra_cfg,
get_policy_and_config_classes, get_policy_and_config_classes,
@ -33,7 +35,7 @@ from lerobot.common.policies.factory import (
) )
from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.policy_protocol import Policy from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.utils.utils import init_hydra_config from lerobot.common.utils.utils import init_hydra_config, seeded_context
from lerobot.scripts.train import make_optimizer_and_scheduler from lerobot.scripts.train import make_optimizer_and_scheduler
from tests.scripts.save_policy_to_safetensors import get_policy_stats from tests.scripts.save_policy_to_safetensors import get_policy_stats
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_cpu, require_env, require_x86_64_kernel from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_cpu, require_env, require_x86_64_kernel
@ -390,3 +392,62 @@ def test_backward_compatibility(env_name, policy_name, extra_overrides, file_nam
assert torch.isclose(param_stats[key], saved_param_stats[key], rtol=50, atol=1e-7).all() assert torch.isclose(param_stats[key], saved_param_stats[key], rtol=50, atol=1e-7).all()
for key in saved_actions: for key in saved_actions:
assert torch.isclose(actions[key], saved_actions[key], rtol=0.1, atol=1e-7).all() assert torch.isclose(actions[key], saved_actions[key], rtol=0.1, atol=1e-7).all()
def test_act_temporal_ensembler():
"""Check that the online method in ACTTemporalEnsembler matches a simple offline calculation."""
temporal_ensemble_coeff = 0.01
chunk_size = 100
episode_length = 101
ensembler = ACTTemporalEnsembler(temporal_ensemble_coeff, chunk_size)
# An batch of arbitrary sequences of 1D actions we wish to compute the average over. We'll keep the
# "action space" in [-1, 1]. Apart from that, there is no real reason for the numbers chosen.
with seeded_context(0):
# Dimension is (batch, episode_length, chunk_size, action_dim(=1))
# Stepping through the episode_length dim is like running inference at each rollout step and getting
# a different action chunk.
batch_seq = torch.stack(
[
torch.rand(episode_length, chunk_size) * 0.05 - 0.6,
torch.rand(episode_length, chunk_size) * 0.02 - 0.01,
torch.rand(episode_length, chunk_size) * 0.2 + 0.3,
],
dim=0,
).unsqueeze(-1) # unsqueeze for action dim
batch_size = batch_seq.shape[0]
# Exponential weighting (normalized). Unsqueeze once to match the position of the `episode_length`
# dimension of `batch_seq`.
weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size)).unsqueeze(-1)
# Simulate stepping through a rollout and computing a batch of actions with model on each step.
for i in range(episode_length):
# Mock a batch of actions.
actions = torch.zeros(size=(batch_size, chunk_size, 1)) + batch_seq[:, i]
online_avg = ensembler.update(actions)
# Simple offline calculation: avg = Σ(aᵢ*wᵢ) / Σ(wᵢ).
# Note: The complicated bit here is the slicing. Think about the (episode_length, chunk_size) grid.
# What we want to do is take diagonal slices across it starting from the left.
# eg: chunk_size=4, episode_length=6
# ┌───────┐
# │0 1 2 3│
# │1 2 3 4│
# │2 3 4 5│
# │3 4 5 6│
# │4 5 6 7│
# │5 6 7 8│
# └───────┘
chunk_indices = torch.arange(min(i, chunk_size - 1), -1, -1)
episode_step_indices = torch.arange(i + 1)[-len(chunk_indices) :]
seq_slice = batch_seq[:, episode_step_indices, chunk_indices]
offline_avg = (
einops.reduce(seq_slice * weights[: i + 1], "b s 1 -> b 1", "sum") / weights[: i + 1].sum()
)
# Sanity check. The average should be between the extrema.
assert torch.all(einops.reduce(seq_slice, "b s 1 -> b 1", "min") <= offline_avg)
assert torch.all(offline_avg <= einops.reduce(seq_slice, "b s 1 -> b 1", "max"))
# Selected atol=1e-4 keeping in mind actions in [-1, 1] and excepting 0.01% error.
assert torch.allclose(online_avg, offline_avg, atol=1e-4)
if __name__ == "__main__":
test_act_temporal_ensembler()