Merge pull request #9 from Cadene/user/rcadene/2024_03_05_aloha_dataset

Add Aloha replay buffer + 4 sim datasets
This commit is contained in:
Remi 2024-03-06 14:51:55 +01:00 committed by GitHub
commit 66373e9b13
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 255 additions and 1 deletions

View File

@ -0,0 +1,185 @@
import logging
from pathlib import Path
from typing import Callable
import einops
import gdown
import h5py
import torch
import torchrl
import tqdm
from tensordict import TensorDict
from torchrl.data.replay_buffers.samplers import SliceSampler
from torchrl.data.replay_buffers.storages import TensorStorage
from torchrl.data.replay_buffers.writers import Writer
from lerobot.common.datasets.abstract import AbstractExperienceReplay
DATASET_IDS = [
"aloha_sim_insertion_human",
"aloha_sim_insertion_scripted",
"aloha_sim_transfer_cube_human",
"aloha_sim_transfer_cube_scripted",
]
FOLDER_URLS = {
"aloha_sim_insertion_human": "https://drive.google.com/drive/folders/1RgyD0JgTX30H4IM5XZn8I3zSV_mr8pyF",
"aloha_sim_insertion_scripted": "https://drive.google.com/drive/folders/1TsojQQSXtHEoGnqgJ3gmpPQR2DPLtS2N",
"aloha_sim_transfer_cube_human": "https://drive.google.com/drive/folders/1sc-E4QYW7A0o23m1u2VWNGVq5smAsfCo",
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/drive/folders/1aRyoOhQwxhyt1J8XgEig4s6kzaw__LXj",
}
EP48_URLS = {
"aloha_sim_insertion_human": "https://drive.google.com/file/d/18Cudl6nikDtgRolea7je8iF_gGKzynOP/view?usp=drive_link",
"aloha_sim_insertion_scripted": "https://drive.google.com/file/d/1wfMSZ24oOh5KR_0aaP3Cnu_c4ZCveduB/view?usp=drive_link",
"aloha_sim_transfer_cube_human": "https://drive.google.com/file/d/18smMymtr8tIxaNUQ61gW6dG50pt3MvGq/view?usp=drive_link",
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/file/d/1pnGIOd-E4-rhz2P3VxpknMKRZCoKt6eI/view?usp=drive_link",
}
EP49_URLS = {
"aloha_sim_insertion_human": "https://drive.google.com/file/d/1C1kZYyROzs-PrLc0SkDgUgMi4-L3lauE/view?usp=drive_link",
"aloha_sim_insertion_scripted": "https://drive.google.com/file/d/17EuCUWS6uCCr6yyNzpXdcdE-_TTNCKtf/view?usp=drive_link",
"aloha_sim_transfer_cube_human": "https://drive.google.com/file/d/1Nk7l53d9sJoGDBKAOnNrExX5nLacATc6/view?usp=drive_link",
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/file/d/1GKReZHrXU73NMiC5zKCq_UtqPVtYq8eo/view?usp=drive_link",
}
NUM_EPISODES = {
"aloha_sim_insertion_human": 50,
"aloha_sim_insertion_scripted": 50,
"aloha_sim_transfer_cube_human": 50,
"aloha_sim_transfer_cube_scripted": 50,
}
EPISODE_LEN = {
"aloha_sim_insertion_human": 500,
"aloha_sim_insertion_scripted": 400,
"aloha_sim_transfer_cube_human": 400,
"aloha_sim_transfer_cube_scripted": 400,
}
CAMERAS = {
"aloha_sim_insertion_human": ["top"],
"aloha_sim_insertion_scripted": ["top"],
"aloha_sim_transfer_cube_human": ["top"],
"aloha_sim_transfer_cube_scripted": ["top"],
}
def download(data_dir, dataset_id):
assert dataset_id in DATASET_IDS
assert dataset_id in FOLDER_URLS
assert dataset_id in EP48_URLS
assert dataset_id in EP49_URLS
data_dir.mkdir(parents=True, exist_ok=True)
gdown.download_folder(FOLDER_URLS[dataset_id], output=data_dir)
# because of the 50 files limit per directory, two files episode 48 and 49 were missing
gdown.download(EP48_URLS[dataset_id], output=data_dir / "episode_48.hdf5", fuzzy=True)
gdown.download(EP49_URLS[dataset_id], output=data_dir / "episode_49.hdf5", fuzzy=True)
class AlohaExperienceReplay(AbstractExperienceReplay):
def __init__(
self,
dataset_id: str,
batch_size: int = None,
*,
shuffle: bool = True,
root: Path = None,
pin_memory: bool = False,
prefetch: int = None,
sampler: SliceSampler = None,
collate_fn: Callable = None,
writer: Writer = None,
transform: "torchrl.envs.Transform" = None,
):
assert dataset_id in DATASET_IDS
super().__init__(
dataset_id,
batch_size,
shuffle=shuffle,
root=root,
pin_memory=pin_memory,
prefetch=prefetch,
sampler=sampler,
collate_fn=collate_fn,
writer=writer,
transform=transform,
)
@property
def stats_patterns(self) -> dict:
d = {
("observation", "state"): "b c -> 1 c",
("action"): "b c -> 1 c",
}
for cam in CAMERAS[self.dataset_id]:
d[("observation", "image", cam)] = "b c h w -> 1 c 1 1"
return d
@property
def image_keys(self) -> list:
return [("observation", "image", cam) for cam in CAMERAS[self.dataset_id]]
# def _is_downloaded(self) -> bool:
# return False
def _download_and_preproc(self):
raw_dir = self.data_dir.parent / f"{self.data_dir.name}_raw"
if not raw_dir.is_dir():
download(raw_dir, self.dataset_id)
total_num_frames = 0
logging.info("Compute total number of frames to initialize offline buffer")
for ep_id in range(NUM_EPISODES[self.dataset_id]):
ep_path = raw_dir / f"episode_{ep_id}.hdf5"
with h5py.File(ep_path, "r") as ep:
total_num_frames += ep["/action"].shape[0] - 1
logging.info(f"{total_num_frames=}")
logging.info("Initialize and feed offline buffer")
idxtd = 0
for ep_id in tqdm.tqdm(range(NUM_EPISODES[self.dataset_id])):
ep_path = raw_dir / f"episode_{ep_id}.hdf5"
with h5py.File(ep_path, "r") as ep:
ep_num_frames = ep["/action"].shape[0]
# last step of demonstration is considered done
done = torch.zeros(ep_num_frames, 1, dtype=torch.bool)
done[-1] = True
state = torch.from_numpy(ep["/observations/qpos"][:])
action = torch.from_numpy(ep["/action"][:])
ep_td = TensorDict(
{
("observation", "state"): state[:-1],
"action": action[:-1],
"episode": torch.tensor([ep_id] * (ep_num_frames - 1)),
"frame_id": torch.arange(0, ep_num_frames - 1, 1),
("next", "observation", "state"): state[1:],
# TODO: compute reward and success
# ("next", "reward"): reward[1:],
("next", "done"): done[1:],
# ("next", "success"): success[1:],
},
batch_size=ep_num_frames - 1,
)
for cam in CAMERAS[self.dataset_id]:
image = torch.from_numpy(ep[f"/observations/images/{cam}"][:])
image = einops.rearrange(image, "b h w c -> b c h w").contiguous()
ep_td["observation", "image", cam] = image[:-1]
ep_td["next", "observation", "image", cam] = image[1:]
if ep_id == 0:
# hack to initialize tensordict data structure to store episodes
td_data = ep_td[0].expand(total_num_frames).memmap_like(self.data_dir)
td_data[idxtd : idxtd + len(ep_td)] = ep_td
idxtd = idxtd + len(ep_td)
return TensorStorage(td_data.lock_())

