Add aloha dataset
This commit is contained in:
parent
49c0955f97
commit
d782b029e1
|
@ -0,0 +1,185 @@
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import einops
|
||||||
|
import gdown
|
||||||
|
import h5py
|
||||||
|
import torch
|
||||||
|
import torchrl
|
||||||
|
import tqdm
|
||||||
|
from tensordict import TensorDict
|
||||||
|
from torchrl.data.replay_buffers.samplers import SliceSampler
|
||||||
|
from torchrl.data.replay_buffers.storages import TensorStorage
|
||||||
|
from torchrl.data.replay_buffers.writers import Writer
|
||||||
|
|
||||||
|
from lerobot.common.datasets.abstract import AbstractExperienceReplay
|
||||||
|
|
||||||
|
DATASET_IDS = [
|
||||||
|
"aloha_sim_insertion_human",
|
||||||
|
"aloha_sim_insertion_scripted",
|
||||||
|
"aloha_sim_transfer_cube_human",
|
||||||
|
"aloha_sim_transfer_cube_scripted",
|
||||||
|
]
|
||||||
|
|
||||||
|
FOLDER_URLS = {
|
||||||
|
"aloha_sim_insertion_human": "https://drive.google.com/drive/folders/1RgyD0JgTX30H4IM5XZn8I3zSV_mr8pyF",
|
||||||
|
"aloha_sim_insertion_scripted": "https://drive.google.com/drive/folders/1TsojQQSXtHEoGnqgJ3gmpPQR2DPLtS2N",
|
||||||
|
"aloha_sim_transfer_cube_human": "https://drive.google.com/drive/folders/1sc-E4QYW7A0o23m1u2VWNGVq5smAsfCo",
|
||||||
|
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/drive/folders/1aRyoOhQwxhyt1J8XgEig4s6kzaw__LXj",
|
||||||
|
}
|
||||||
|
|
||||||
|
EP48_URLS = {
|
||||||
|
"aloha_sim_insertion_human": "https://drive.google.com/file/d/18Cudl6nikDtgRolea7je8iF_gGKzynOP/view?usp=drive_link",
|
||||||
|
"aloha_sim_insertion_scripted": "https://drive.google.com/file/d/1wfMSZ24oOh5KR_0aaP3Cnu_c4ZCveduB/view?usp=drive_link",
|
||||||
|
"aloha_sim_transfer_cube_human": "https://drive.google.com/file/d/18smMymtr8tIxaNUQ61gW6dG50pt3MvGq/view?usp=drive_link",
|
||||||
|
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/file/d/1pnGIOd-E4-rhz2P3VxpknMKRZCoKt6eI/view?usp=drive_link",
|
||||||
|
}
|
||||||
|
|
||||||
|
EP49_URLS = {
|
||||||
|
"aloha_sim_insertion_human": "https://drive.google.com/file/d/1C1kZYyROzs-PrLc0SkDgUgMi4-L3lauE/view?usp=drive_link",
|
||||||
|
"aloha_sim_insertion_scripted": "https://drive.google.com/file/d/17EuCUWS6uCCr6yyNzpXdcdE-_TTNCKtf/view?usp=drive_link",
|
||||||
|
"aloha_sim_transfer_cube_human": "https://drive.google.com/file/d/1Nk7l53d9sJoGDBKAOnNrExX5nLacATc6/view?usp=drive_link",
|
||||||
|
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/file/d/1GKReZHrXU73NMiC5zKCq_UtqPVtYq8eo/view?usp=drive_link",
|
||||||
|
}
|
||||||
|
|
||||||
|
NUM_EPISODES = {
|
||||||
|
"aloha_sim_insertion_human": 50,
|
||||||
|
"aloha_sim_insertion_scripted": 50,
|
||||||
|
"aloha_sim_transfer_cube_human": 50,
|
||||||
|
"aloha_sim_transfer_cube_scripted": 50,
|
||||||
|
}
|
||||||
|
|
||||||
|
EPISODE_LEN = {
|
||||||
|
"aloha_sim_insertion_human": 500,
|
||||||
|
"aloha_sim_insertion_scripted": 400,
|
||||||
|
"aloha_sim_transfer_cube_human": 400,
|
||||||
|
"aloha_sim_transfer_cube_scripted": 400,
|
||||||
|
}
|
||||||
|
|
||||||
|
CAMERAS = {
|
||||||
|
"aloha_sim_insertion_human": ["top"],
|
||||||
|
"aloha_sim_insertion_scripted": ["top"],
|
||||||
|
"aloha_sim_transfer_cube_human": ["top"],
|
||||||
|
"aloha_sim_transfer_cube_scripted": ["top"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def download(data_dir, dataset_id):
|
||||||
|
assert dataset_id in DATASET_IDS
|
||||||
|
assert dataset_id in FOLDER_URLS
|
||||||
|
assert dataset_id in EP48_URLS
|
||||||
|
assert dataset_id in EP49_URLS
|
||||||
|
|
||||||
|
data_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
gdown.download_folder(FOLDER_URLS[dataset_id], output=data_dir)
|
||||||
|
|
||||||
|
# because of the 50 files limit per directory, two files episode 48 and 49 were missing
|
||||||
|
gdown.download(EP48_URLS[dataset_id], output=data_dir / "episode_48.hdf5", fuzzy=True)
|
||||||
|
gdown.download(EP49_URLS[dataset_id], output=data_dir / "episode_49.hdf5", fuzzy=True)
|
||||||
|
|
||||||
|
|
||||||
|
class AlohaExperienceReplay(AbstractExperienceReplay):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dataset_id: str,
|
||||||
|
batch_size: int = None,
|
||||||
|
*,
|
||||||
|
shuffle: bool = True,
|
||||||
|
root: Path = None,
|
||||||
|
pin_memory: bool = False,
|
||||||
|
prefetch: int = None,
|
||||||
|
sampler: SliceSampler = None,
|
||||||
|
collate_fn: Callable = None,
|
||||||
|
writer: Writer = None,
|
||||||
|
transform: "torchrl.envs.Transform" = None,
|
||||||
|
):
|
||||||
|
assert dataset_id in DATASET_IDS
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
dataset_id,
|
||||||
|
batch_size,
|
||||||
|
shuffle=shuffle,
|
||||||
|
root=root,
|
||||||
|
pin_memory=pin_memory,
|
||||||
|
prefetch=prefetch,
|
||||||
|
sampler=sampler,
|
||||||
|
collate_fn=collate_fn,
|
||||||
|
writer=writer,
|
||||||
|
transform=transform,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stats_patterns(self) -> dict:
|
||||||
|
d = {
|
||||||
|
("observation", "state"): "b c -> 1 c",
|
||||||
|
("action"): "b c -> 1 c",
|
||||||
|
}
|
||||||
|
for cam in CAMERAS[self.dataset_id]:
|
||||||
|
d[("observation", "image", cam)] = "b c h w -> 1 c 1 1"
|
||||||
|
return d
|
||||||
|
|
||||||
|
@property
|
||||||
|
def image_keys(self) -> list:
|
||||||
|
return [("observation", "image", cam) for cam in CAMERAS[self.dataset_id]]
|
||||||
|
|
||||||
|
# def _is_downloaded(self) -> bool:
|
||||||
|
# return False
|
||||||
|
|
||||||
|
def _download_and_preproc(self):
|
||||||
|
raw_dir = self.data_dir.parent / f"{self.data_dir.name}_raw"
|
||||||
|
if not raw_dir.is_dir():
|
||||||
|
download(raw_dir, self.dataset_id)
|
||||||
|
|
||||||
|
total_num_frames = 0
|
||||||
|
logging.info("Compute total number of frames to initialize offline buffer")
|
||||||
|
for ep_id in range(NUM_EPISODES[self.dataset_id]):
|
||||||
|
ep_path = raw_dir / f"episode_{ep_id}.hdf5"
|
||||||
|
with h5py.File(ep_path, "r") as ep:
|
||||||
|
total_num_frames += ep["/action"].shape[0] - 1
|
||||||
|
logging.info(f"{total_num_frames=}")
|
||||||
|
|
||||||
|
logging.info("Initialize and feed offline buffer")
|
||||||
|
idxtd = 0
|
||||||
|
for ep_id in tqdm.tqdm(range(NUM_EPISODES[self.dataset_id])):
|
||||||
|
ep_path = raw_dir / f"episode_{ep_id}.hdf5"
|
||||||
|
with h5py.File(ep_path, "r") as ep:
|
||||||
|
ep_num_frames = ep["/action"].shape[0]
|
||||||
|
|
||||||
|
# last step of demonstration is considered done
|
||||||
|
done = torch.zeros(ep_num_frames, 1, dtype=torch.bool)
|
||||||
|
done[-1] = True
|
||||||
|
|
||||||
|
state = torch.from_numpy(ep["/observations/qpos"][:])
|
||||||
|
action = torch.from_numpy(ep["/action"][:])
|
||||||
|
|
||||||
|
ep_td = TensorDict(
|
||||||
|
{
|
||||||
|
("observation", "state"): state[:-1],
|
||||||
|
"action": action[:-1],
|
||||||
|
"episode": torch.tensor([ep_id] * (ep_num_frames - 1)),
|
||||||
|
"frame_id": torch.arange(0, ep_num_frames - 1, 1),
|
||||||
|
("next", "observation", "state"): state[1:],
|
||||||
|
# TODO: compute reward and success
|
||||||
|
# ("next", "reward"): reward[1:],
|
||||||
|
("next", "done"): done[1:],
|
||||||
|
# ("next", "success"): success[1:],
|
||||||
|
},
|
||||||
|
batch_size=ep_num_frames - 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
for cam in CAMERAS[self.dataset_id]:
|
||||||
|
image = torch.from_numpy(ep[f"/observations/images/{cam}"][:])
|
||||||
|
image = einops.rearrange(image, "b h w c -> b c h w").contiguous()
|
||||||
|
ep_td["observation", "image", cam] = image[:-1]
|
||||||
|
ep_td["next", "observation", "image", cam] = image[1:]
|
||||||
|
|
||||||
|
if ep_id == 0:
|
||||||
|
# hack to initialize tensordict data structure to store episodes
|
||||||
|
td_data = ep_td[0].expand(total_num_frames).memmap_like(self.data_dir)
|
||||||
|
|
||||||
|
td_data[idxtd : idxtd + len(ep_td)] = ep_td
|
||||||
|
idxtd = idxtd + len(ep_td)
|
||||||
|
|
||||||
|
return TensorStorage(td_data.lock_())
|
|
@ -66,6 +66,12 @@ def make_offline_buffer(
|
||||||
|
|
||||||
clsfunc = PushtExperienceReplay
|
clsfunc = PushtExperienceReplay
|
||||||
dataset_id = "pusht"
|
dataset_id = "pusht"
|
||||||
|
|
||||||
|
elif cfg.env.name == "aloha":
|
||||||
|
from lerobot.common.datasets.aloha import AlohaExperienceReplay
|
||||||
|
|
||||||
|
clsfunc = AlohaExperienceReplay
|
||||||
|
dataset_id = f"aloha_{cfg.env.task}"
|
||||||
else:
|
else:
|
||||||
raise ValueError(cfg.env.name)
|
raise ValueError(cfg.env.name)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,25 @@
|
||||||
|
# @package _global_
|
||||||
|
|
||||||
|
eval_episodes: 50
|
||||||
|
eval_freq: 7500
|
||||||
|
save_freq: 75000
|
||||||
|
log_freq: 250
|
||||||
|
# TODO: same as simxarm, need to adjust
|
||||||
|
offline_steps: 25000
|
||||||
|
online_steps: 25000
|
||||||
|
|
||||||
|
fps: 50
|
||||||
|
|
||||||
|
env:
|
||||||
|
name: aloha
|
||||||
|
task: sim_insertion_human
|
||||||
|
from_pixels: True
|
||||||
|
pixels_only: False
|
||||||
|
image_size: 96
|
||||||
|
action_repeat: 1
|
||||||
|
episode_length: 300
|
||||||
|
fps: ${fps}
|
||||||
|
|
||||||
|
policy:
|
||||||
|
state_dim: 2
|
||||||
|
action_dim: 2
|
Loading…
Reference in New Issue