Merge remote-tracking branch 'origin/main' into user/rcadene/2024_07_16_control_robot_v2
This commit is contained in:
commit
895182b272
|
@ -80,7 +80,7 @@ policy:
|
||||||
n_vae_encoder_layers: 4
|
n_vae_encoder_layers: 4
|
||||||
|
|
||||||
# Inference.
|
# Inference.
|
||||||
temporal_ensemble_momentum: null
|
temporal_ensemble_coeff: null
|
||||||
|
|
||||||
# Training and loss computation.
|
# Training and loss computation.
|
||||||
dropout: 0.1
|
dropout: 0.1
|
||||||
|
|
|
@ -35,15 +35,16 @@ from lerobot.common.datasets.utils import (
|
||||||
)
|
)
|
||||||
from lerobot.common.datasets.video_utils import VideoFrame, load_from_videos
|
from lerobot.common.datasets.video_utils import VideoFrame, load_from_videos
|
||||||
|
|
||||||
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
# For maintainers, see lerobot/common/datasets/push_dataset_to_hub/codebase_version.md
|
||||||
CODEBASE_VERSION = "v1.5"
|
CODEBASE_VERSION = "v1.5"
|
||||||
|
|
||||||
|
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||||
|
|
||||||
|
|
||||||
class LeRobotDataset(torch.utils.data.Dataset):
|
class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
repo_id: str,
|
repo_id: str,
|
||||||
version: str | None = CODEBASE_VERSION,
|
|
||||||
root: Path | None = DATA_DIR,
|
root: Path | None = DATA_DIR,
|
||||||
split: str = "train",
|
split: str = "train",
|
||||||
image_transforms: Callable | None = None,
|
image_transforms: Callable | None = None,
|
||||||
|
@ -52,7 +53,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.repo_id = repo_id
|
self.repo_id = repo_id
|
||||||
self.version = version
|
|
||||||
self.root = root
|
self.root = root
|
||||||
self.split = split
|
self.split = split
|
||||||
self.image_transforms = image_transforms
|
self.image_transforms = image_transforms
|
||||||
|
@ -60,16 +60,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
# load data from hub or locally when root is provided
|
# load data from hub or locally when root is provided
|
||||||
# TODO(rcadene, aliberts): implement faster transfer
|
# TODO(rcadene, aliberts): implement faster transfer
|
||||||
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
|
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
|
||||||
self.hf_dataset = load_hf_dataset(repo_id, version, root, split)
|
self.hf_dataset = load_hf_dataset(repo_id, CODEBASE_VERSION, root, split)
|
||||||
if split == "train":
|
if split == "train":
|
||||||
self.episode_data_index = load_episode_data_index(repo_id, version, root)
|
self.episode_data_index = load_episode_data_index(repo_id, CODEBASE_VERSION, root)
|
||||||
else:
|
else:
|
||||||
self.episode_data_index = calculate_episode_data_index(self.hf_dataset)
|
self.episode_data_index = calculate_episode_data_index(self.hf_dataset)
|
||||||
self.hf_dataset = reset_episode_index(self.hf_dataset)
|
self.hf_dataset = reset_episode_index(self.hf_dataset)
|
||||||
self.stats = load_stats(repo_id, version, root)
|
self.stats = load_stats(repo_id, CODEBASE_VERSION, root)
|
||||||
self.info = load_info(repo_id, version, root)
|
self.info = load_info(repo_id, CODEBASE_VERSION, root)
|
||||||
if self.video:
|
if self.video:
|
||||||
self.videos_dir = load_videos(repo_id, version, root)
|
self.videos_dir = load_videos(repo_id, CODEBASE_VERSION, root)
|
||||||
self.video_backend = video_backend if video_backend is not None else "pyav"
|
self.video_backend = video_backend if video_backend is not None else "pyav"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -164,7 +164,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
return (
|
return (
|
||||||
f"{self.__class__.__name__}(\n"
|
f"{self.__class__.__name__}(\n"
|
||||||
f" Repository ID: '{self.repo_id}',\n"
|
f" Repository ID: '{self.repo_id}',\n"
|
||||||
f" Version: '{self.version}',\n"
|
|
||||||
f" Split: '{self.split}',\n"
|
f" Split: '{self.split}',\n"
|
||||||
f" Number of Samples: {self.num_samples},\n"
|
f" Number of Samples: {self.num_samples},\n"
|
||||||
f" Number of Episodes: {self.num_episodes},\n"
|
f" Number of Episodes: {self.num_episodes},\n"
|
||||||
|
@ -173,6 +172,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
f" Camera Keys: {self.camera_keys},\n"
|
f" Camera Keys: {self.camera_keys},\n"
|
||||||
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
|
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
|
||||||
f" Transformations: {self.image_transforms},\n"
|
f" Transformations: {self.image_transforms},\n"
|
||||||
|
f" Codebase Version: {self.info.get('codebase_version', '< v1.6')},\n"
|
||||||
f")"
|
f")"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -180,7 +180,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
def from_preloaded(
|
def from_preloaded(
|
||||||
cls,
|
cls,
|
||||||
repo_id: str = "from_preloaded",
|
repo_id: str = "from_preloaded",
|
||||||
version: str | None = CODEBASE_VERSION,
|
|
||||||
root: Path | None = None,
|
root: Path | None = None,
|
||||||
split: str = "train",
|
split: str = "train",
|
||||||
transform: callable = None,
|
transform: callable = None,
|
||||||
|
@ -204,7 +203,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
# create an empty object of type LeRobotDataset
|
# create an empty object of type LeRobotDataset
|
||||||
obj = cls.__new__(cls)
|
obj = cls.__new__(cls)
|
||||||
obj.repo_id = repo_id
|
obj.repo_id = repo_id
|
||||||
obj.version = version
|
|
||||||
obj.root = root
|
obj.root = root
|
||||||
obj.split = split
|
obj.split = split
|
||||||
obj.image_transforms = transform
|
obj.image_transforms = transform
|
||||||
|
@ -228,7 +226,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
repo_ids: list[str],
|
repo_ids: list[str],
|
||||||
version: str | None = CODEBASE_VERSION,
|
|
||||||
root: Path | None = DATA_DIR,
|
root: Path | None = DATA_DIR,
|
||||||
split: str = "train",
|
split: str = "train",
|
||||||
image_transforms: Callable | None = None,
|
image_transforms: Callable | None = None,
|
||||||
|
@ -242,7 +239,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||||
self._datasets = [
|
self._datasets = [
|
||||||
LeRobotDataset(
|
LeRobotDataset(
|
||||||
repo_id,
|
repo_id,
|
||||||
version=version,
|
|
||||||
root=root,
|
root=root,
|
||||||
split=split,
|
split=split,
|
||||||
delta_timestamps=delta_timestamps,
|
delta_timestamps=delta_timestamps,
|
||||||
|
@ -279,7 +275,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||||
)
|
)
|
||||||
self.disabled_data_keys.update(extra_keys)
|
self.disabled_data_keys.update(extra_keys)
|
||||||
|
|
||||||
self.version = version
|
|
||||||
self.root = root
|
self.root = root
|
||||||
self.split = split
|
self.split = split
|
||||||
self.image_transforms = image_transforms
|
self.image_transforms = image_transforms
|
||||||
|
@ -395,7 +390,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||||
return (
|
return (
|
||||||
f"{self.__class__.__name__}(\n"
|
f"{self.__class__.__name__}(\n"
|
||||||
f" Repository IDs: '{self.repo_ids}',\n"
|
f" Repository IDs: '{self.repo_ids}',\n"
|
||||||
f" Version: '{self.version}',\n"
|
|
||||||
f" Split: '{self.split}',\n"
|
f" Split: '{self.split}',\n"
|
||||||
f" Number of Samples: {self.num_samples},\n"
|
f" Number of Samples: {self.num_samples},\n"
|
||||||
f" Number of Episodes: {self.num_episodes},\n"
|
f" Number of Episodes: {self.num_episodes},\n"
|
||||||
|
|
|
@ -0,0 +1,57 @@
|
||||||
|
## Using / Updating `CODEBASE_VERSION` (for maintainers)
|
||||||
|
|
||||||
|
Since our dataset pushed to the hub are decoupled with the evolution of this repo, we ensure compatibility of
|
||||||
|
the datasets with our code, we use a `CODEBASE_VERSION` (defined in
|
||||||
|
lerobot/common/datasets/lerobot_dataset.py) variable.
|
||||||
|
|
||||||
|
For instance, [`lerobot/pusht`](https://huggingface.co/datasets/lerobot/pusht) has many versions to maintain backward compatibility between LeRobot codebase versions:
|
||||||
|
- [v1.0](https://huggingface.co/datasets/lerobot/pusht/tree/v1.0)
|
||||||
|
- [v1.1](https://huggingface.co/datasets/lerobot/pusht/tree/v1.1)
|
||||||
|
- [v1.2](https://huggingface.co/datasets/lerobot/pusht/tree/v1.2)
|
||||||
|
- [v1.3](https://huggingface.co/datasets/lerobot/pusht/tree/v1.3)
|
||||||
|
- [v1.4](https://huggingface.co/datasets/lerobot/pusht/tree/v1.4)
|
||||||
|
- [v1.5](https://huggingface.co/datasets/lerobot/pusht/tree/v1.5) <-- last version
|
||||||
|
- [main](https://huggingface.co/datasets/lerobot/pusht/tree/main) <-- points to the last version
|
||||||
|
|
||||||
|
Starting with v1.6, every dataset pushed to the hub or saved locally also have this version number in their
|
||||||
|
`info.json` metadata.
|
||||||
|
|
||||||
|
### Uploading a new dataset
|
||||||
|
If you are pushing a new dataset, you don't need to worry about any of the instructions below, nor to be
|
||||||
|
compatible with previous codebase versions. The `push_dataset_to_hub.py` script will automatically tag your
|
||||||
|
dataset with the current `CODEBASE_VERSION`.
|
||||||
|
|
||||||
|
### Updating an existing dataset
|
||||||
|
If you want to update an existing dataset, you need to change the `CODEBASE_VERSION` from `lerobot_dataset.py`
|
||||||
|
before running `push_dataset_to_hub.py`. This is especially useful if you introduce a breaking change
|
||||||
|
intentionally or not (i.e. something not backward compatible such as modifying the reward functions used,
|
||||||
|
deleting some frames at the end of an episode, etc.). That way, people running a previous version of the
|
||||||
|
codebase won't be affected by your change and backward compatibility is maintained.
|
||||||
|
|
||||||
|
However, you will need to update the version of ALL the other datasets so that they have the new
|
||||||
|
`CODEBASE_VERSION` as a branch in their hugging face dataset repository. Don't worry, there is an easy way
|
||||||
|
that doesn't require to run `push_dataset_to_hub.py`. You can just "branch-out" from the `main` branch on HF
|
||||||
|
dataset repo by running this script which corresponds to a `git checkout -b` (so no copy or upload needed):
|
||||||
|
|
||||||
|
```python
|
||||||
|
from huggingface_hub import HfApi
|
||||||
|
|
||||||
|
from lerobot import available_datasets
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||||
|
|
||||||
|
api = HfApi()
|
||||||
|
|
||||||
|
for repo_id in available_datasets:
|
||||||
|
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
|
||||||
|
branches = [b.name for b in dataset_info.branches]
|
||||||
|
if CODEBASE_VERSION in branches:
|
||||||
|
# First check if the newer version already exists.
|
||||||
|
print(f"Found existing branch for {repo_id}. Please contact a member of the core LeRobot team.")
|
||||||
|
print("Exiting early")
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Now create a branch named after the new version by branching out from "main"
|
||||||
|
# which is expected to be the preceding version
|
||||||
|
api.create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION, revision="main")
|
||||||
|
print(f"{repo_id} successfully updated")
|
||||||
|
```
|
|
@ -32,46 +32,41 @@ from pathlib import Path
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
AVAILABLE_RAW_REPO_IDS = [
|
AVAILABLE_RAW_REPO_IDS = [
|
||||||
"cadene/pusht_image_raw",
|
"lerobot-raw/aloha_mobile_cabinet_raw",
|
||||||
"cadene/xarm_lift_medium_image_raw",
|
"lerobot-raw/aloha_mobile_chair_raw",
|
||||||
"cadene/xarm_lift_medium_replay_image_raw",
|
"lerobot-raw/aloha_mobile_elevator_raw",
|
||||||
"cadene/xarm_push_medium_image_raw",
|
"lerobot-raw/aloha_mobile_shrimp_raw",
|
||||||
"cadene/xarm_push_medium_replay_image_raw",
|
"lerobot-raw/aloha_mobile_wash_pan_raw",
|
||||||
"cadene/aloha_sim_insertion_human_image_raw",
|
"lerobot-raw/aloha_mobile_wipe_wine_raw",
|
||||||
"cadene/aloha_sim_insertion_scripted_image_raw",
|
"lerobot-raw/aloha_sim_insertion_human_raw",
|
||||||
"cadene/aloha_sim_transfer_cube_human_image_raw",
|
"lerobot-raw/aloha_sim_insertion_scripted_raw",
|
||||||
"cadene/aloha_sim_transfer_cube_scripted_image_raw",
|
"lerobot-raw/aloha_sim_transfer_cube_human_raw",
|
||||||
"cadene/pusht_raw",
|
"lerobot-raw/aloha_sim_transfer_cube_scripted_raw",
|
||||||
"cadene/xarm_lift_medium_raw",
|
"lerobot-raw/aloha_static_battery_raw",
|
||||||
"cadene/xarm_lift_medium_replay_raw",
|
"lerobot-raw/aloha_static_candy_raw",
|
||||||
"cadene/xarm_push_medium_raw",
|
"lerobot-raw/aloha_static_coffee_new_raw",
|
||||||
"cadene/xarm_push_medium_replay_raw",
|
"lerobot-raw/aloha_static_coffee_raw",
|
||||||
"cadene/aloha_sim_insertion_human_raw",
|
"lerobot-raw/aloha_static_cups_open_raw",
|
||||||
"cadene/aloha_sim_insertion_scripted_raw",
|
"lerobot-raw/aloha_static_fork_pick_up_raw",
|
||||||
"cadene/aloha_sim_transfer_cube_human_raw",
|
"lerobot-raw/aloha_static_pingpong_test_raw",
|
||||||
"cadene/aloha_sim_transfer_cube_scripted_raw",
|
"lerobot-raw/aloha_static_pro_pencil_raw",
|
||||||
"cadene/aloha_mobile_cabinet_raw",
|
"lerobot-raw/aloha_static_screw_driver_raw",
|
||||||
"cadene/aloha_mobile_chair_raw",
|
"lerobot-raw/aloha_static_tape_raw",
|
||||||
"cadene/aloha_mobile_elevator_raw",
|
"lerobot-raw/aloha_static_thread_velcro_raw",
|
||||||
"cadene/aloha_mobile_shrimp_raw",
|
"lerobot-raw/aloha_static_towel_raw",
|
||||||
"cadene/aloha_mobile_wash_pan_raw",
|
"lerobot-raw/aloha_static_vinh_cup_left_raw",
|
||||||
"cadene/aloha_mobile_wipe_wine_raw",
|
"lerobot-raw/aloha_static_vinh_cup_raw",
|
||||||
"cadene/aloha_static_battery_raw",
|
"lerobot-raw/aloha_static_ziploc_slide_raw",
|
||||||
"cadene/aloha_static_candy_raw",
|
"lerobot-raw/pusht_raw",
|
||||||
"cadene/aloha_static_coffee_raw",
|
"lerobot-raw/umi_cup_in_the_wild_raw",
|
||||||
"cadene/aloha_static_coffee_new_raw",
|
"lerobot-raw/unitreeh1_fold_clothes_raw",
|
||||||
"cadene/aloha_static_cups_open_raw",
|
"lerobot-raw/unitreeh1_rearrange_objects_raw",
|
||||||
"cadene/aloha_static_fork_pick_up_raw",
|
"lerobot-raw/unitreeh1_two_robot_greeting_raw",
|
||||||
"cadene/aloha_static_pingpong_test_raw",
|
"lerobot-raw/unitreeh1_warehouse_raw",
|
||||||
"cadene/aloha_static_pro_pencil_raw",
|
"lerobot-raw/xarm_lift_medium_raw",
|
||||||
"cadene/aloha_static_screw_driver_raw",
|
"lerobot-raw/xarm_lift_medium_replay_raw",
|
||||||
"cadene/aloha_static_tape_raw",
|
"lerobot-raw/xarm_push_medium_raw",
|
||||||
"cadene/aloha_static_thread_velcro_raw",
|
"lerobot-raw/xarm_push_medium_replay_raw",
|
||||||
"cadene/aloha_static_towel_raw",
|
|
||||||
"cadene/aloha_static_vinh_cup_raw",
|
|
||||||
"cadene/aloha_static_vinh_cup_left_raw",
|
|
||||||
"cadene/aloha_static_ziploc_slide_raw",
|
|
||||||
"cadene/umi_cup_in_the_wild_raw",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -89,7 +84,6 @@ def download_raw(raw_dir: Path, repo_id: str):
|
||||||
stacklevel=1,
|
stacklevel=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
raw_dir = Path(raw_dir)
|
|
||||||
# Send warning if raw_dir isn't well formated
|
# Send warning if raw_dir isn't well formated
|
||||||
if raw_dir.parts[-2] != user_id or raw_dir.parts[-1] != dataset_id:
|
if raw_dir.parts[-2] != user_id or raw_dir.parts[-1] != dataset_id:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
|
@ -99,7 +93,7 @@ def download_raw(raw_dir: Path, repo_id: str):
|
||||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
logging.info(f"Start downloading from huggingface.co/{user_id} for {dataset_id}")
|
logging.info(f"Start downloading from huggingface.co/{user_id} for {dataset_id}")
|
||||||
snapshot_download(f"{repo_id}", repo_type="dataset", local_dir=raw_dir)
|
snapshot_download(repo_id, repo_type="dataset", local_dir=raw_dir)
|
||||||
logging.info(f"Finish downloading from huggingface.co/{user_id} for {dataset_id}")
|
logging.info(f"Finish downloading from huggingface.co/{user_id} for {dataset_id}")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -28,6 +28,7 @@ import tqdm
|
||||||
from datasets import Dataset, Features, Image, Sequence, Value
|
from datasets import Dataset, Features, Image, Sequence, Value
|
||||||
from PIL import Image as PILImage
|
from PIL import Image as PILImage
|
||||||
|
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
calculate_episode_data_index,
|
calculate_episode_data_index,
|
||||||
|
@ -210,6 +211,7 @@ def from_raw_to_lerobot_format(
|
||||||
hf_dataset = to_hf_dataset(data_dict, video)
|
hf_dataset = to_hf_dataset(data_dict, video)
|
||||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||||
info = {
|
info = {
|
||||||
|
"codebase_version": CODEBASE_VERSION,
|
||||||
"fps": fps,
|
"fps": fps,
|
||||||
"video": video,
|
"video": video,
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,6 +23,7 @@ import torch
|
||||||
from datasets import Dataset, Features, Image, Value
|
from datasets import Dataset, Features, Image, Value
|
||||||
from PIL import Image as PILImage
|
from PIL import Image as PILImage
|
||||||
|
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes
|
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes
|
||||||
from lerobot.common.datasets.utils import calculate_episode_data_index, hf_transform_to_torch
|
from lerobot.common.datasets.utils import calculate_episode_data_index, hf_transform_to_torch
|
||||||
from lerobot.common.datasets.video_utils import VideoFrame
|
from lerobot.common.datasets.video_utils import VideoFrame
|
||||||
|
@ -95,6 +96,7 @@ def from_raw_to_lerobot_format(
|
||||||
hf_dataset = to_hf_dataset(data_dict, video)
|
hf_dataset = to_hf_dataset(data_dict, video)
|
||||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||||
info = {
|
info = {
|
||||||
|
"codebase_version": CODEBASE_VERSION,
|
||||||
"fps": fps,
|
"fps": fps,
|
||||||
"video": video,
|
"video": video,
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,6 +24,7 @@ import pandas as pd
|
||||||
import torch
|
import torch
|
||||||
from datasets import Dataset, Features, Image, Sequence, Value
|
from datasets import Dataset, Features, Image, Sequence, Value
|
||||||
|
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
calculate_episode_data_index,
|
calculate_episode_data_index,
|
||||||
hf_transform_to_torch,
|
hf_transform_to_torch,
|
||||||
|
@ -214,6 +215,7 @@ def from_raw_to_lerobot_format(
|
||||||
hf_dataset = to_hf_dataset(data_df, video)
|
hf_dataset = to_hf_dataset(data_df, video)
|
||||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||||
info = {
|
info = {
|
||||||
|
"codebase_version": CODEBASE_VERSION,
|
||||||
"fps": fps,
|
"fps": fps,
|
||||||
"video": video,
|
"video": video,
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,6 +25,7 @@ import zarr
|
||||||
from datasets import Dataset, Features, Image, Sequence, Value
|
from datasets import Dataset, Features, Image, Sequence, Value
|
||||||
from PIL import Image as PILImage
|
from PIL import Image as PILImage
|
||||||
|
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
calculate_episode_data_index,
|
calculate_episode_data_index,
|
||||||
|
@ -258,6 +259,7 @@ def from_raw_to_lerobot_format(
|
||||||
hf_dataset = to_hf_dataset(data_dict, video, keypoints_instead_of_image)
|
hf_dataset = to_hf_dataset(data_dict, video, keypoints_instead_of_image)
|
||||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||||
info = {
|
info = {
|
||||||
|
"codebase_version": CODEBASE_VERSION,
|
||||||
"fps": fps,
|
"fps": fps,
|
||||||
"video": video if not keypoints_instead_of_image else 0,
|
"video": video if not keypoints_instead_of_image else 0,
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,6 +25,7 @@ import zarr
|
||||||
from datasets import Dataset, Features, Image, Sequence, Value
|
from datasets import Dataset, Features, Image, Sequence, Value
|
||||||
from PIL import Image as PILImage
|
from PIL import Image as PILImage
|
||||||
|
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||||
from lerobot.common.datasets.push_dataset_to_hub._umi_imagecodecs_numcodecs import register_codecs
|
from lerobot.common.datasets.push_dataset_to_hub._umi_imagecodecs_numcodecs import register_codecs
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
|
@ -199,6 +200,7 @@ def from_raw_to_lerobot_format(
|
||||||
hf_dataset = to_hf_dataset(data_dict, video)
|
hf_dataset = to_hf_dataset(data_dict, video)
|
||||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||||
info = {
|
info = {
|
||||||
|
"codebase_version": CODEBASE_VERSION,
|
||||||
"fps": fps,
|
"fps": fps,
|
||||||
"video": video,
|
"video": video,
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,6 +25,7 @@ import tqdm
|
||||||
from datasets import Dataset, Features, Image, Sequence, Value
|
from datasets import Dataset, Features, Image, Sequence, Value
|
||||||
from PIL import Image as PILImage
|
from PIL import Image as PILImage
|
||||||
|
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
calculate_episode_data_index,
|
calculate_episode_data_index,
|
||||||
|
@ -177,6 +178,7 @@ def from_raw_to_lerobot_format(
|
||||||
hf_dataset = to_hf_dataset(data_dict, video)
|
hf_dataset = to_hf_dataset(data_dict, video)
|
||||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||||
info = {
|
info = {
|
||||||
|
"codebase_version": CODEBASE_VERSION,
|
||||||
"fps": fps,
|
"fps": fps,
|
||||||
"video": video,
|
"video": video,
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,13 +15,15 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
import warnings
|
||||||
|
from functools import cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import torch
|
import torch
|
||||||
from datasets import load_dataset, load_from_disk
|
from datasets import load_dataset, load_from_disk
|
||||||
from huggingface_hub import hf_hub_download, snapshot_download
|
from huggingface_hub import HfApi, hf_hub_download, snapshot_download
|
||||||
from PIL import Image as PILImage
|
from PIL import Image as PILImage
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
|
@ -80,7 +82,28 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||||
return items_dict
|
return items_dict
|
||||||
|
|
||||||
|
|
||||||
def load_hf_dataset(repo_id, version, root, split) -> datasets.Dataset:
|
@cache
|
||||||
|
def get_hf_dataset_safe_version(repo_id: str, version: str) -> str:
|
||||||
|
api = HfApi()
|
||||||
|
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
|
||||||
|
branches = [b.name for b in dataset_info.branches]
|
||||||
|
if version not in branches:
|
||||||
|
warnings.warn(
|
||||||
|
f"""You are trying to load a dataset from {repo_id} created with a previous version of the
|
||||||
|
codebase. The following versions are available: {branches}.
|
||||||
|
The requested version ('{version}') is not found. You should be fine since
|
||||||
|
backward compatibility is maintained. If you encounter a problem, contact LeRobot maintainers on
|
||||||
|
Discord ('https://discord.com/invite/s3KuuzsPFb') or open an issue on github.""",
|
||||||
|
stacklevel=1,
|
||||||
|
)
|
||||||
|
if "main" not in branches:
|
||||||
|
raise ValueError(f"Version 'main' not found on {repo_id}")
|
||||||
|
return "main"
|
||||||
|
else:
|
||||||
|
return version
|
||||||
|
|
||||||
|
|
||||||
|
def load_hf_dataset(repo_id: str, version: str, root: Path, split: str) -> datasets.Dataset:
|
||||||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||||
if root is not None:
|
if root is not None:
|
||||||
hf_dataset = load_from_disk(str(Path(root) / repo_id / "train"))
|
hf_dataset = load_from_disk(str(Path(root) / repo_id / "train"))
|
||||||
|
@ -101,7 +124,9 @@ def load_hf_dataset(repo_id, version, root, split) -> datasets.Dataset:
|
||||||
f'`split` ({split}) should either be "train", "train[INT:]", or "train[:INT]"'
|
f'`split` ({split}) should either be "train", "train[INT:]", or "train[:INT]"'
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
hf_dataset = load_dataset(repo_id, revision=version, split=split)
|
safe_version = get_hf_dataset_safe_version(repo_id, version)
|
||||||
|
hf_dataset = load_dataset(repo_id, revision=safe_version, split=split)
|
||||||
|
|
||||||
hf_dataset.set_transform(hf_transform_to_torch)
|
hf_dataset.set_transform(hf_transform_to_torch)
|
||||||
return hf_dataset
|
return hf_dataset
|
||||||
|
|
||||||
|
@ -119,8 +144,9 @@ def load_episode_data_index(repo_id, version, root) -> dict[str, torch.Tensor]:
|
||||||
if root is not None:
|
if root is not None:
|
||||||
path = Path(root) / repo_id / "meta_data" / "episode_data_index.safetensors"
|
path = Path(root) / repo_id / "meta_data" / "episode_data_index.safetensors"
|
||||||
else:
|
else:
|
||||||
|
safe_version = get_hf_dataset_safe_version(repo_id, version)
|
||||||
path = hf_hub_download(
|
path = hf_hub_download(
|
||||||
repo_id, "meta_data/episode_data_index.safetensors", repo_type="dataset", revision=version
|
repo_id, "meta_data/episode_data_index.safetensors", repo_type="dataset", revision=safe_version
|
||||||
)
|
)
|
||||||
|
|
||||||
return load_file(path)
|
return load_file(path)
|
||||||
|
@ -137,7 +163,10 @@ def load_stats(repo_id, version, root) -> dict[str, dict[str, torch.Tensor]]:
|
||||||
if root is not None:
|
if root is not None:
|
||||||
path = Path(root) / repo_id / "meta_data" / "stats.safetensors"
|
path = Path(root) / repo_id / "meta_data" / "stats.safetensors"
|
||||||
else:
|
else:
|
||||||
path = hf_hub_download(repo_id, "meta_data/stats.safetensors", repo_type="dataset", revision=version)
|
safe_version = get_hf_dataset_safe_version(repo_id, version)
|
||||||
|
path = hf_hub_download(
|
||||||
|
repo_id, "meta_data/stats.safetensors", repo_type="dataset", revision=safe_version
|
||||||
|
)
|
||||||
|
|
||||||
stats = load_file(path)
|
stats = load_file(path)
|
||||||
return unflatten_dict(stats)
|
return unflatten_dict(stats)
|
||||||
|
@ -154,7 +183,8 @@ def load_info(repo_id, version, root) -> dict:
|
||||||
if root is not None:
|
if root is not None:
|
||||||
path = Path(root) / repo_id / "meta_data" / "info.json"
|
path = Path(root) / repo_id / "meta_data" / "info.json"
|
||||||
else:
|
else:
|
||||||
path = hf_hub_download(repo_id, "meta_data/info.json", repo_type="dataset", revision=version)
|
safe_version = get_hf_dataset_safe_version(repo_id, version)
|
||||||
|
path = hf_hub_download(repo_id, "meta_data/info.json", repo_type="dataset", revision=safe_version)
|
||||||
|
|
||||||
with open(path) as f:
|
with open(path) as f:
|
||||||
info = json.load(f)
|
info = json.load(f)
|
||||||
|
@ -166,7 +196,8 @@ def load_videos(repo_id, version, root) -> Path:
|
||||||
path = Path(root) / repo_id / "videos"
|
path = Path(root) / repo_id / "videos"
|
||||||
else:
|
else:
|
||||||
# TODO(rcadene): we download the whole repo here. see if we can avoid this
|
# TODO(rcadene): we download the whole repo here. see if we can avoid this
|
||||||
repo_dir = snapshot_download(repo_id, repo_type="dataset", revision=version)
|
safe_version = get_hf_dataset_safe_version(repo_id, version)
|
||||||
|
repo_dir = snapshot_download(repo_id, repo_type="dataset", revision=safe_version)
|
||||||
path = Path(repo_dir) / "videos"
|
path = Path(repo_dir) / "videos"
|
||||||
|
|
||||||
return path
|
return path
|
||||||
|
|
|
@ -76,12 +76,10 @@ class ACTConfig:
|
||||||
documentation in the policy class).
|
documentation in the policy class).
|
||||||
latent_dim: The VAE's latent dimension.
|
latent_dim: The VAE's latent dimension.
|
||||||
n_vae_encoder_layers: The number of transformer layers to use for the VAE's encoder.
|
n_vae_encoder_layers: The number of transformer layers to use for the VAE's encoder.
|
||||||
temporal_ensemble_momentum: Exponential moving average (EMA) momentum parameter (α) for ensembling
|
temporal_ensemble_coeff: Coefficient for the exponential weighting scheme to apply for temporal
|
||||||
actions for a given time step over multiple policy invocations. Updates are calculated as:
|
ensembling. Defaults to None which means temporal ensembling is not used. `n_action_steps` must be
|
||||||
x⁻ₙ = αx⁻ₙ₋₁ + (1-α)xₙ. Note that the ACT paper and original ACT code describes a different
|
1 when using this feature, as inference needs to happen at every step to form an ensemble. For
|
||||||
parameter here: they refer to a weighting scheme wᵢ = exp(-m⋅i) and set m = 0.01. With our
|
more information on how ensembling works, please see `ACTTemporalEnsembler`.
|
||||||
formulation, this is equivalent to α = exp(-0.01) ≈ 0.99. When this parameter is provided, we
|
|
||||||
require `n_action_steps == 1` (since we need to query the policy every step anyway).
|
|
||||||
dropout: Dropout to use in the transformer layers (see code for details).
|
dropout: Dropout to use in the transformer layers (see code for details).
|
||||||
kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective
|
kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective
|
||||||
is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`.
|
is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`.
|
||||||
|
@ -139,7 +137,8 @@ class ACTConfig:
|
||||||
n_vae_encoder_layers: int = 4
|
n_vae_encoder_layers: int = 4
|
||||||
|
|
||||||
# Inference.
|
# Inference.
|
||||||
temporal_ensemble_momentum: float | None = None
|
# Note: the value used in ACT when temporal ensembling is enabled is 0.01.
|
||||||
|
temporal_ensemble_coeff: float | None = None
|
||||||
|
|
||||||
# Training and loss computation.
|
# Training and loss computation.
|
||||||
dropout: float = 0.1
|
dropout: float = 0.1
|
||||||
|
@ -151,7 +150,7 @@ class ACTConfig:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
||||||
)
|
)
|
||||||
if self.temporal_ensemble_momentum is not None and self.n_action_steps > 1:
|
if self.temporal_ensemble_coeff is not None and self.n_action_steps > 1:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"`n_action_steps` must be 1 when using temporal ensembling. This is "
|
"`n_action_steps` must be 1 when using temporal ensembling. This is "
|
||||||
"because the policy needs to be queried every step to compute the ensembled action."
|
"because the policy needs to be queried every step to compute the ensembled action."
|
||||||
|
|
|
@ -77,12 +77,15 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
|
|
||||||
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||||
|
|
||||||
|
if config.temporal_ensemble_coeff is not None:
|
||||||
|
self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size)
|
||||||
|
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""This should be called whenever the environment is reset."""
|
"""This should be called whenever the environment is reset."""
|
||||||
if self.config.temporal_ensemble_momentum is not None:
|
if self.config.temporal_ensemble_coeff is not None:
|
||||||
self._ensembled_actions = None
|
self.temporal_ensembler.reset()
|
||||||
else:
|
else:
|
||||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||||
|
|
||||||
|
@ -100,24 +103,12 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
if len(self.expected_image_keys) > 0:
|
if len(self.expected_image_keys) > 0:
|
||||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||||
|
|
||||||
# If we are doing temporal ensembling, keep track of the exponential moving average (EMA), and return
|
# If we are doing temporal ensembling, do online updates where we keep track of the number of actions
|
||||||
# the first action.
|
# we are ensembling over.
|
||||||
if self.config.temporal_ensemble_momentum is not None:
|
if self.config.temporal_ensemble_coeff is not None:
|
||||||
actions = self.model(batch)[0] # (batch_size, chunk_size, action_dim)
|
actions = self.model(batch)[0] # (batch_size, chunk_size, action_dim)
|
||||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||||
if self._ensembled_actions is None:
|
action = self.temporal_ensembler.update(actions)
|
||||||
# Initializes `self._ensembled_action` to the sequence of actions predicted during the first
|
|
||||||
# time step of the episode.
|
|
||||||
self._ensembled_actions = actions.clone()
|
|
||||||
else:
|
|
||||||
# self._ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute
|
|
||||||
# the EMA update for those entries.
|
|
||||||
alpha = self.config.temporal_ensemble_momentum
|
|
||||||
self._ensembled_actions = alpha * self._ensembled_actions + (1 - alpha) * actions[:, :-1]
|
|
||||||
# The last action, which has no prior moving average, needs to get concatenated onto the end.
|
|
||||||
self._ensembled_actions = torch.cat([self._ensembled_actions, actions[:, -1:]], dim=1)
|
|
||||||
# "Consume" the first action.
|
|
||||||
action, self._ensembled_actions = self._ensembled_actions[:, 0], self._ensembled_actions[:, 1:]
|
|
||||||
return action
|
return action
|
||||||
|
|
||||||
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
|
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
|
||||||
|
@ -162,6 +153,97 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
return loss_dict
|
return loss_dict
|
||||||
|
|
||||||
|
|
||||||
|
class ACTTemporalEnsembler:
|
||||||
|
def __init__(self, temporal_ensemble_coeff: float, chunk_size: int) -> None:
|
||||||
|
"""Temporal ensembling as described in Algorithm 2 of https://arxiv.org/abs/2304.13705.
|
||||||
|
|
||||||
|
The weights are calculated as wᵢ = exp(-temporal_ensemble_coeff * i) where w₀ is the oldest action.
|
||||||
|
They are then normalized to sum to 1 by dividing by Σwᵢ. Here's some intuition around how the
|
||||||
|
coefficient works:
|
||||||
|
- Setting it to 0 uniformly weighs all actions.
|
||||||
|
- Setting it positive gives more weight to older actions.
|
||||||
|
- Setting it negative gives more weight to newer actions.
|
||||||
|
NOTE: The default value for `temporal_ensemble_coeff` used by the original ACT work is 0.01. This
|
||||||
|
results in older actions being weighed more highly than newer actions (the experiments documented in
|
||||||
|
https://github.com/huggingface/lerobot/pull/319 hint at why highly weighing new actions might be
|
||||||
|
detrimental: doing so aggressively may diminish the benefits of action chunking).
|
||||||
|
|
||||||
|
Here we use an online method for computing the average rather than caching a history of actions in
|
||||||
|
order to compute the average offline. For a simple 1D sequence it looks something like:
|
||||||
|
|
||||||
|
```
|
||||||
|
import torch
|
||||||
|
|
||||||
|
seq = torch.linspace(8, 8.5, 100)
|
||||||
|
print(seq)
|
||||||
|
|
||||||
|
m = 0.01
|
||||||
|
exp_weights = torch.exp(-m * torch.arange(len(seq)))
|
||||||
|
print(exp_weights)
|
||||||
|
|
||||||
|
# Calculate offline
|
||||||
|
avg = (exp_weights * seq).sum() / exp_weights.sum()
|
||||||
|
print("offline", avg)
|
||||||
|
|
||||||
|
# Calculate online
|
||||||
|
for i, item in enumerate(seq):
|
||||||
|
if i == 0:
|
||||||
|
avg = item
|
||||||
|
continue
|
||||||
|
avg *= exp_weights[:i].sum()
|
||||||
|
avg += item * exp_weights[i]
|
||||||
|
avg /= exp_weights[:i+1].sum()
|
||||||
|
print("online", avg)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
self.chunk_size = chunk_size
|
||||||
|
self.ensemble_weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size))
|
||||||
|
self.ensemble_weights_cumsum = torch.cumsum(self.ensemble_weights, dim=0)
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""Resets the online computation variables."""
|
||||||
|
self.ensembled_actions = None
|
||||||
|
# (chunk_size,) count of how many actions are in the ensemble for each time step in the sequence.
|
||||||
|
self.ensembled_actions_count = None
|
||||||
|
|
||||||
|
def update(self, actions: Tensor) -> Tensor:
|
||||||
|
"""
|
||||||
|
Takes a (batch, chunk_size, action_dim) sequence of actions, update the temporal ensemble for all
|
||||||
|
time steps, and pop/return the next batch of actions in the sequence.
|
||||||
|
"""
|
||||||
|
self.ensemble_weights = self.ensemble_weights.to(device=actions.device)
|
||||||
|
self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to(device=actions.device)
|
||||||
|
if self.ensembled_actions is None:
|
||||||
|
# Initializes `self._ensembled_action` to the sequence of actions predicted during the first
|
||||||
|
# time step of the episode.
|
||||||
|
self.ensembled_actions = actions.clone()
|
||||||
|
# Note: The last dimension is unsqueeze to make sure we can broadcast properly for tensor
|
||||||
|
# operations later.
|
||||||
|
self.ensembled_actions_count = torch.ones(
|
||||||
|
(self.chunk_size, 1), dtype=torch.long, device=self.ensembled_actions.device
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# self.ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute
|
||||||
|
# the online update for those entries.
|
||||||
|
self.ensembled_actions *= self.ensemble_weights_cumsum[self.ensembled_actions_count - 1]
|
||||||
|
self.ensembled_actions += actions[:, :-1] * self.ensemble_weights[self.ensembled_actions_count]
|
||||||
|
self.ensembled_actions /= self.ensemble_weights_cumsum[self.ensembled_actions_count]
|
||||||
|
self.ensembled_actions_count = torch.clamp(self.ensembled_actions_count + 1, max=self.chunk_size)
|
||||||
|
# The last action, which has no prior online average, needs to get concatenated onto the end.
|
||||||
|
self.ensembled_actions = torch.cat([self.ensembled_actions, actions[:, -1:]], dim=1)
|
||||||
|
self.ensembled_actions_count = torch.cat(
|
||||||
|
[self.ensembled_actions_count, torch.ones_like(self.ensembled_actions_count[-1:])]
|
||||||
|
)
|
||||||
|
# "Consume" the first action.
|
||||||
|
action, self.ensembled_actions, self.ensembled_actions_count = (
|
||||||
|
self.ensembled_actions[:, 0],
|
||||||
|
self.ensembled_actions[:, 1:],
|
||||||
|
self.ensembled_actions_count[1:],
|
||||||
|
)
|
||||||
|
return action
|
||||||
|
|
||||||
|
|
||||||
class ACT(nn.Module):
|
class ACT(nn.Module):
|
||||||
"""Action Chunking Transformer: The underlying neural network for ACTPolicy.
|
"""Action Chunking Transformer: The underlying neural network for ACTPolicy.
|
||||||
|
|
||||||
|
|
|
@ -75,7 +75,7 @@ policy:
|
||||||
n_vae_encoder_layers: 4
|
n_vae_encoder_layers: 4
|
||||||
|
|
||||||
# Inference.
|
# Inference.
|
||||||
temporal_ensemble_momentum: null
|
temporal_ensemble_coeff: null
|
||||||
|
|
||||||
# Training and loss computation.
|
# Training and loss computation.
|
||||||
dropout: 0.1
|
dropout: 0.1
|
||||||
|
|
|
@ -107,7 +107,7 @@ policy:
|
||||||
n_vae_encoder_layers: 4
|
n_vae_encoder_layers: 4
|
||||||
|
|
||||||
# Inference.
|
# Inference.
|
||||||
temporal_ensemble_momentum: null
|
temporal_ensemble_coeff: null
|
||||||
|
|
||||||
# Training and loss computation.
|
# Training and loss computation.
|
||||||
dropout: 0.1
|
dropout: 0.1
|
||||||
|
|
|
@ -103,7 +103,7 @@ policy:
|
||||||
n_vae_encoder_layers: 4
|
n_vae_encoder_layers: 4
|
||||||
|
|
||||||
# Inference.
|
# Inference.
|
||||||
temporal_ensemble_momentum: null
|
temporal_ensemble_coeff: null
|
||||||
|
|
||||||
# Training and loss computation.
|
# Training and loss computation.
|
||||||
dropout: 0.1
|
dropout: 0.1
|
||||||
|
|
|
@ -494,6 +494,7 @@ def record_dataset(
|
||||||
hf_dataset = to_hf_dataset(data_dict, video)
|
hf_dataset = to_hf_dataset(data_dict, video)
|
||||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||||
info = {
|
info = {
|
||||||
|
"codebase_version": CODEBASE_VERSION,
|
||||||
"fps": fps,
|
"fps": fps,
|
||||||
"video": video,
|
"video": video,
|
||||||
}
|
}
|
||||||
|
|
|
@ -40,60 +40,6 @@ python lerobot/scripts/push_dataset_to_hub.py \
|
||||||
--raw-format umi_zarr \
|
--raw-format umi_zarr \
|
||||||
--repo-id lerobot/umi_cup_in_the_wild
|
--repo-id lerobot/umi_cup_in_the_wild
|
||||||
```
|
```
|
||||||
|
|
||||||
**WARNING: Updating an existing dataset**
|
|
||||||
|
|
||||||
If you want to update an existing dataset, you need to change the `CODEBASE_VERSION` from `lerobot_dataset.py`
|
|
||||||
before running `push_dataset_to_hub.py`. This is especially useful if you introduce a breaking change
|
|
||||||
intentionally or not (i.e. something not backward compatible such as modifying the reward functions used,
|
|
||||||
deleting some frames at the end of an episode, etc.). That way, people running a previous version of the
|
|
||||||
codebase won't be affected by your change and backward compatibility is maintained.
|
|
||||||
|
|
||||||
For instance, Pusht has many versions to maintain backward compatibility between LeRobot codebase versions:
|
|
||||||
- [v1.0](https://huggingface.co/datasets/lerobot/pusht/tree/v1.0)
|
|
||||||
- [v1.1](https://huggingface.co/datasets/lerobot/pusht/tree/v1.1)
|
|
||||||
- [v1.2](https://huggingface.co/datasets/lerobot/pusht/tree/v1.2)
|
|
||||||
- [v1.3](https://huggingface.co/datasets/lerobot/pusht/tree/v1.3)
|
|
||||||
- [v1.4](https://huggingface.co/datasets/lerobot/pusht/tree/v1.4)
|
|
||||||
- [v1.5](https://huggingface.co/datasets/lerobot/pusht/tree/v1.5) <-- last version
|
|
||||||
- [main](https://huggingface.co/datasets/lerobot/pusht/tree/main) <-- points to the last version
|
|
||||||
|
|
||||||
However, you will need to update the version of ALL the other datasets so that they have the new
|
|
||||||
`CODEBASE_VERSION` as a branch in their hugging face dataset repository. Don't worry, there is an easy way
|
|
||||||
that doesn't require to run `push_dataset_to_hub.py`. You can just "branch-out" from the `main` branch on HF
|
|
||||||
dataset repo by running this script which corresponds to a `git checkout -b` (so no copy or upload needed):
|
|
||||||
|
|
||||||
```python
|
|
||||||
import os
|
|
||||||
|
|
||||||
from huggingface_hub import create_branch, hf_hub_download
|
|
||||||
from huggingface_hub.utils._errors import RepositoryNotFoundError
|
|
||||||
|
|
||||||
from lerobot import available_datasets
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
|
||||||
|
|
||||||
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" # makes it easier to see the print-out below
|
|
||||||
|
|
||||||
NEW_CODEBASE_VERSION = "v1.5" # REPLACE THIS WITH YOUR DESIRED VERSION
|
|
||||||
|
|
||||||
for repo_id in available_datasets:
|
|
||||||
# First check if the newer version already exists.
|
|
||||||
try:
|
|
||||||
hf_hub_download(
|
|
||||||
repo_id=repo_id, repo_type="dataset", filename=".gitattributes", revision=NEW_CODEBASE_VERSION
|
|
||||||
)
|
|
||||||
print(f"Found existing branch for {repo_id}. Please contact a member of the core LeRobot team.")
|
|
||||||
print("Exiting early")
|
|
||||||
break
|
|
||||||
except RepositoryNotFoundError:
|
|
||||||
# Now create a branch.
|
|
||||||
create_branch(repo_id, repo_type="dataset", branch=NEW_CODEBASE_VERSION, revision=CODEBASE_VERSION)
|
|
||||||
print(f"{repo_id} successfully updated")
|
|
||||||
|
|
||||||
```
|
|
||||||
|
|
||||||
On the other hand, if you are pushing a new dataset, you don't need to worry about any of the instructions
|
|
||||||
above, nor to be compatible with previous codebase versions.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
@ -104,7 +50,7 @@ from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from huggingface_hub import HfApi, create_branch
|
from huggingface_hub import HfApi
|
||||||
from safetensors.torch import save_file
|
from safetensors.torch import save_file
|
||||||
|
|
||||||
from lerobot.common.datasets.compute_stats import compute_stats
|
from lerobot.common.datasets.compute_stats import compute_stats
|
||||||
|
@ -270,7 +216,8 @@ def push_dataset_to_hub(
|
||||||
push_meta_data_to_hub(repo_id, meta_data_dir, revision="main")
|
push_meta_data_to_hub(repo_id, meta_data_dir, revision="main")
|
||||||
if video:
|
if video:
|
||||||
push_videos_to_hub(repo_id, videos_dir, revision="main")
|
push_videos_to_hub(repo_id, videos_dir, revision="main")
|
||||||
create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION)
|
api = HfApi()
|
||||||
|
api.create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION)
|
||||||
|
|
||||||
if tests_data_dir:
|
if tests_data_dir:
|
||||||
# get the first episode
|
# get the first episode
|
||||||
|
|
|
@ -272,7 +272,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
cfg.resume = True
|
cfg.resume = True
|
||||||
elif Logger.get_last_checkpoint_dir(out_dir).exists():
|
elif Logger.get_last_checkpoint_dir(out_dir).exists():
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"The configured output directory {Logger.get_last_checkpoint_dir(out_dir)} already exists."
|
f"The configured output directory {Logger.get_last_checkpoint_dir(out_dir)} already exists. If "
|
||||||
|
"you meant to resume training, please use `resume=true` in your command or yaml configuration."
|
||||||
)
|
)
|
||||||
|
|
||||||
# log metrics to terminal and wandb
|
# log metrics to terminal and wandb
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
import inspect
|
import inspect
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import einops
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from huggingface_hub import PyTorchModelHubMixin
|
from huggingface_hub import PyTorchModelHubMixin
|
||||||
|
@ -26,6 +27,7 @@ from lerobot.common.datasets.factory import make_dataset
|
||||||
from lerobot.common.datasets.utils import cycle
|
from lerobot.common.datasets.utils import cycle
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
from lerobot.common.envs.utils import preprocess_observation
|
from lerobot.common.envs.utils import preprocess_observation
|
||||||
|
from lerobot.common.policies.act.modeling_act import ACTTemporalEnsembler
|
||||||
from lerobot.common.policies.factory import (
|
from lerobot.common.policies.factory import (
|
||||||
_policy_cfg_from_hydra_cfg,
|
_policy_cfg_from_hydra_cfg,
|
||||||
get_policy_and_config_classes,
|
get_policy_and_config_classes,
|
||||||
|
@ -33,7 +35,7 @@ from lerobot.common.policies.factory import (
|
||||||
)
|
)
|
||||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||||
from lerobot.common.policies.policy_protocol import Policy
|
from lerobot.common.policies.policy_protocol import Policy
|
||||||
from lerobot.common.utils.utils import init_hydra_config
|
from lerobot.common.utils.utils import init_hydra_config, seeded_context
|
||||||
from lerobot.scripts.train import make_optimizer_and_scheduler
|
from lerobot.scripts.train import make_optimizer_and_scheduler
|
||||||
from tests.scripts.save_policy_to_safetensors import get_policy_stats
|
from tests.scripts.save_policy_to_safetensors import get_policy_stats
|
||||||
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_cpu, require_env, require_x86_64_kernel
|
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_cpu, require_env, require_x86_64_kernel
|
||||||
|
@ -390,3 +392,62 @@ def test_backward_compatibility(env_name, policy_name, extra_overrides, file_nam
|
||||||
assert torch.isclose(param_stats[key], saved_param_stats[key], rtol=50, atol=1e-7).all()
|
assert torch.isclose(param_stats[key], saved_param_stats[key], rtol=50, atol=1e-7).all()
|
||||||
for key in saved_actions:
|
for key in saved_actions:
|
||||||
assert torch.isclose(actions[key], saved_actions[key], rtol=0.1, atol=1e-7).all()
|
assert torch.isclose(actions[key], saved_actions[key], rtol=0.1, atol=1e-7).all()
|
||||||
|
|
||||||
|
|
||||||
|
def test_act_temporal_ensembler():
|
||||||
|
"""Check that the online method in ACTTemporalEnsembler matches a simple offline calculation."""
|
||||||
|
temporal_ensemble_coeff = 0.01
|
||||||
|
chunk_size = 100
|
||||||
|
episode_length = 101
|
||||||
|
ensembler = ACTTemporalEnsembler(temporal_ensemble_coeff, chunk_size)
|
||||||
|
# An batch of arbitrary sequences of 1D actions we wish to compute the average over. We'll keep the
|
||||||
|
# "action space" in [-1, 1]. Apart from that, there is no real reason for the numbers chosen.
|
||||||
|
with seeded_context(0):
|
||||||
|
# Dimension is (batch, episode_length, chunk_size, action_dim(=1))
|
||||||
|
# Stepping through the episode_length dim is like running inference at each rollout step and getting
|
||||||
|
# a different action chunk.
|
||||||
|
batch_seq = torch.stack(
|
||||||
|
[
|
||||||
|
torch.rand(episode_length, chunk_size) * 0.05 - 0.6,
|
||||||
|
torch.rand(episode_length, chunk_size) * 0.02 - 0.01,
|
||||||
|
torch.rand(episode_length, chunk_size) * 0.2 + 0.3,
|
||||||
|
],
|
||||||
|
dim=0,
|
||||||
|
).unsqueeze(-1) # unsqueeze for action dim
|
||||||
|
batch_size = batch_seq.shape[0]
|
||||||
|
# Exponential weighting (normalized). Unsqueeze once to match the position of the `episode_length`
|
||||||
|
# dimension of `batch_seq`.
|
||||||
|
weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size)).unsqueeze(-1)
|
||||||
|
|
||||||
|
# Simulate stepping through a rollout and computing a batch of actions with model on each step.
|
||||||
|
for i in range(episode_length):
|
||||||
|
# Mock a batch of actions.
|
||||||
|
actions = torch.zeros(size=(batch_size, chunk_size, 1)) + batch_seq[:, i]
|
||||||
|
online_avg = ensembler.update(actions)
|
||||||
|
# Simple offline calculation: avg = Σ(aᵢ*wᵢ) / Σ(wᵢ).
|
||||||
|
# Note: The complicated bit here is the slicing. Think about the (episode_length, chunk_size) grid.
|
||||||
|
# What we want to do is take diagonal slices across it starting from the left.
|
||||||
|
# eg: chunk_size=4, episode_length=6
|
||||||
|
# ┌───────┐
|
||||||
|
# │0 1 2 3│
|
||||||
|
# │1 2 3 4│
|
||||||
|
# │2 3 4 5│
|
||||||
|
# │3 4 5 6│
|
||||||
|
# │4 5 6 7│
|
||||||
|
# │5 6 7 8│
|
||||||
|
# └───────┘
|
||||||
|
chunk_indices = torch.arange(min(i, chunk_size - 1), -1, -1)
|
||||||
|
episode_step_indices = torch.arange(i + 1)[-len(chunk_indices) :]
|
||||||
|
seq_slice = batch_seq[:, episode_step_indices, chunk_indices]
|
||||||
|
offline_avg = (
|
||||||
|
einops.reduce(seq_slice * weights[: i + 1], "b s 1 -> b 1", "sum") / weights[: i + 1].sum()
|
||||||
|
)
|
||||||
|
# Sanity check. The average should be between the extrema.
|
||||||
|
assert torch.all(einops.reduce(seq_slice, "b s 1 -> b 1", "min") <= offline_avg)
|
||||||
|
assert torch.all(offline_avg <= einops.reduce(seq_slice, "b s 1 -> b 1", "max"))
|
||||||
|
# Selected atol=1e-4 keeping in mind actions in [-1, 1] and excepting 0.01% error.
|
||||||
|
assert torch.allclose(online_avg, offline_avg, atol=1e-4)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_act_temporal_ensembler()
|
||||||
|
|
Loading…
Reference in New Issue