View File

@ -66,6 +66,12 @@ def make_offline_buffer(
clsfunc = PushtExperienceReplay
dataset_id = "pusht"
elif cfg.env.name == "aloha":
from lerobot.common.datasets.aloha import AlohaExperienceReplay
clsfunc = AlohaExperienceReplay
dataset_id = f"aloha_{cfg.env.task}"
else:
raise ValueError(cfg.env.name)

25
lerobot/configs/env/aloha.yaml vendored Normal file
View File

@ -0,0 +1,25 @@
# @package _global_
eval_episodes: 50
eval_freq: 7500
save_freq: 75000
log_freq: 250
# TODO: same as simxarm, need to adjust
offline_steps: 25000
online_steps: 25000
fps: 50
env:
name: aloha
task: sim_insertion_human
from_pixels: True
pixels_only: False
image_size: 96
action_repeat: 1
episode_length: 300
fps: ${fps}
policy:
state_dim: 2
action_dim: 2

39
poetry.lock generated
View File

@ -822,6 +822,43 @@ files = [
{file = "gym_notices-0.0.8-py3-none-any.whl", hash = "sha256:e5f82e00823a166747b4c2a07de63b6560b1acb880638547e0cabf825a01e463"},
]
[[package]]
name = "h5py"
version = "3.10.0"
description = "Read and write HDF5 files from Python"
optional = false
python-versions = ">=3.8"
files = [
{file = "h5py-3.10.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b963fb772964fc1d1563c57e4e2e874022ce11f75ddc6df1a626f42bd49ab99f"},
{file = "h5py-3.10.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:012ab448590e3c4f5a8dd0f3533255bc57f80629bf7c5054cf4c87b30085063c"},
{file = "h5py-3.10.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:781a24263c1270a62cd67be59f293e62b76acfcc207afa6384961762bb88ea03"},
{file = "h5py-3.10.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f42e6c30698b520f0295d70157c4e202a9e402406f50dc08f5a7bc416b24e52d"},
{file = "h5py-3.10.0-cp310-cp310-win_amd64.whl", hash = "sha256:93dd840bd675787fc0b016f7a05fc6efe37312a08849d9dd4053fd0377b1357f"},
{file = "h5py-3.10.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:2381e98af081b6df7f6db300cd88f88e740649d77736e4b53db522d8874bf2dc"},
{file = "h5py-3.10.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:667fe23ab33d5a8a6b77970b229e14ae3bb84e4ea3382cc08567a02e1499eedd"},
{file = "h5py-3.10.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:90286b79abd085e4e65e07c1bd7ee65a0f15818ea107f44b175d2dfe1a4674b7"},
{file = "h5py-3.10.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c013d2e79c00f28ffd0cc24e68665ea03ae9069e167087b2adb5727d2736a52"},
{file = "h5py-3.10.0-cp311-cp311-win_amd64.whl", hash = "sha256:92273ce69ae4983dadb898fd4d3bea5eb90820df953b401282ee69ad648df684"},
{file = "h5py-3.10.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:3c97d03f87f215e7759a354460fb4b0d0f27001450b18b23e556e7856a0b21c3"},
{file = "h5py-3.10.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:86df4c2de68257b8539a18646ceccdcf2c1ce6b1768ada16c8dcfb489eafae20"},
{file = "h5py-3.10.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba9ab36be991119a3ff32d0c7cbe5faf9b8d2375b5278b2aea64effbeba66039"},
{file = "h5py-3.10.0-cp312-cp312-win_amd64.whl", hash = "sha256:2c8e4fda19eb769e9a678592e67eaec3a2f069f7570c82d2da909c077aa94339"},
{file = "h5py-3.10.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:492305a074327e8d2513011fa9fffeb54ecb28a04ca4c4227d7e1e9616d35641"},
{file = "h5py-3.10.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:9450464b458cca2c86252b624279115dcaa7260a40d3cb1594bf2b410a2bd1a3"},
{file = "h5py-3.10.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fd6f6d1384a9f491732cee233b99cd4bfd6e838a8815cc86722f9d2ee64032af"},
{file = "h5py-3.10.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3074ec45d3dc6e178c6f96834cf8108bf4a60ccb5ab044e16909580352010a97"},
{file = "h5py-3.10.0-cp38-cp38-win_amd64.whl", hash = "sha256:212bb997a91e6a895ce5e2f365ba764debeaef5d2dca5c6fb7098d66607adf99"},
{file = "h5py-3.10.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5dfc65ac21fa2f630323c92453cadbe8d4f504726ec42f6a56cf80c2f90d6c52"},
{file = "h5py-3.10.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d4682b94fd36ab217352be438abd44c8f357c5449b8995e63886b431d260f3d3"},
{file = "h5py-3.10.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aece0e2e1ed2aab076c41802e50a0c3e5ef8816d60ece39107d68717d4559824"},
{file = "h5py-3.10.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:43a61b2c2ad65b1fabc28802d133eed34debcc2c8b420cb213d3d4ef4d3e2229"},
{file = "h5py-3.10.0-cp39-cp39-win_amd64.whl", hash = "sha256:ae2f0201c950059676455daf92700eeb57dcf5caaf71b9e1328e6e6593601770"},
{file = "h5py-3.10.0.tar.gz", hash = "sha256:d93adc48ceeb33347eb24a634fb787efc7ae4644e6ea4ba733d099605045c049"},
]
[package.dependencies]
numpy = ">=1.17.3"
[[package]]
name = "huggingface-hub"
version = "0.21.3"
@ -3103,4 +3140,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "1de157eb9ba5b7016e43385b8e9ac507896103fe1ed70f1c7af2b0de3fa05dc1"
content-hash = "9c3e86956dd11bc8d7823e5e6c5e74a073051b495f71f96179113d99791f7ca0"

View File

@ -48,6 +48,7 @@ opencv-python = "^4.9.0.80"
diffusion-policy = {git = "https://github.com/real-stanford/diffusion_policy"}
diffusers = "^0.26.3"
torchvision = "^0.17.1"
h5py = "^3.10.0"
[tool.poetry.group.dev.dependencies]