diff --git a/lerobot/common/datasets/aloha.py b/lerobot/common/datasets/aloha.py new file mode 100644 index 00000000..7397327d --- /dev/null +++ b/lerobot/common/datasets/aloha.py @@ -0,0 +1,185 @@ +import logging +from pathlib import Path +from typing import Callable + +import einops +import gdown +import h5py +import torch +import torchrl +import tqdm +from tensordict import TensorDict +from torchrl.data.replay_buffers.samplers import SliceSampler +from torchrl.data.replay_buffers.storages import TensorStorage +from torchrl.data.replay_buffers.writers import Writer + +from lerobot.common.datasets.abstract import AbstractExperienceReplay + +DATASET_IDS = [ + "aloha_sim_insertion_human", + "aloha_sim_insertion_scripted", + "aloha_sim_transfer_cube_human", + "aloha_sim_transfer_cube_scripted", +] + +FOLDER_URLS = { + "aloha_sim_insertion_human": "https://drive.google.com/drive/folders/1RgyD0JgTX30H4IM5XZn8I3zSV_mr8pyF", + "aloha_sim_insertion_scripted": "https://drive.google.com/drive/folders/1TsojQQSXtHEoGnqgJ3gmpPQR2DPLtS2N", + "aloha_sim_transfer_cube_human": "https://drive.google.com/drive/folders/1sc-E4QYW7A0o23m1u2VWNGVq5smAsfCo", + "aloha_sim_transfer_cube_scripted": "https://drive.google.com/drive/folders/1aRyoOhQwxhyt1J8XgEig4s6kzaw__LXj", +} + +EP48_URLS = { + "aloha_sim_insertion_human": "https://drive.google.com/file/d/18Cudl6nikDtgRolea7je8iF_gGKzynOP/view?usp=drive_link", + "aloha_sim_insertion_scripted": "https://drive.google.com/file/d/1wfMSZ24oOh5KR_0aaP3Cnu_c4ZCveduB/view?usp=drive_link", + "aloha_sim_transfer_cube_human": "https://drive.google.com/file/d/18smMymtr8tIxaNUQ61gW6dG50pt3MvGq/view?usp=drive_link", + "aloha_sim_transfer_cube_scripted": "https://drive.google.com/file/d/1pnGIOd-E4-rhz2P3VxpknMKRZCoKt6eI/view?usp=drive_link", +} + +EP49_URLS = { + "aloha_sim_insertion_human": "https://drive.google.com/file/d/1C1kZYyROzs-PrLc0SkDgUgMi4-L3lauE/view?usp=drive_link", + "aloha_sim_insertion_scripted": "https://drive.google.com/file/d/17EuCUWS6uCCr6yyNzpXdcdE-_TTNCKtf/view?usp=drive_link", + "aloha_sim_transfer_cube_human": "https://drive.google.com/file/d/1Nk7l53d9sJoGDBKAOnNrExX5nLacATc6/view?usp=drive_link", + "aloha_sim_transfer_cube_scripted": "https://drive.google.com/file/d/1GKReZHrXU73NMiC5zKCq_UtqPVtYq8eo/view?usp=drive_link", +} + +NUM_EPISODES = { + "aloha_sim_insertion_human": 50, + "aloha_sim_insertion_scripted": 50, + "aloha_sim_transfer_cube_human": 50, + "aloha_sim_transfer_cube_scripted": 50, +} + +EPISODE_LEN = { + "aloha_sim_insertion_human": 500, + "aloha_sim_insertion_scripted": 400, + "aloha_sim_transfer_cube_human": 400, + "aloha_sim_transfer_cube_scripted": 400, +} + +CAMERAS = { + "aloha_sim_insertion_human": ["top"], + "aloha_sim_insertion_scripted": ["top"], + "aloha_sim_transfer_cube_human": ["top"], + "aloha_sim_transfer_cube_scripted": ["top"], +} + + +def download(data_dir, dataset_id): + assert dataset_id in DATASET_IDS + assert dataset_id in FOLDER_URLS + assert dataset_id in EP48_URLS + assert dataset_id in EP49_URLS + + data_dir.mkdir(parents=True, exist_ok=True) + + gdown.download_folder(FOLDER_URLS[dataset_id], output=data_dir) + + # because of the 50 files limit per directory, two files episode 48 and 49 were missing + gdown.download(EP48_URLS[dataset_id], output=data_dir / "episode_48.hdf5", fuzzy=True) + gdown.download(EP49_URLS[dataset_id], output=data_dir / "episode_49.hdf5", fuzzy=True) + + +class AlohaExperienceReplay(AbstractExperienceReplay): + def __init__( + self, + dataset_id: str, + batch_size: int = None, + *, + shuffle: bool = True, + root: Path = None, + pin_memory: bool = False, + prefetch: int = None, + sampler: SliceSampler = None, + collate_fn: Callable = None, + writer: Writer = None, + transform: "torchrl.envs.Transform" = None, + ): + assert dataset_id in DATASET_IDS + + super().__init__( + dataset_id, + batch_size, + shuffle=shuffle, + root=root, + pin_memory=pin_memory, + prefetch=prefetch, + sampler=sampler, + collate_fn=collate_fn, + writer=writer, + transform=transform, + ) + + @property + def stats_patterns(self) -> dict: + d = { + ("observation", "state"): "b c -> 1 c", + ("action"): "b c -> 1 c", + } + for cam in CAMERAS[self.dataset_id]: + d[("observation", "image", cam)] = "b c h w -> 1 c 1 1" + return d + + @property + def image_keys(self) -> list: + return [("observation", "image", cam) for cam in CAMERAS[self.dataset_id]] + + # def _is_downloaded(self) -> bool: + # return False + + def _download_and_preproc(self): + raw_dir = self.data_dir.parent / f"{self.data_dir.name}_raw" + if not raw_dir.is_dir(): + download(raw_dir, self.dataset_id) + + total_num_frames = 0 + logging.info("Compute total number of frames to initialize offline buffer") + for ep_id in range(NUM_EPISODES[self.dataset_id]): + ep_path = raw_dir / f"episode_{ep_id}.hdf5" + with h5py.File(ep_path, "r") as ep: + total_num_frames += ep["/action"].shape[0] - 1 + logging.info(f"{total_num_frames=}") + + logging.info("Initialize and feed offline buffer") + idxtd = 0 + for ep_id in tqdm.tqdm(range(NUM_EPISODES[self.dataset_id])): + ep_path = raw_dir / f"episode_{ep_id}.hdf5" + with h5py.File(ep_path, "r") as ep: + ep_num_frames = ep["/action"].shape[0] + + # last step of demonstration is considered done + done = torch.zeros(ep_num_frames, 1, dtype=torch.bool) + done[-1] = True + + state = torch.from_numpy(ep["/observations/qpos"][:]) + action = torch.from_numpy(ep["/action"][:]) + + ep_td = TensorDict( + { + ("observation", "state"): state[:-1], + "action": action[:-1], + "episode": torch.tensor([ep_id] * (ep_num_frames - 1)), + "frame_id": torch.arange(0, ep_num_frames - 1, 1), + ("next", "observation", "state"): state[1:], + # TODO: compute reward and success + # ("next", "reward"): reward[1:], + ("next", "done"): done[1:], + # ("next", "success"): success[1:], + }, + batch_size=ep_num_frames - 1, + ) + + for cam in CAMERAS[self.dataset_id]: + image = torch.from_numpy(ep[f"/observations/images/{cam}"][:]) + image = einops.rearrange(image, "b h w c -> b c h w").contiguous() + ep_td["observation", "image", cam] = image[:-1] + ep_td["next", "observation", "image", cam] = image[1:] + + if ep_id == 0: + # hack to initialize tensordict data structure to store episodes + td_data = ep_td[0].expand(total_num_frames).memmap_like(self.data_dir) + + td_data[idxtd : idxtd + len(ep_td)] = ep_td + idxtd = idxtd + len(ep_td) + + return TensorStorage(td_data.lock_()) diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index e054682e..5bb4b14f 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -66,6 +66,12 @@ def make_offline_buffer( clsfunc = PushtExperienceReplay dataset_id = "pusht" + + elif cfg.env.name == "aloha": + from lerobot.common.datasets.aloha import AlohaExperienceReplay + + clsfunc = AlohaExperienceReplay + dataset_id = f"aloha_{cfg.env.task}" else: raise ValueError(cfg.env.name) diff --git a/lerobot/configs/env/aloha.yaml b/lerobot/configs/env/aloha.yaml new file mode 100644 index 00000000..5b5ecbb7 --- /dev/null +++ b/lerobot/configs/env/aloha.yaml @@ -0,0 +1,25 @@ +# @package _global_ + +eval_episodes: 50 +eval_freq: 7500 +save_freq: 75000 +log_freq: 250 +# TODO: same as simxarm, need to adjust +offline_steps: 25000 +online_steps: 25000 + +fps: 50 + +env: + name: aloha + task: sim_insertion_human + from_pixels: True + pixels_only: False + image_size: 96 + action_repeat: 1 + episode_length: 300 + fps: ${fps} + +policy: + state_dim: 2 + action_dim: 2 diff --git a/poetry.lock b/poetry.lock index 4b96b902..9a35071b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -822,6 +822,43 @@ files = [ {file = "gym_notices-0.0.8-py3-none-any.whl", hash = "sha256:e5f82e00823a166747b4c2a07de63b6560b1acb880638547e0cabf825a01e463"}, ] +[[package]] +name = "h5py" +version = "3.10.0" +description = "Read and write HDF5 files from Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "h5py-3.10.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b963fb772964fc1d1563c57e4e2e874022ce11f75ddc6df1a626f42bd49ab99f"}, + {file = "h5py-3.10.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:012ab448590e3c4f5a8dd0f3533255bc57f80629bf7c5054cf4c87b30085063c"}, + {file = "h5py-3.10.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:781a24263c1270a62cd67be59f293e62b76acfcc207afa6384961762bb88ea03"}, + {file = "h5py-3.10.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f42e6c30698b520f0295d70157c4e202a9e402406f50dc08f5a7bc416b24e52d"}, + {file = "h5py-3.10.0-cp310-cp310-win_amd64.whl", hash = "sha256:93dd840bd675787fc0b016f7a05fc6efe37312a08849d9dd4053fd0377b1357f"}, + {file = "h5py-3.10.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:2381e98af081b6df7f6db300cd88f88e740649d77736e4b53db522d8874bf2dc"}, + {file = "h5py-3.10.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:667fe23ab33d5a8a6b77970b229e14ae3bb84e4ea3382cc08567a02e1499eedd"}, + {file = "h5py-3.10.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:90286b79abd085e4e65e07c1bd7ee65a0f15818ea107f44b175d2dfe1a4674b7"}, + {file = "h5py-3.10.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c013d2e79c00f28ffd0cc24e68665ea03ae9069e167087b2adb5727d2736a52"}, + {file = "h5py-3.10.0-cp311-cp311-win_amd64.whl", hash = "sha256:92273ce69ae4983dadb898fd4d3bea5eb90820df953b401282ee69ad648df684"}, + {file = "h5py-3.10.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:3c97d03f87f215e7759a354460fb4b0d0f27001450b18b23e556e7856a0b21c3"}, + {file = "h5py-3.10.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:86df4c2de68257b8539a18646ceccdcf2c1ce6b1768ada16c8dcfb489eafae20"}, + {file = "h5py-3.10.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba9ab36be991119a3ff32d0c7cbe5faf9b8d2375b5278b2aea64effbeba66039"}, + {file = "h5py-3.10.0-cp312-cp312-win_amd64.whl", hash = "sha256:2c8e4fda19eb769e9a678592e67eaec3a2f069f7570c82d2da909c077aa94339"}, + {file = "h5py-3.10.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:492305a074327e8d2513011fa9fffeb54ecb28a04ca4c4227d7e1e9616d35641"}, + {file = "h5py-3.10.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:9450464b458cca2c86252b624279115dcaa7260a40d3cb1594bf2b410a2bd1a3"}, + {file = "h5py-3.10.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fd6f6d1384a9f491732cee233b99cd4bfd6e838a8815cc86722f9d2ee64032af"}, + {file = "h5py-3.10.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3074ec45d3dc6e178c6f96834cf8108bf4a60ccb5ab044e16909580352010a97"}, + {file = "h5py-3.10.0-cp38-cp38-win_amd64.whl", hash = "sha256:212bb997a91e6a895ce5e2f365ba764debeaef5d2dca5c6fb7098d66607adf99"}, + {file = "h5py-3.10.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5dfc65ac21fa2f630323c92453cadbe8d4f504726ec42f6a56cf80c2f90d6c52"}, + {file = "h5py-3.10.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d4682b94fd36ab217352be438abd44c8f357c5449b8995e63886b431d260f3d3"}, + {file = "h5py-3.10.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aece0e2e1ed2aab076c41802e50a0c3e5ef8816d60ece39107d68717d4559824"}, + {file = "h5py-3.10.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:43a61b2c2ad65b1fabc28802d133eed34debcc2c8b420cb213d3d4ef4d3e2229"}, + {file = "h5py-3.10.0-cp39-cp39-win_amd64.whl", hash = "sha256:ae2f0201c950059676455daf92700eeb57dcf5caaf71b9e1328e6e6593601770"}, + {file = "h5py-3.10.0.tar.gz", hash = "sha256:d93adc48ceeb33347eb24a634fb787efc7ae4644e6ea4ba733d099605045c049"}, +] + +[package.dependencies] +numpy = ">=1.17.3" + [[package]] name = "huggingface-hub" version = "0.21.3" @@ -3103,4 +3140,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "1de157eb9ba5b7016e43385b8e9ac507896103fe1ed70f1c7af2b0de3fa05dc1" +content-hash = "9c3e86956dd11bc8d7823e5e6c5e74a073051b495f71f96179113d99791f7ca0" diff --git a/pyproject.toml b/pyproject.toml index b0bcfce6..64cbb850 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,7 @@ opencv-python = "^4.9.0.80" diffusion-policy = {git = "https://github.com/real-stanford/diffusion_policy"} diffusers = "^0.26.3" torchvision = "^0.17.1" +h5py = "^3.10.0" [tool.poetry.group.dev.dependencies]