Add download_and_upload_dataset.py in script, update all datasets, update online training

This commit is contained in:
Cadene 2024-04-15 21:26:33 +00:00
parent c6aca7fe44
commit 67d79732f9
8 changed files with 493 additions and 519 deletions

View File

@ -1,72 +1,17 @@
import logging
from pathlib import Path
import einops
import gdown
import h5py
import torch
import tqdm
from datasets import load_dataset
from lerobot.common.datasets.utils import load_data_with_delta_timestamps
FOLDER_URLS = {
"aloha_sim_insertion_human": "https://drive.google.com/drive/folders/1RgyD0JgTX30H4IM5XZn8I3zSV_mr8pyF",
"aloha_sim_insertion_scripted": "https://drive.google.com/drive/folders/1TsojQQSXtHEoGnqgJ3gmpPQR2DPLtS2N",
"aloha_sim_transfer_cube_human": "https://drive.google.com/drive/folders/1sc-E4QYW7A0o23m1u2VWNGVq5smAsfCo",
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/drive/folders/1aRyoOhQwxhyt1J8XgEig4s6kzaw__LXj",
}
EP48_URLS = {
"aloha_sim_insertion_human": "https://drive.google.com/file/d/18Cudl6nikDtgRolea7je8iF_gGKzynOP/view?usp=drive_link",
"aloha_sim_insertion_scripted": "https://drive.google.com/file/d/1wfMSZ24oOh5KR_0aaP3Cnu_c4ZCveduB/view?usp=drive_link",
"aloha_sim_transfer_cube_human": "https://drive.google.com/file/d/18smMymtr8tIxaNUQ61gW6dG50pt3MvGq/view?usp=drive_link",
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/file/d/1pnGIOd-E4-rhz2P3VxpknMKRZCoKt6eI/view?usp=drive_link",
}
EP49_URLS = {
"aloha_sim_insertion_human": "https://drive.google.com/file/d/1C1kZYyROzs-PrLc0SkDgUgMi4-L3lauE/view?usp=drive_link",
"aloha_sim_insertion_scripted": "https://drive.google.com/file/d/17EuCUWS6uCCr6yyNzpXdcdE-_TTNCKtf/view?usp=drive_link",
"aloha_sim_transfer_cube_human": "https://drive.google.com/file/d/1Nk7l53d9sJoGDBKAOnNrExX5nLacATc6/view?usp=drive_link",
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/file/d/1GKReZHrXU73NMiC5zKCq_UtqPVtYq8eo/view?usp=drive_link",
}
NUM_EPISODES = {
"aloha_sim_insertion_human": 50,
"aloha_sim_insertion_scripted": 50,
"aloha_sim_transfer_cube_human": 50,
"aloha_sim_transfer_cube_scripted": 50,
}
EPISODE_LEN = {
"aloha_sim_insertion_human": 500,
"aloha_sim_insertion_scripted": 400,
"aloha_sim_transfer_cube_human": 400,
"aloha_sim_transfer_cube_scripted": 400,
}
CAMERAS = {
"aloha_sim_insertion_human": ["top"],
"aloha_sim_insertion_scripted": ["top"],
"aloha_sim_transfer_cube_human": ["top"],
"aloha_sim_transfer_cube_scripted": ["top"],
}
def download(data_dir, dataset_id):
assert dataset_id in FOLDER_URLS
assert dataset_id in EP48_URLS
assert dataset_id in EP49_URLS
data_dir.mkdir(parents=True, exist_ok=True)
gdown.download_folder(FOLDER_URLS[dataset_id], output=str(data_dir))
# because of the 50 files limit per directory, two files episode 48 and 49 were missing
gdown.download(EP48_URLS[dataset_id], output=str(data_dir / "episode_48.hdf5"), fuzzy=True)
gdown.download(EP49_URLS[dataset_id], output=str(data_dir / "episode_49.hdf5"), fuzzy=True)
from lerobot.common.datasets.utils import load_previous_and_future_frames
class AlohaDataset(torch.utils.data.Dataset):
"""
https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human
https://huggingface.co/datasets/lerobot/aloha_sim_insertion_scripted
https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human
https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted
"""
available_datasets = [
"aloha_sim_insertion_human",
"aloha_sim_insertion_scripted",
@ -79,129 +24,40 @@ class AlohaDataset(torch.utils.data.Dataset):
def __init__(
self,
dataset_id: str,
version: str | None = "v1.2",
root: Path | None = None,
version: str | None = "v1.0",
transform: callable = None,
delta_timestamps: dict[list[float]] | None = None,
):
super().__init__()
self.dataset_id = dataset_id
self.version = version
self.root = root
self.transform = transform
self.delta_timestamps = delta_timestamps
self.data_dir = self.root / f"{self.dataset_id}"
if (self.data_dir / "data_dict.pth").exists() and (
self.data_dir / "data_ids_per_episode.pth"
).exists():
self.data_dict = torch.load(self.data_dir / "data_dict.pth")
self.data_ids_per_episode = torch.load(self.data_dir / "data_ids_per_episode.pth")
else:
self._download_and_preproc_obsolete()
self.data_dir.mkdir(parents=True, exist_ok=True)
torch.save(self.data_dict, self.data_dir / "data_dict.pth")
torch.save(self.data_ids_per_episode, self.data_dir / "data_ids_per_episode.pth")
self.data_dict = load_dataset(f"lerobot/{self.dataset_id}", revision=self.version, split="train")
self.data_dict = self.data_dict.with_format("torch")
@property
def num_samples(self) -> int:
return len(self.data_dict["index"]) if "index" in self.data_dict else 0
return len(self.data_dict)
@property
def num_episodes(self) -> int:
return len(self.data_ids_per_episode)
return len(self.data_dict.unique("episode_id"))
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
item = {}
item = self.data_dict[idx]
# get episode id and timestamp of the sampled frame
current_ts = self.data_dict["timestamp"][idx].item()
episode = self.data_dict["episode"][idx].item()
for key in self.data_dict:
if self.delta_timestamps is not None and key in self.delta_timestamps:
data, is_pad = load_data_with_delta_timestamps(
if self.delta_timestamps is not None:
item = load_previous_and_future_frames(
item,
self.data_dict,
self.data_ids_per_episode,
self.delta_timestamps,
key,
current_ts,
episode,
)
item[key] = data
item[f"{key}_is_pad"] = is_pad
else:
item[key] = self.data_dict[key][idx]
if self.transform is not None:
item = self.transform(item)
return item
def _download_and_preproc_obsolete(self):
assert self.root is not None
raw_dir = self.root / f"{self.dataset_id}_raw"
if not raw_dir.is_dir():
download(raw_dir, self.dataset_id)
total_frames = 0
logging.info("Compute total number of frames to initialize offline buffer")
for ep_id in range(NUM_EPISODES[self.dataset_id]):
ep_path = raw_dir / f"episode_{ep_id}.hdf5"
with h5py.File(ep_path, "r") as ep:
total_frames += ep["/action"].shape[0] - 1
logging.info(f"{total_frames=}")
self.data_ids_per_episode = {}
ep_dicts = []
frame_idx = 0
for ep_id in tqdm.tqdm(range(NUM_EPISODES[self.dataset_id])):
ep_path = raw_dir / f"episode_{ep_id}.hdf5"
with h5py.File(ep_path, "r") as ep:
num_frames = ep["/action"].shape[0]
# last step of demonstration is considered done
done = torch.zeros(num_frames, dtype=torch.bool)
done[-1] = True
state = torch.from_numpy(ep["/observations/qpos"][:])
action = torch.from_numpy(ep["/action"][:])
ep_dict = {
"observation.state": state,
"action": action,
"episode": torch.tensor([ep_id] * num_frames),
"frame_id": torch.arange(0, num_frames, 1),
"timestamp": torch.arange(0, num_frames, 1) / self.fps,
# "next.observation.state": state,
# TODO(rcadene): compute reward and success
# "next.reward": reward[1:],
"next.done": done[1:],
# "next.success": success[1:],
}
for cam in CAMERAS[self.dataset_id]:
image = torch.from_numpy(ep[f"/observations/images/{cam}"][:])
image = einops.rearrange(image, "b h w c -> b c h w").contiguous()
ep_dict[f"observation.images.{cam}"] = image[:-1]
# ep_dict[f"next.observation.images.{cam}"] = image[1:]
assert isinstance(ep_id, int)
self.data_ids_per_episode[ep_id] = torch.arange(frame_idx, frame_idx + num_frames, 1)
assert len(self.data_ids_per_episode[ep_id]) == num_frames
ep_dicts.append(ep_dict)
frame_idx += num_frames
self.data_dict = {}
keys = ep_dicts[0].keys()
for key in keys:
self.data_dict[key] = torch.cat([x[key] for x in ep_dicts])
self.data_dict["index"] = torch.arange(0, total_frames, 1)

View File

@ -1,5 +1,4 @@
import logging
import os
from pathlib import Path
import torch
@ -8,11 +7,6 @@ from torchvision.transforms import v2
from lerobot.common.datasets.utils import compute_stats
from lerobot.common.transforms import NormalizeTransform, Prod
# DATA_DIR specifies to location where datasets are loaded. By default, DATA_DIR is None and
# we load from `$HOME/.cache/huggingface/hub/datasets`. For our unit tests, we set `DATA_DIR=tests/data`
# to load a subset of our datasets for faster continuous integration.
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
def make_dataset(
cfg,
@ -57,12 +51,11 @@ def make_dataset(
# instantiate a one frame dataset with light transform
stats_dataset = clsfunc(
dataset_id=cfg.dataset_id,
root=DATA_DIR,
transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
)
# load stats if the file exists already or compute stats and save it
precomputed_stats_path = stats_dataset.data_dir / "stats.pth"
precomputed_stats_path = Path("data") / cfg.dataset_id / "stats.pth"
if precomputed_stats_path.exists():
stats = torch.load(precomputed_stats_path)
else:
@ -94,7 +87,6 @@ def make_dataset(
dataset = clsfunc(
dataset_id=cfg.dataset_id,
root=DATA_DIR,
delta_timestamps=delta_timestamps,
transform=transforms,
)

View File

@ -1,25 +1,12 @@
from pathlib import Path
import einops
import numpy as np
import torch
import tqdm
from datasets import Dataset, load_dataset
from datasets import load_dataset
from lerobot.common.datasets._diffusion_policy_replay_buffer import (
ReplayBuffer as DiffusionPolicyReplayBuffer,
)
from lerobot.common.datasets.utils import download_and_extract_zip, load_previous_and_future_frames
# as define in env
SUCCESS_THRESHOLD = 0.95 # 95% coverage,
PUSHT_URL = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip"
PUSHT_ZARR = Path("pusht/pusht_cchi_v7_replay.zarr")
from lerobot.common.datasets.utils import load_previous_and_future_frames
class PushtDataset(torch.utils.data.Dataset):
"""
https://huggingface.co/datasets/lerobot/pusht
Arguments
----------
@ -35,32 +22,19 @@ class PushtDataset(torch.utils.data.Dataset):
def __init__(
self,
dataset_id: str,
version: str | None = "v1.2",
root: Path | None = None,
version: str | None = "v1.0",
transform: callable = None,
delta_timestamps: dict[list[float]] | None = None,
):
super().__init__()
self.dataset_id = dataset_id
self.version = version
self.root = root
self.transform = transform
self.delta_timestamps = delta_timestamps
# self.data_dir = self.root / f"{self.dataset_id}"
# if (self.data_dir / "data_dict.pth").exists() and (
# self.data_dir / "data_ids_per_episode.pth"
# ).exists():
# self.data_dict = torch.load(self.data_dir / "data_dict.pth")
# self.data_ids_per_episode = torch.load(self.data_dir / "data_ids_per_episode.pth")
# else:
# self._download_and_preproc_obsolete()
# self.data_dir.mkdir(parents=True, exist_ok=True)
# torch.save(self.data_dict, self.data_dir / "data_dict.pth")
# torch.save(self.data_ids_per_episode, self.data_dir / "data_ids_per_episode.pth")
self.data_dict = load_dataset("lerobot/pusht", split="train")
# self.data_dict = load_dataset(f"lerobot/{self.dataset_id}", revision=self.version, split="train")
self.data_dict = load_dataset(f"lerobot/{self.dataset_id}", split="train")
self.data_dict = self.data_dict.with_format("torch")
self.data_dict.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
@property
def num_samples(self) -> int:
@ -87,135 +61,3 @@ class PushtDataset(torch.utils.data.Dataset):
item = self.transform(item)
return item
def _download_and_preproc_obsolete(self):
try:
import pymunk
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
except ModuleNotFoundError as e:
print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
raise e
assert self.root is not None
raw_dir = self.root / f"{self.dataset_id}_raw"
zarr_path = (raw_dir / PUSHT_ZARR).resolve()
if not zarr_path.is_dir():
raw_dir.mkdir(parents=True, exist_ok=True)
download_and_extract_zip(PUSHT_URL, raw_dir)
# load
dataset_dict = DiffusionPolicyReplayBuffer.copy_from_path(
zarr_path
) # , keys=['img', 'state', 'action'])
episode_ids = torch.from_numpy(dataset_dict.get_episode_idxs())
num_episodes = dataset_dict.meta["episode_ends"].shape[0]
total_frames = dataset_dict["action"].shape[0]
# to create test artifact
# num_episodes = 1
# total_frames = 50
assert len(
{dataset_dict[key].shape[0] for key in dataset_dict.keys()} # noqa: SIM118
), "Some data type dont have the same number of total frames."
# TODO: verify that goal pose is expected to be fixed
goal_pos_angle = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians)
goal_body = PushTEnv.get_goal_pose_body(goal_pos_angle)
imgs = torch.from_numpy(dataset_dict["img"])
imgs = einops.rearrange(imgs, "b h w c -> b c h w")
states = torch.from_numpy(dataset_dict["state"])
actions = torch.from_numpy(dataset_dict["action"])
self.data_ids_per_episode = {}
ep_dicts = []
idx0 = 0
for episode_id in tqdm.tqdm(range(num_episodes)):
idx1 = dataset_dict.meta["episode_ends"][episode_id]
num_frames = idx1 - idx0
assert (episode_ids[idx0:idx1] == episode_id).all()
image = imgs[idx0:idx1]
assert image.min() >= 0.0
assert image.max() <= 255.0
image = image.type(torch.uint8)
state = states[idx0:idx1]
agent_pos = state[:, :2]
block_pos = state[:, 2:4]
block_angle = state[:, 4]
reward = torch.zeros(num_frames)
success = torch.zeros(num_frames, dtype=torch.bool)
done = torch.zeros(num_frames, dtype=torch.bool)
for i in range(num_frames):
space = pymunk.Space()
space.gravity = 0, 0
space.damping = 0
# Add walls.
walls = [
PushTEnv.add_segment(space, (5, 506), (5, 5), 2),
PushTEnv.add_segment(space, (5, 5), (506, 5), 2),
PushTEnv.add_segment(space, (506, 5), (506, 506), 2),
PushTEnv.add_segment(space, (5, 506), (506, 506), 2),
]
space.add(*walls)
block_body = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item())
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
intersection_area = goal_geom.intersection(block_geom).area
goal_area = goal_geom.area
coverage = intersection_area / goal_area
reward[i] = np.clip(coverage / SUCCESS_THRESHOLD, 0, 1)
success[i] = coverage > SUCCESS_THRESHOLD
# last step of demonstration is considered done
done[-1] = True
ep_dict = {
"observation.image": image,
"observation.state": agent_pos,
"action": actions[idx0:idx1],
"episode": torch.tensor([episode_id] * num_frames, dtype=torch.int),
"frame_id": torch.arange(0, num_frames, 1),
"timestamp": torch.arange(0, num_frames, 1) / self.fps,
# "next.observation.image": image[1:],
# "next.observation.state": agent_pos[1:],
# TODO(rcadene): verify that reward and done are aligned with image and agent_pos
"next.reward": torch.cat([reward[1:], reward[[-1]]]),
"next.done": torch.cat([done[1:], done[[-1]]]),
"next.success": torch.cat([success[1:], success[[-1]]]),
}
ep_dicts.append(ep_dict)
assert isinstance(episode_id, int)
self.data_ids_per_episode[episode_id] = torch.arange(idx0, idx1, 1)
assert len(self.data_ids_per_episode[episode_id]) == num_frames
idx0 = idx1
self.data_dict = {}
keys = ep_dicts[0].keys()
for key in keys:
self.data_dict[key] = torch.cat([x[key] for x in ep_dicts])
self.data_dict["index"] = torch.arange(0, total_frames, 1)
dataset = Dataset.from_dict(self.data_dict)
dataset = dataset.with_format("torch")
def add_episode_data_id_from_to(frame):
ep_id = frame["episode"].item()
frame["episode_data_id_from"] = self.data_ids_per_episode[ep_id][0]
frame["episode_data_id_to"] = self.data_ids_per_episode[ep_id][-1]
return frame
dataset = dataset.map(add_episode_data_id_from_to, num_proc=4)
dataset = dataset.rename_column("episode", "episode_id")
dataset.push_to_hub("lerobot/pusht", token=True)

View File

@ -1,39 +1,11 @@
import io
import zipfile
from copy import deepcopy
from math import ceil
from pathlib import Path
import einops
import requests
import torch
import tqdm
def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
print(f"downloading from {url}")
response = requests.get(url, stream=True)
if response.status_code == 200:
total_size = int(response.headers.get("content-length", 0))
progress_bar = tqdm.tqdm(total=total_size, unit="B", unit_scale=True)
zip_file = io.BytesIO()
for chunk in response.iter_content(chunk_size=1024):
if chunk:
zip_file.write(chunk)
progress_bar.update(len(chunk))
progress_bar.close()
zip_file.seek(0)
with zipfile.ZipFile(zip_file, "r") as zip_ref:
zip_ref.extractall(destination_folder)
return True
else:
return False
def load_previous_and_future_frames(
item: dict[torch.Tensor],
data_dict: dict[torch.Tensor],

View File

@ -1,30 +1,14 @@
import pickle
import zipfile
from pathlib import Path
import torch
import tqdm
from datasets import load_dataset
from lerobot.common.datasets.utils import load_data_with_delta_timestamps
def download(raw_dir):
import gdown
raw_dir.mkdir(parents=True, exist_ok=True)
url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
zip_path = raw_dir / "data.zip"
gdown.download(url, str(zip_path), quiet=False)
print("Extracting...")
with zipfile.ZipFile(str(zip_path), "r") as zip_f:
for member in zip_f.namelist():
if member.startswith("data/xarm") and member.endswith(".pkl"):
print(member)
zip_f.extract(member=member)
zip_path.unlink()
from lerobot.common.datasets.utils import load_previous_and_future_frames
class XarmDataset(torch.utils.data.Dataset):
"""
https://huggingface.co/datasets/lerobot/xarm_lift_medium
"""
available_datasets = [
"xarm_lift_medium",
]
@ -34,130 +18,40 @@ class XarmDataset(torch.utils.data.Dataset):
def __init__(
self,
dataset_id: str,
version: str | None = "v1.1",
root: Path | None = None,
version: str | None = "v1.0",
transform: callable = None,
delta_timestamps: dict[list[float]] | None = None,
):
super().__init__()
self.dataset_id = dataset_id
self.version = version
self.root = root
self.transform = transform
self.delta_timestamps = delta_timestamps
self.data_dir = self.root / f"{self.dataset_id}"
if (self.data_dir / "data_dict.pth").exists() and (
self.data_dir / "data_ids_per_episode.pth"
).exists():
self.data_dict = torch.load(self.data_dir / "data_dict.pth")
self.data_ids_per_episode = torch.load(self.data_dir / "data_ids_per_episode.pth")
else:
self._download_and_preproc_obsolete()
self.data_dir.mkdir(parents=True, exist_ok=True)
torch.save(self.data_dict, self.data_dir / "data_dict.pth")
torch.save(self.data_ids_per_episode, self.data_dir / "data_ids_per_episode.pth")
self.data_dict = load_dataset(f"lerobot/{self.dataset_id}", revision=self.version, split="train")
self.data_dict = self.data_dict.with_format("torch")
@property
def num_samples(self) -> int:
return len(self.data_dict["index"]) if "index" in self.data_dict else 0
return len(self.data_dict)
@property
def num_episodes(self) -> int:
return len(self.data_ids_per_episode)
return len(self.data_dict.unique("episode_id"))
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
item = {}
item = self.data_dict[idx]
# get episode id and timestamp of the sampled frame
current_ts = self.data_dict["timestamp"][idx].item()
episode = self.data_dict["episode"][idx].item()
for key in self.data_dict:
if self.delta_timestamps is not None and key in self.delta_timestamps:
data, is_pad = load_data_with_delta_timestamps(
if self.delta_timestamps is not None:
item = load_previous_and_future_frames(
item,
self.data_dict,
self.data_ids_per_episode,
self.delta_timestamps,
key,
current_ts,
episode,
)
item[key] = data
item[f"{key}_is_pad"] = is_pad
else:
item[key] = self.data_dict[key][idx]
if self.transform is not None:
item = self.transform(item)
return item
def _download_and_preproc_obsolete(self):
assert self.root is not None
raw_dir = self.root / f"{self.dataset_id}_raw"
if not raw_dir.exists():
download(raw_dir)
dataset_path = self.root / f"{self.dataset_id}" / "buffer.pkl"
print(f"Using offline dataset '{dataset_path}'")
with open(dataset_path, "rb") as f:
dataset_dict = pickle.load(f)
total_frames = dataset_dict["actions"].shape[0]
self.data_ids_per_episode = {}
ep_dicts = []
idx0 = 0
idx1 = 0
episode_id = 0
for i in tqdm.tqdm(range(total_frames)):
idx1 += 1
if not dataset_dict["dones"][i]:
continue
num_frames = idx1 - idx0
image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1])
state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1])
action = torch.tensor(dataset_dict["actions"][idx0:idx1])
# TODO(rcadene): we have a missing last frame which is the observation when the env is done
# it is critical to have this frame for tdmpc to predict a "done observation/state"
# next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][idx0:idx1])
# next_state = torch.tensor(dataset_dict["next_observations"]["state"][idx0:idx1])
next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1])
next_done = torch.tensor(dataset_dict["dones"][idx0:idx1])
ep_dict = {
"observation.image": image,
"observation.state": state,
"action": action,
"episode": torch.tensor([episode_id] * num_frames, dtype=torch.int),
"frame_id": torch.arange(0, num_frames, 1),
"timestamp": torch.arange(0, num_frames, 1) / self.fps,
# "next.observation.image": next_image,
# "next.observation.state": next_state,
"next.reward": next_reward,
"next.done": next_done,
}
ep_dicts.append(ep_dict)
assert isinstance(episode_id, int)
self.data_ids_per_episode[episode_id] = torch.arange(idx0, idx1, 1)
assert len(self.data_ids_per_episode[episode_id]) == num_frames
idx0 = idx1
episode_id += 1
self.data_dict = {}
keys = ep_dicts[0].keys()
for key in keys:
self.data_dict[key] = torch.cat([x[key] for x in ep_dicts])
self.data_dict["index"] = torch.arange(0, total_frames, 1)

View File

@ -0,0 +1,415 @@
"""
This file contains all obsolete download scripts. They are centralized here to not have to load
useless dependencies when using datasets.
"""
import io
import pickle
from pathlib import Path
import einops
import h5py
import numpy as np
import torch
import tqdm
from datasets import Dataset
def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
import zipfile
import requests
print(f"downloading from {url}")
response = requests.get(url, stream=True)
if response.status_code == 200:
total_size = int(response.headers.get("content-length", 0))
progress_bar = tqdm.tqdm(total=total_size, unit="B", unit_scale=True)
zip_file = io.BytesIO()
for chunk in response.iter_content(chunk_size=1024):
if chunk:
zip_file.write(chunk)
progress_bar.update(len(chunk))
progress_bar.close()
zip_file.seek(0)
with zipfile.ZipFile(zip_file, "r") as zip_ref:
zip_ref.extractall(destination_folder)
return True
else:
return False
def download_and_upload_pusht(root, dataset_id="pusht", fps=10):
try:
import pymunk
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
from lerobot.common.policies.diffusion.replay_buffer import (
ReplayBuffer as DiffusionPolicyReplayBuffer,
)
except ModuleNotFoundError as e:
print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
raise e
# as define in env
success_threshold = 0.95 # 95% coverage,
pusht_url = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip"
pusht_zarr = Path("pusht/pusht_cchi_v7_replay.zarr")
root = Path(root)
raw_dir = root / f"{dataset_id}_raw"
zarr_path = (raw_dir / pusht_zarr).resolve()
if not zarr_path.is_dir():
raw_dir.mkdir(parents=True, exist_ok=True)
download_and_extract_zip(pusht_url, raw_dir)
# load
dataset_dict = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path) # , keys=['img', 'state', 'action'])
episode_ids = torch.from_numpy(dataset_dict.get_episode_idxs())
num_episodes = dataset_dict.meta["episode_ends"].shape[0]
total_frames = dataset_dict["action"].shape[0]
# to create test artifact
# num_episodes = 1
# total_frames = 50
assert len(
{dataset_dict[key].shape[0] for key in dataset_dict.keys()} # noqa: SIM118
), "Some data type dont have the same number of total frames."
# TODO: verify that goal pose is expected to be fixed
goal_pos_angle = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians)
goal_body = PushTEnv.get_goal_pose_body(goal_pos_angle)
imgs = torch.from_numpy(dataset_dict["img"])
imgs = einops.rearrange(imgs, "b h w c -> b c h w")
states = torch.from_numpy(dataset_dict["state"])
actions = torch.from_numpy(dataset_dict["action"])
data_ids_per_episode = {}
ep_dicts = []
idx0 = 0
for episode_id in tqdm.tqdm(range(num_episodes)):
idx1 = dataset_dict.meta["episode_ends"][episode_id]
num_frames = idx1 - idx0
assert (episode_ids[idx0:idx1] == episode_id).all()
image = imgs[idx0:idx1]
assert image.min() >= 0.0
assert image.max() <= 255.0
image = image.type(torch.uint8)
state = states[idx0:idx1]
agent_pos = state[:, :2]
block_pos = state[:, 2:4]
block_angle = state[:, 4]
reward = torch.zeros(num_frames)
success = torch.zeros(num_frames, dtype=torch.bool)
done = torch.zeros(num_frames, dtype=torch.bool)
for i in range(num_frames):
space = pymunk.Space()
space.gravity = 0, 0
space.damping = 0
# Add walls.
walls = [
PushTEnv.add_segment(space, (5, 506), (5, 5), 2),
PushTEnv.add_segment(space, (5, 5), (506, 5), 2),
PushTEnv.add_segment(space, (506, 5), (506, 506), 2),
PushTEnv.add_segment(space, (5, 506), (506, 506), 2),
]
space.add(*walls)
block_body = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item())
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
intersection_area = goal_geom.intersection(block_geom).area
goal_area = goal_geom.area
coverage = intersection_area / goal_area
reward[i] = np.clip(coverage / success_threshold, 0, 1)
success[i] = coverage > success_threshold
# last step of demonstration is considered done
done[-1] = True
ep_dict = {
"observation.image": image,
"observation.state": agent_pos,
"action": actions[idx0:idx1],
"episode_id": torch.tensor([episode_id] * num_frames, dtype=torch.int),
"frame_id": torch.arange(0, num_frames, 1),
"timestamp": torch.arange(0, num_frames, 1) / fps,
# "next.observation.image": image[1:],
# "next.observation.state": agent_pos[1:],
# TODO(rcadene): verify that reward and done are aligned with image and agent_pos
"next.reward": torch.cat([reward[1:], reward[[-1]]]),
"next.done": torch.cat([done[1:], done[[-1]]]),
"next.success": torch.cat([success[1:], success[[-1]]]),
}
ep_dicts.append(ep_dict)
assert isinstance(episode_id, int)
data_ids_per_episode[episode_id] = torch.arange(idx0, idx1, 1)
assert len(data_ids_per_episode[episode_id]) == num_frames
idx0 = idx1
data_dict = {}
keys = ep_dicts[0].keys()
for key in keys:
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
data_dict["index"] = torch.arange(0, total_frames, 1)
dataset = Dataset.from_dict(data_dict)
dataset = dataset.with_format("torch")
def add_episode_data_id_from_to(frame):
ep_id = frame["episode_id"].item()
frame["episode_data_id_from"] = data_ids_per_episode[ep_id][0]
frame["episode_data_id_to"] = data_ids_per_episode[ep_id][-1]
return frame
dataset = dataset.map(add_episode_data_id_from_to, num_proc=4)
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
def download_and_upload_xarm(root, dataset_id, fps=15):
root = Path(root)
raw_dir = root / f"{dataset_id}_raw"
if not raw_dir.exists():
import zipfile
import gdown
raw_dir.mkdir(parents=True, exist_ok=True)
url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
zip_path = raw_dir / "data.zip"
gdown.download(url, str(zip_path), quiet=False)
print("Extracting...")
with zipfile.ZipFile(str(zip_path), "r") as zip_f:
for member in zip_f.namelist():
if member.startswith("data/xarm") and member.endswith(".pkl"):
print(member)
zip_f.extract(member=member)
zip_path.unlink()
dataset_path = root / f"{dataset_id}" / "buffer.pkl"
print(f"Using offline dataset '{dataset_path}'")
with open(dataset_path, "rb") as f:
dataset_dict = pickle.load(f)
total_frames = dataset_dict["actions"].shape[0]
data_ids_per_episode = {}
ep_dicts = []
idx0 = 0
idx1 = 0
episode_id = 0
for i in tqdm.tqdm(range(total_frames)):
idx1 += 1
if not dataset_dict["dones"][i]:
continue
num_frames = idx1 - idx0
image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1])
state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1])
action = torch.tensor(dataset_dict["actions"][idx0:idx1])
# TODO(rcadene): we have a missing last frame which is the observation when the env is done
# it is critical to have this frame for tdmpc to predict a "done observation/state"
# next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][idx0:idx1])
# next_state = torch.tensor(dataset_dict["next_observations"]["state"][idx0:idx1])
next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1])
next_done = torch.tensor(dataset_dict["dones"][idx0:idx1])
ep_dict = {
"observation.image": image,
"observation.state": state,
"action": action,
"episode_id": torch.tensor([episode_id] * num_frames, dtype=torch.int),
"frame_id": torch.arange(0, num_frames, 1),
"timestamp": torch.arange(0, num_frames, 1) / fps,
# "next.observation.image": next_image,
# "next.observation.state": next_state,
"next.reward": next_reward,
"next.done": next_done,
}
ep_dicts.append(ep_dict)
assert isinstance(episode_id, int)
data_ids_per_episode[episode_id] = torch.arange(idx0, idx1, 1)
assert len(data_ids_per_episode[episode_id]) == num_frames
idx0 = idx1
episode_id += 1
data_dict = {}
keys = ep_dicts[0].keys()
for key in keys:
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
data_dict["index"] = torch.arange(0, total_frames, 1)
dataset = Dataset.from_dict(data_dict)
dataset = dataset.with_format("torch")
def add_episode_data_id_from_to(frame):
ep_id = frame["episode_id"].item()
frame["episode_data_id_from"] = data_ids_per_episode[ep_id][0]
frame["episode_data_id_to"] = data_ids_per_episode[ep_id][-1]
return frame
dataset = dataset.map(add_episode_data_id_from_to, num_proc=4)
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
def download_and_upload_aloha(root, dataset_id, fps=50):
folder_urls = {
"aloha_sim_insertion_human": "https://drive.google.com/drive/folders/1RgyD0JgTX30H4IM5XZn8I3zSV_mr8pyF",
"aloha_sim_insertion_scripted": "https://drive.google.com/drive/folders/1TsojQQSXtHEoGnqgJ3gmpPQR2DPLtS2N",
"aloha_sim_transfer_cube_human": "https://drive.google.com/drive/folders/1sc-E4QYW7A0o23m1u2VWNGVq5smAsfCo",
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/drive/folders/1aRyoOhQwxhyt1J8XgEig4s6kzaw__LXj",
}
ep48_urls = {
"aloha_sim_insertion_human": "https://drive.google.com/file/d/18Cudl6nikDtgRolea7je8iF_gGKzynOP/view?usp=drive_link",
"aloha_sim_insertion_scripted": "https://drive.google.com/file/d/1wfMSZ24oOh5KR_0aaP3Cnu_c4ZCveduB/view?usp=drive_link",
"aloha_sim_transfer_cube_human": "https://drive.google.com/file/d/18smMymtr8tIxaNUQ61gW6dG50pt3MvGq/view?usp=drive_link",
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/file/d/1pnGIOd-E4-rhz2P3VxpknMKRZCoKt6eI/view?usp=drive_link",
}
ep49_urls = {
"aloha_sim_insertion_human": "https://drive.google.com/file/d/1C1kZYyROzs-PrLc0SkDgUgMi4-L3lauE/view?usp=drive_link",
"aloha_sim_insertion_scripted": "https://drive.google.com/file/d/17EuCUWS6uCCr6yyNzpXdcdE-_TTNCKtf/view?usp=drive_link",
"aloha_sim_transfer_cube_human": "https://drive.google.com/file/d/1Nk7l53d9sJoGDBKAOnNrExX5nLacATc6/view?usp=drive_link",
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/file/d/1GKReZHrXU73NMiC5zKCq_UtqPVtYq8eo/view?usp=drive_link",
}
num_episodes = {
"aloha_sim_insertion_human": 50,
"aloha_sim_insertion_scripted": 50,
"aloha_sim_transfer_cube_human": 50,
"aloha_sim_transfer_cube_scripted": 50,
}
episode_len = {
"aloha_sim_insertion_human": 500,
"aloha_sim_insertion_scripted": 400,
"aloha_sim_transfer_cube_human": 400,
"aloha_sim_transfer_cube_scripted": 400,
}
cameras = {
"aloha_sim_insertion_human": ["top"],
"aloha_sim_insertion_scripted": ["top"],
"aloha_sim_transfer_cube_human": ["top"],
"aloha_sim_transfer_cube_scripted": ["top"],
}
root = Path(root)
raw_dir = root / f"{dataset_id}_raw"
if not raw_dir.is_dir():
import gdown
assert dataset_id in folder_urls
assert dataset_id in ep48_urls
assert dataset_id in ep49_urls
raw_dir.mkdir(parents=True, exist_ok=True)
gdown.download_folder(folder_urls[dataset_id], output=str(raw_dir))
# because of the 50 files limit per directory, two files episode 48 and 49 were missing
gdown.download(ep48_urls[dataset_id], output=str(raw_dir / "episode_48.hdf5"), fuzzy=True)
gdown.download(ep49_urls[dataset_id], output=str(raw_dir / "episode_49.hdf5"), fuzzy=True)
data_ids_per_episode = {}
ep_dicts = []
frame_idx = 0
for ep_id in tqdm.tqdm(range(num_episodes[dataset_id])):
ep_path = raw_dir / f"episode_{ep_id}.hdf5"
with h5py.File(ep_path, "r") as ep:
num_frames = ep["/action"].shape[0]
assert episode_len[dataset_id] == num_frames
# last step of demonstration is considered done
done = torch.zeros(num_frames, dtype=torch.bool)
done[-1] = True
state = torch.from_numpy(ep["/observations/qpos"][:])
action = torch.from_numpy(ep["/action"][:])
ep_dict = {
"observation.state": state,
"action": action,
"episode_id": torch.tensor([ep_id] * num_frames),
"frame_id": torch.arange(0, num_frames, 1),
"timestamp": torch.arange(0, num_frames, 1) / fps,
# "next.observation.state": state,
# TODO(rcadene): compute reward and success
# "next.reward": reward,
"next.done": done,
# "next.success": success,
}
for cam in cameras[dataset_id]:
image = torch.from_numpy(ep[f"/observations/images/{cam}"][:])
image = einops.rearrange(image, "b h w c -> b c h w").contiguous()
ep_dict[f"observation.images.{cam}"] = image
# ep_dict[f"next.observation.images.{cam}"] = image
assert isinstance(ep_id, int)
data_ids_per_episode[ep_id] = torch.arange(frame_idx, frame_idx + num_frames, 1)
assert len(data_ids_per_episode[ep_id]) == num_frames
ep_dicts.append(ep_dict)
frame_idx += num_frames
data_dict = {}
keys = ep_dicts[0].keys()
for key in keys:
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
total_frames = frame_idx
data_dict["index"] = torch.arange(0, total_frames, 1)
dataset = Dataset.from_dict(data_dict)
dataset = dataset.with_format("torch")
def add_episode_data_id_from_to(frame):
ep_id = frame["episode_id"].item()
frame["episode_data_id_from"] = data_ids_per_episode[ep_id][0]
frame["episode_data_id_to"] = data_ids_per_episode[ep_id][-1]
return frame
dataset = dataset.map(add_episode_data_id_from_to)
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
if __name__ == "__main__":
root = "data"
# download_and_upload_pusht(root, dataset_id="pusht")
# download_and_upload_xarm(root, dataset_id="xarm_lift_medium")
download_and_upload_aloha(root, dataset_id="aloha_sim_insertion_human")
download_and_upload_aloha(root, dataset_id="aloha_sim_insertion_scripted")
download_and_upload_aloha(root, dataset_id="aloha_sim_transfer_cube_human")
download_and_upload_aloha(root, dataset_id="aloha_sim_transfer_cube_scripted")

View File

@ -41,6 +41,7 @@ import gymnasium as gym
import imageio
import numpy as np
import torch
from datasets import Dataset
from huggingface_hub import snapshot_download
from lerobot.common.datasets.factory import make_dataset
@ -199,30 +200,28 @@ def eval_policy(
ep_dicts = []
num_episodes = dones.shape[0]
total_frames = 0
idx0 = idx1 = 0
data_ids_per_episode = {}
idx_from = 0
for ep_id in range(num_episodes):
num_frames = done_indices[ep_id].item() + 1
total_frames += num_frames
# TODO(rcadene): We need to add a missing last frame which is the observation
# of a done state. it is critical to have this frame for tdmpc to predict a "done observation/state"
ep_dict = {
"action": actions[ep_id, :num_frames],
"episode": torch.tensor([ep_id] * num_frames),
"episode_id": torch.tensor([ep_id] * num_frames),
"frame_id": torch.arange(0, num_frames, 1),
"timestamp": torch.arange(0, num_frames, 1) / fps,
"next.done": dones[ep_id, :num_frames],
"next.reward": rewards[ep_id, :num_frames].type(torch.float32),
"episode_data_id_from": torch.tensor([idx_from] * num_frames),
"episode_data_id_to": torch.tensor([idx_from + num_frames - 1] * num_frames),
}
for key in observations:
ep_dict[key] = observations[key][ep_id, :num_frames]
ep_dicts.append(ep_dict)
total_frames += num_frames
idx1 += num_frames
data_ids_per_episode[ep_id] = torch.arange(idx0, idx1, 1)
idx0 = idx1
idx_from += num_frames
# similar logic is implemented in dataset preprocessing
data_dict = {}
@ -231,6 +230,8 @@ def eval_policy(
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
data_dict["index"] = torch.arange(0, total_frames, 1)
data_dict = Dataset.from_dict(data_dict).with_format("torch")
if max_episodes_rendered > 0:
batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *)
@ -280,10 +281,7 @@ def eval_policy(
"eval_s": time.time() - start,
"eval_ep_s": (time.time() - start) / num_episodes,
},
"episodes": {
"data_dict": data_dict,
"data_ids_per_episode": data_ids_per_episode,
},
"episodes": data_dict,
}
if max_episodes_rendered > 0:
info["videos"] = videos

View File

@ -4,6 +4,8 @@ from pathlib import Path
import hydra
import torch
from datasets import concatenate_datasets
from datasets.utils.logging import disable_progress_bar
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import cycle
@ -128,29 +130,33 @@ def calculate_online_sample_weight(n_off: int, n_on: int, pc_on: float):
return -(n_off * pc_on) / (n_on * (pc_on - 1))
def add_episodes_inplace(episodes, online_dataset, concat_dataset, sampler, pc_online_samples):
data_dict = episodes["data_dict"]
data_ids_per_episode = episodes["data_ids_per_episode"]
def add_episodes_inplace(data_dict, online_dataset, concat_dataset, sampler, pc_online_samples):
first_episode_id = data_dict.select_columns("episode_id")[0]["episode_id"].item()
first_index = data_dict.select_columns("index")[0]["index"].item()
assert first_episode_id == 0, f"We expect the first episode_id to be 0 and not {first_episode_id}"
assert first_index == 0, f"We expect the first first_index to be 0 and not {first_index}"
if len(online_dataset) == 0:
# initialize online dataset
online_dataset.data_dict = data_dict
online_dataset.data_ids_per_episode = data_ids_per_episode
else:
# find episode index and data frame indices according to previous episode in online_dataset
start_episode = max(online_dataset.data_ids_per_episode.keys()) + 1
start_episode = online_dataset.data_dict["episode_id"][-1].item() + 1
start_index = online_dataset.data_dict["index"][-1].item() + 1
data_dict["episode"] += start_episode
data_dict["index"] += start_index
def shift_indices(example):
# note: we dont shift "frame_id" since it represents the index of the frame in the episode it belongs to
example["episode_id"] += start_episode
example["index"] += start_index
example["episode_data_id_from"] += start_index
example["episode_data_id_to"] += start_index
return example
disable_progress_bar() # map has a tqdm progres bar
data_dict = data_dict.map(shift_indices)
# extend online dataset
for key in data_dict:
# TODO(rcadene): avoid reallocating memory at every step by preallocating memory or changing our data structure
online_dataset.data_dict[key] = torch.cat([online_dataset.data_dict[key], data_dict[key]])
for ep_id in data_ids_per_episode:
online_dataset.data_ids_per_episode[ep_id + start_episode] = (
data_ids_per_episode[ep_id] + start_index
)
online_dataset.data_dict = concatenate_datasets([online_dataset.data_dict, data_dict])
# update the concatenated dataset length used during sampling
concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets)
@ -269,7 +275,6 @@ def train(cfg: dict, out_dir=None, job_name=None):
# create an empty online dataset similar to offline dataset
online_dataset = deepcopy(offline_dataset)
online_dataset.data_dict = {}
online_dataset.data_ids_per_episode = {}
# create dataloader for online training
concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset])