208 lines
7.8 KiB
Python
208 lines
7.8 KiB
Python
import logging
|
|
from pathlib import Path
|
|
|
|
import einops
|
|
import gdown
|
|
import h5py
|
|
import torch
|
|
import tqdm
|
|
|
|
from lerobot.common.datasets.utils import load_data_with_delta_timestamps
|
|
|
|
FOLDER_URLS = {
|
|
"aloha_sim_insertion_human": "https://drive.google.com/drive/folders/1RgyD0JgTX30H4IM5XZn8I3zSV_mr8pyF",
|
|
"aloha_sim_insertion_scripted": "https://drive.google.com/drive/folders/1TsojQQSXtHEoGnqgJ3gmpPQR2DPLtS2N",
|
|
"aloha_sim_transfer_cube_human": "https://drive.google.com/drive/folders/1sc-E4QYW7A0o23m1u2VWNGVq5smAsfCo",
|
|
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/drive/folders/1aRyoOhQwxhyt1J8XgEig4s6kzaw__LXj",
|
|
}
|
|
|
|
EP48_URLS = {
|
|
"aloha_sim_insertion_human": "https://drive.google.com/file/d/18Cudl6nikDtgRolea7je8iF_gGKzynOP/view?usp=drive_link",
|
|
"aloha_sim_insertion_scripted": "https://drive.google.com/file/d/1wfMSZ24oOh5KR_0aaP3Cnu_c4ZCveduB/view?usp=drive_link",
|
|
"aloha_sim_transfer_cube_human": "https://drive.google.com/file/d/18smMymtr8tIxaNUQ61gW6dG50pt3MvGq/view?usp=drive_link",
|
|
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/file/d/1pnGIOd-E4-rhz2P3VxpknMKRZCoKt6eI/view?usp=drive_link",
|
|
}
|
|
|
|
EP49_URLS = {
|
|
"aloha_sim_insertion_human": "https://drive.google.com/file/d/1C1kZYyROzs-PrLc0SkDgUgMi4-L3lauE/view?usp=drive_link",
|
|
"aloha_sim_insertion_scripted": "https://drive.google.com/file/d/17EuCUWS6uCCr6yyNzpXdcdE-_TTNCKtf/view?usp=drive_link",
|
|
"aloha_sim_transfer_cube_human": "https://drive.google.com/file/d/1Nk7l53d9sJoGDBKAOnNrExX5nLacATc6/view?usp=drive_link",
|
|
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/file/d/1GKReZHrXU73NMiC5zKCq_UtqPVtYq8eo/view?usp=drive_link",
|
|
}
|
|
|
|
NUM_EPISODES = {
|
|
"aloha_sim_insertion_human": 50,
|
|
"aloha_sim_insertion_scripted": 50,
|
|
"aloha_sim_transfer_cube_human": 50,
|
|
"aloha_sim_transfer_cube_scripted": 50,
|
|
}
|
|
|
|
EPISODE_LEN = {
|
|
"aloha_sim_insertion_human": 500,
|
|
"aloha_sim_insertion_scripted": 400,
|
|
"aloha_sim_transfer_cube_human": 400,
|
|
"aloha_sim_transfer_cube_scripted": 400,
|
|
}
|
|
|
|
CAMERAS = {
|
|
"aloha_sim_insertion_human": ["top"],
|
|
"aloha_sim_insertion_scripted": ["top"],
|
|
"aloha_sim_transfer_cube_human": ["top"],
|
|
"aloha_sim_transfer_cube_scripted": ["top"],
|
|
}
|
|
|
|
|
|
def download(data_dir, dataset_id):
|
|
assert dataset_id in FOLDER_URLS
|
|
assert dataset_id in EP48_URLS
|
|
assert dataset_id in EP49_URLS
|
|
|
|
data_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
gdown.download_folder(FOLDER_URLS[dataset_id], output=str(data_dir))
|
|
|
|
# because of the 50 files limit per directory, two files episode 48 and 49 were missing
|
|
gdown.download(EP48_URLS[dataset_id], output=str(data_dir / "episode_48.hdf5"), fuzzy=True)
|
|
gdown.download(EP49_URLS[dataset_id], output=str(data_dir / "episode_49.hdf5"), fuzzy=True)
|
|
|
|
|
|
class AlohaDataset(torch.utils.data.Dataset):
|
|
available_datasets = [
|
|
"aloha_sim_insertion_human",
|
|
"aloha_sim_insertion_scripted",
|
|
"aloha_sim_transfer_cube_human",
|
|
"aloha_sim_transfer_cube_scripted",
|
|
]
|
|
fps = 50
|
|
image_keys = ["observation.images.top"]
|
|
|
|
def __init__(
|
|
self,
|
|
dataset_id: str,
|
|
version: str | None = "v1.2",
|
|
root: Path | None = None,
|
|
transform: callable = None,
|
|
delta_timestamps: dict[list[float]] | None = None,
|
|
):
|
|
super().__init__()
|
|
self.dataset_id = dataset_id
|
|
self.version = version
|
|
self.root = root
|
|
self.transform = transform
|
|
self.delta_timestamps = delta_timestamps
|
|
|
|
self.data_dir = self.root / f"{self.dataset_id}"
|
|
if (self.data_dir / "data_dict.pth").exists() and (
|
|
self.data_dir / "data_ids_per_episode.pth"
|
|
).exists():
|
|
self.data_dict = torch.load(self.data_dir / "data_dict.pth")
|
|
self.data_ids_per_episode = torch.load(self.data_dir / "data_ids_per_episode.pth")
|
|
else:
|
|
self._download_and_preproc_obsolete()
|
|
self.data_dir.mkdir(parents=True, exist_ok=True)
|
|
torch.save(self.data_dict, self.data_dir / "data_dict.pth")
|
|
torch.save(self.data_ids_per_episode, self.data_dir / "data_ids_per_episode.pth")
|
|
|
|
@property
|
|
def num_samples(self) -> int:
|
|
return len(self.data_dict["index"])
|
|
|
|
@property
|
|
def num_episodes(self) -> int:
|
|
return len(self.data_ids_per_episode)
|
|
|
|
def __len__(self):
|
|
return self.num_samples
|
|
|
|
def __getitem__(self, idx):
|
|
item = {}
|
|
|
|
# get episode id and timestamp of the sampled frame
|
|
current_ts = self.data_dict["timestamp"][idx].item()
|
|
episode = self.data_dict["episode"][idx].item()
|
|
|
|
for key in self.data_dict:
|
|
if self.delta_timestamps is not None and key in self.delta_timestamps:
|
|
data, is_pad = load_data_with_delta_timestamps(
|
|
self.data_dict,
|
|
self.data_ids_per_episode,
|
|
self.delta_timestamps,
|
|
key,
|
|
current_ts,
|
|
episode,
|
|
)
|
|
item[key] = data
|
|
item[f"{key}_is_pad"] = is_pad
|
|
else:
|
|
item[key] = self.data_dict[key][idx]
|
|
|
|
if self.transform is not None:
|
|
item = self.transform(item)
|
|
|
|
return item
|
|
|
|
def _download_and_preproc_obsolete(self):
|
|
assert self.root is not None
|
|
raw_dir = self.root / f"{self.dataset_id}_raw"
|
|
if not raw_dir.is_dir():
|
|
download(raw_dir, self.dataset_id)
|
|
|
|
total_frames = 0
|
|
logging.info("Compute total number of frames to initialize offline buffer")
|
|
for ep_id in range(NUM_EPISODES[self.dataset_id]):
|
|
ep_path = raw_dir / f"episode_{ep_id}.hdf5"
|
|
with h5py.File(ep_path, "r") as ep:
|
|
total_frames += ep["/action"].shape[0] - 1
|
|
logging.info(f"{total_frames=}")
|
|
|
|
self.data_ids_per_episode = {}
|
|
ep_dicts = []
|
|
|
|
frame_idx = 0
|
|
for ep_id in tqdm.tqdm(range(NUM_EPISODES[self.dataset_id])):
|
|
ep_path = raw_dir / f"episode_{ep_id}.hdf5"
|
|
with h5py.File(ep_path, "r") as ep:
|
|
num_frames = ep["/action"].shape[0]
|
|
|
|
# last step of demonstration is considered done
|
|
done = torch.zeros(num_frames, dtype=torch.bool)
|
|
done[-1] = True
|
|
|
|
state = torch.from_numpy(ep["/observations/qpos"][:])
|
|
action = torch.from_numpy(ep["/action"][:])
|
|
|
|
ep_dict = {
|
|
"observation.state": state,
|
|
"action": action,
|
|
"episode": torch.tensor([ep_id] * num_frames),
|
|
"frame_id": torch.arange(0, num_frames, 1),
|
|
"timestamp": torch.arange(0, num_frames, 1) / self.fps,
|
|
# "next.observation.state": state,
|
|
# TODO(rcadene): compute reward and success
|
|
# "next.reward": reward[1:],
|
|
"next.done": done[1:],
|
|
# "next.success": success[1:],
|
|
}
|
|
|
|
for cam in CAMERAS[self.dataset_id]:
|
|
image = torch.from_numpy(ep[f"/observations/images/{cam}"][:])
|
|
image = einops.rearrange(image, "b h w c -> b c h w").contiguous()
|
|
ep_dict[f"observation.images.{cam}"] = image[:-1]
|
|
# ep_dict[f"next.observation.images.{cam}"] = image[1:]
|
|
|
|
assert isinstance(ep_id, int)
|
|
self.data_ids_per_episode[ep_id] = torch.arange(frame_idx, frame_idx + num_frames, 1)
|
|
assert len(self.data_ids_per_episode[ep_id]) == num_frames
|
|
|
|
ep_dicts.append(ep_dict)
|
|
|
|
frame_idx += num_frames
|
|
|
|
self.data_dict = {}
|
|
|
|
keys = ep_dicts[0].keys()
|
|
for key in keys:
|
|
self.data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
|
|
|
self.data_dict["index"] = torch.arange(0, total_frames, 1)
|