Add download_and_upload_dataset.py in script, update all datasets, update online training
This commit is contained in:
parent
c6aca7fe44
commit
67d79732f9
|
@ -1,72 +1,17 @@
|
|||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import einops
|
||||
import gdown
|
||||
import h5py
|
||||
import torch
|
||||
import tqdm
|
||||
from datasets import load_dataset
|
||||
|
||||
from lerobot.common.datasets.utils import load_data_with_delta_timestamps
|
||||
|
||||
FOLDER_URLS = {
|
||||
"aloha_sim_insertion_human": "https://drive.google.com/drive/folders/1RgyD0JgTX30H4IM5XZn8I3zSV_mr8pyF",
|
||||
"aloha_sim_insertion_scripted": "https://drive.google.com/drive/folders/1TsojQQSXtHEoGnqgJ3gmpPQR2DPLtS2N",
|
||||
"aloha_sim_transfer_cube_human": "https://drive.google.com/drive/folders/1sc-E4QYW7A0o23m1u2VWNGVq5smAsfCo",
|
||||
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/drive/folders/1aRyoOhQwxhyt1J8XgEig4s6kzaw__LXj",
|
||||
}
|
||||
|
||||
EP48_URLS = {
|
||||
"aloha_sim_insertion_human": "https://drive.google.com/file/d/18Cudl6nikDtgRolea7je8iF_gGKzynOP/view?usp=drive_link",
|
||||
"aloha_sim_insertion_scripted": "https://drive.google.com/file/d/1wfMSZ24oOh5KR_0aaP3Cnu_c4ZCveduB/view?usp=drive_link",
|
||||
"aloha_sim_transfer_cube_human": "https://drive.google.com/file/d/18smMymtr8tIxaNUQ61gW6dG50pt3MvGq/view?usp=drive_link",
|
||||
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/file/d/1pnGIOd-E4-rhz2P3VxpknMKRZCoKt6eI/view?usp=drive_link",
|
||||
}
|
||||
|
||||
EP49_URLS = {
|
||||
"aloha_sim_insertion_human": "https://drive.google.com/file/d/1C1kZYyROzs-PrLc0SkDgUgMi4-L3lauE/view?usp=drive_link",
|
||||
"aloha_sim_insertion_scripted": "https://drive.google.com/file/d/17EuCUWS6uCCr6yyNzpXdcdE-_TTNCKtf/view?usp=drive_link",
|
||||
"aloha_sim_transfer_cube_human": "https://drive.google.com/file/d/1Nk7l53d9sJoGDBKAOnNrExX5nLacATc6/view?usp=drive_link",
|
||||
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/file/d/1GKReZHrXU73NMiC5zKCq_UtqPVtYq8eo/view?usp=drive_link",
|
||||
}
|
||||
|
||||
NUM_EPISODES = {
|
||||
"aloha_sim_insertion_human": 50,
|
||||
"aloha_sim_insertion_scripted": 50,
|
||||
"aloha_sim_transfer_cube_human": 50,
|
||||
"aloha_sim_transfer_cube_scripted": 50,
|
||||
}
|
||||
|
||||
EPISODE_LEN = {
|
||||
"aloha_sim_insertion_human": 500,
|
||||
"aloha_sim_insertion_scripted": 400,
|
||||
"aloha_sim_transfer_cube_human": 400,
|
||||
"aloha_sim_transfer_cube_scripted": 400,
|
||||
}
|
||||
|
||||
CAMERAS = {
|
||||
"aloha_sim_insertion_human": ["top"],
|
||||
"aloha_sim_insertion_scripted": ["top"],
|
||||
"aloha_sim_transfer_cube_human": ["top"],
|
||||
"aloha_sim_transfer_cube_scripted": ["top"],
|
||||
}
|
||||
|
||||
|
||||
def download(data_dir, dataset_id):
|
||||
assert dataset_id in FOLDER_URLS
|
||||
assert dataset_id in EP48_URLS
|
||||
assert dataset_id in EP49_URLS
|
||||
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
gdown.download_folder(FOLDER_URLS[dataset_id], output=str(data_dir))
|
||||
|
||||
# because of the 50 files limit per directory, two files episode 48 and 49 were missing
|
||||
gdown.download(EP48_URLS[dataset_id], output=str(data_dir / "episode_48.hdf5"), fuzzy=True)
|
||||
gdown.download(EP49_URLS[dataset_id], output=str(data_dir / "episode_49.hdf5"), fuzzy=True)
|
||||
from lerobot.common.datasets.utils import load_previous_and_future_frames
|
||||
|
||||
|
||||
class AlohaDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human
|
||||
https://huggingface.co/datasets/lerobot/aloha_sim_insertion_scripted
|
||||
https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human
|
||||
https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted
|
||||
"""
|
||||
|
||||
available_datasets = [
|
||||
"aloha_sim_insertion_human",
|
||||
"aloha_sim_insertion_scripted",
|
||||
|
@ -79,129 +24,40 @@ class AlohaDataset(torch.utils.data.Dataset):
|
|||
def __init__(
|
||||
self,
|
||||
dataset_id: str,
|
||||
version: str | None = "v1.2",
|
||||
root: Path | None = None,
|
||||
version: str | None = "v1.0",
|
||||
transform: callable = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dataset_id = dataset_id
|
||||
self.version = version
|
||||
self.root = root
|
||||
self.transform = transform
|
||||
self.delta_timestamps = delta_timestamps
|
||||
|
||||
self.data_dir = self.root / f"{self.dataset_id}"
|
||||
if (self.data_dir / "data_dict.pth").exists() and (
|
||||
self.data_dir / "data_ids_per_episode.pth"
|
||||
).exists():
|
||||
self.data_dict = torch.load(self.data_dir / "data_dict.pth")
|
||||
self.data_ids_per_episode = torch.load(self.data_dir / "data_ids_per_episode.pth")
|
||||
else:
|
||||
self._download_and_preproc_obsolete()
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(self.data_dict, self.data_dir / "data_dict.pth")
|
||||
torch.save(self.data_ids_per_episode, self.data_dir / "data_ids_per_episode.pth")
|
||||
self.data_dict = load_dataset(f"lerobot/{self.dataset_id}", revision=self.version, split="train")
|
||||
self.data_dict = self.data_dict.with_format("torch")
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
return len(self.data_dict["index"]) if "index" in self.data_dict else 0
|
||||
return len(self.data_dict)
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
return len(self.data_ids_per_episode)
|
||||
return len(self.data_dict.unique("episode_id"))
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = {}
|
||||
item = self.data_dict[idx]
|
||||
|
||||
# get episode id and timestamp of the sampled frame
|
||||
current_ts = self.data_dict["timestamp"][idx].item()
|
||||
episode = self.data_dict["episode"][idx].item()
|
||||
|
||||
for key in self.data_dict:
|
||||
if self.delta_timestamps is not None and key in self.delta_timestamps:
|
||||
data, is_pad = load_data_with_delta_timestamps(
|
||||
self.data_dict,
|
||||
self.data_ids_per_episode,
|
||||
self.delta_timestamps,
|
||||
key,
|
||||
current_ts,
|
||||
episode,
|
||||
)
|
||||
item[key] = data
|
||||
item[f"{key}_is_pad"] = is_pad
|
||||
else:
|
||||
item[key] = self.data_dict[key][idx]
|
||||
if self.delta_timestamps is not None:
|
||||
item = load_previous_and_future_frames(
|
||||
item,
|
||||
self.data_dict,
|
||||
self.delta_timestamps,
|
||||
)
|
||||
|
||||
if self.transform is not None:
|
||||
item = self.transform(item)
|
||||
|
||||
return item
|
||||
|
||||
def _download_and_preproc_obsolete(self):
|
||||
assert self.root is not None
|
||||
raw_dir = self.root / f"{self.dataset_id}_raw"
|
||||
if not raw_dir.is_dir():
|
||||
download(raw_dir, self.dataset_id)
|
||||
|
||||
total_frames = 0
|
||||
logging.info("Compute total number of frames to initialize offline buffer")
|
||||
for ep_id in range(NUM_EPISODES[self.dataset_id]):
|
||||
ep_path = raw_dir / f"episode_{ep_id}.hdf5"
|
||||
with h5py.File(ep_path, "r") as ep:
|
||||
total_frames += ep["/action"].shape[0] - 1
|
||||
logging.info(f"{total_frames=}")
|
||||
|
||||
self.data_ids_per_episode = {}
|
||||
ep_dicts = []
|
||||
|
||||
frame_idx = 0
|
||||
for ep_id in tqdm.tqdm(range(NUM_EPISODES[self.dataset_id])):
|
||||
ep_path = raw_dir / f"episode_{ep_id}.hdf5"
|
||||
with h5py.File(ep_path, "r") as ep:
|
||||
num_frames = ep["/action"].shape[0]
|
||||
|
||||
# last step of demonstration is considered done
|
||||
done = torch.zeros(num_frames, dtype=torch.bool)
|
||||
done[-1] = True
|
||||
|
||||
state = torch.from_numpy(ep["/observations/qpos"][:])
|
||||
action = torch.from_numpy(ep["/action"][:])
|
||||
|
||||
ep_dict = {
|
||||
"observation.state": state,
|
||||
"action": action,
|
||||
"episode": torch.tensor([ep_id] * num_frames),
|
||||
"frame_id": torch.arange(0, num_frames, 1),
|
||||
"timestamp": torch.arange(0, num_frames, 1) / self.fps,
|
||||
# "next.observation.state": state,
|
||||
# TODO(rcadene): compute reward and success
|
||||
# "next.reward": reward[1:],
|
||||
"next.done": done[1:],
|
||||
# "next.success": success[1:],
|
||||
}
|
||||
|
||||
for cam in CAMERAS[self.dataset_id]:
|
||||
image = torch.from_numpy(ep[f"/observations/images/{cam}"][:])
|
||||
image = einops.rearrange(image, "b h w c -> b c h w").contiguous()
|
||||
ep_dict[f"observation.images.{cam}"] = image[:-1]
|
||||
# ep_dict[f"next.observation.images.{cam}"] = image[1:]
|
||||
|
||||
assert isinstance(ep_id, int)
|
||||
self.data_ids_per_episode[ep_id] = torch.arange(frame_idx, frame_idx + num_frames, 1)
|
||||
assert len(self.data_ids_per_episode[ep_id]) == num_frames
|
||||
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
frame_idx += num_frames
|
||||
|
||||
self.data_dict = {}
|
||||
|
||||
keys = ep_dicts[0].keys()
|
||||
for key in keys:
|
||||
self.data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
||||
|
||||
self.data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
@ -8,11 +7,6 @@ from torchvision.transforms import v2
|
|||
from lerobot.common.datasets.utils import compute_stats
|
||||
from lerobot.common.transforms import NormalizeTransform, Prod
|
||||
|
||||
# DATA_DIR specifies to location where datasets are loaded. By default, DATA_DIR is None and
|
||||
# we load from `$HOME/.cache/huggingface/hub/datasets`. For our unit tests, we set `DATA_DIR=tests/data`
|
||||
# to load a subset of our datasets for faster continuous integration.
|
||||
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||
|
||||
|
||||
def make_dataset(
|
||||
cfg,
|
||||
|
@ -57,12 +51,11 @@ def make_dataset(
|
|||
# instantiate a one frame dataset with light transform
|
||||
stats_dataset = clsfunc(
|
||||
dataset_id=cfg.dataset_id,
|
||||
root=DATA_DIR,
|
||||
transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
|
||||
)
|
||||
|
||||
# load stats if the file exists already or compute stats and save it
|
||||
precomputed_stats_path = stats_dataset.data_dir / "stats.pth"
|
||||
precomputed_stats_path = Path("data") / cfg.dataset_id / "stats.pth"
|
||||
if precomputed_stats_path.exists():
|
||||
stats = torch.load(precomputed_stats_path)
|
||||
else:
|
||||
|
@ -94,7 +87,6 @@ def make_dataset(
|
|||
|
||||
dataset = clsfunc(
|
||||
dataset_id=cfg.dataset_id,
|
||||
root=DATA_DIR,
|
||||
delta_timestamps=delta_timestamps,
|
||||
transform=transforms,
|
||||
)
|
||||
|
|
|
@ -1,25 +1,12 @@
|
|||
from pathlib import Path
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
from datasets import Dataset, load_dataset
|
||||
from datasets import load_dataset
|
||||
|
||||
from lerobot.common.datasets._diffusion_policy_replay_buffer import (
|
||||
ReplayBuffer as DiffusionPolicyReplayBuffer,
|
||||
)
|
||||
from lerobot.common.datasets.utils import download_and_extract_zip, load_previous_and_future_frames
|
||||
|
||||
# as define in env
|
||||
SUCCESS_THRESHOLD = 0.95 # 95% coverage,
|
||||
|
||||
PUSHT_URL = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip"
|
||||
PUSHT_ZARR = Path("pusht/pusht_cchi_v7_replay.zarr")
|
||||
from lerobot.common.datasets.utils import load_previous_and_future_frames
|
||||
|
||||
|
||||
class PushtDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
https://huggingface.co/datasets/lerobot/pusht
|
||||
|
||||
Arguments
|
||||
----------
|
||||
|
@ -35,32 +22,19 @@ class PushtDataset(torch.utils.data.Dataset):
|
|||
def __init__(
|
||||
self,
|
||||
dataset_id: str,
|
||||
version: str | None = "v1.2",
|
||||
root: Path | None = None,
|
||||
version: str | None = "v1.0",
|
||||
transform: callable = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dataset_id = dataset_id
|
||||
self.version = version
|
||||
self.root = root
|
||||
self.transform = transform
|
||||
self.delta_timestamps = delta_timestamps
|
||||
|
||||
# self.data_dir = self.root / f"{self.dataset_id}"
|
||||
# if (self.data_dir / "data_dict.pth").exists() and (
|
||||
# self.data_dir / "data_ids_per_episode.pth"
|
||||
# ).exists():
|
||||
# self.data_dict = torch.load(self.data_dir / "data_dict.pth")
|
||||
# self.data_ids_per_episode = torch.load(self.data_dir / "data_ids_per_episode.pth")
|
||||
# else:
|
||||
# self._download_and_preproc_obsolete()
|
||||
# self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
# torch.save(self.data_dict, self.data_dir / "data_dict.pth")
|
||||
# torch.save(self.data_ids_per_episode, self.data_dir / "data_ids_per_episode.pth")
|
||||
|
||||
self.data_dict = load_dataset("lerobot/pusht", split="train")
|
||||
# self.data_dict = load_dataset(f"lerobot/{self.dataset_id}", revision=self.version, split="train")
|
||||
self.data_dict = load_dataset(f"lerobot/{self.dataset_id}", split="train")
|
||||
self.data_dict = self.data_dict.with_format("torch")
|
||||
self.data_dict.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
|
@ -87,135 +61,3 @@ class PushtDataset(torch.utils.data.Dataset):
|
|||
item = self.transform(item)
|
||||
|
||||
return item
|
||||
|
||||
def _download_and_preproc_obsolete(self):
|
||||
try:
|
||||
import pymunk
|
||||
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
|
||||
except ModuleNotFoundError as e:
|
||||
print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
|
||||
raise e
|
||||
|
||||
assert self.root is not None
|
||||
raw_dir = self.root / f"{self.dataset_id}_raw"
|
||||
zarr_path = (raw_dir / PUSHT_ZARR).resolve()
|
||||
if not zarr_path.is_dir():
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
download_and_extract_zip(PUSHT_URL, raw_dir)
|
||||
|
||||
# load
|
||||
dataset_dict = DiffusionPolicyReplayBuffer.copy_from_path(
|
||||
zarr_path
|
||||
) # , keys=['img', 'state', 'action'])
|
||||
|
||||
episode_ids = torch.from_numpy(dataset_dict.get_episode_idxs())
|
||||
num_episodes = dataset_dict.meta["episode_ends"].shape[0]
|
||||
total_frames = dataset_dict["action"].shape[0]
|
||||
# to create test artifact
|
||||
# num_episodes = 1
|
||||
# total_frames = 50
|
||||
assert len(
|
||||
{dataset_dict[key].shape[0] for key in dataset_dict.keys()} # noqa: SIM118
|
||||
), "Some data type dont have the same number of total frames."
|
||||
|
||||
# TODO: verify that goal pose is expected to be fixed
|
||||
goal_pos_angle = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians)
|
||||
goal_body = PushTEnv.get_goal_pose_body(goal_pos_angle)
|
||||
|
||||
imgs = torch.from_numpy(dataset_dict["img"])
|
||||
imgs = einops.rearrange(imgs, "b h w c -> b c h w")
|
||||
states = torch.from_numpy(dataset_dict["state"])
|
||||
actions = torch.from_numpy(dataset_dict["action"])
|
||||
|
||||
self.data_ids_per_episode = {}
|
||||
ep_dicts = []
|
||||
|
||||
idx0 = 0
|
||||
for episode_id in tqdm.tqdm(range(num_episodes)):
|
||||
idx1 = dataset_dict.meta["episode_ends"][episode_id]
|
||||
|
||||
num_frames = idx1 - idx0
|
||||
|
||||
assert (episode_ids[idx0:idx1] == episode_id).all()
|
||||
|
||||
image = imgs[idx0:idx1]
|
||||
assert image.min() >= 0.0
|
||||
assert image.max() <= 255.0
|
||||
image = image.type(torch.uint8)
|
||||
|
||||
state = states[idx0:idx1]
|
||||
agent_pos = state[:, :2]
|
||||
block_pos = state[:, 2:4]
|
||||
block_angle = state[:, 4]
|
||||
|
||||
reward = torch.zeros(num_frames)
|
||||
success = torch.zeros(num_frames, dtype=torch.bool)
|
||||
done = torch.zeros(num_frames, dtype=torch.bool)
|
||||
for i in range(num_frames):
|
||||
space = pymunk.Space()
|
||||
space.gravity = 0, 0
|
||||
space.damping = 0
|
||||
|
||||
# Add walls.
|
||||
walls = [
|
||||
PushTEnv.add_segment(space, (5, 506), (5, 5), 2),
|
||||
PushTEnv.add_segment(space, (5, 5), (506, 5), 2),
|
||||
PushTEnv.add_segment(space, (506, 5), (506, 506), 2),
|
||||
PushTEnv.add_segment(space, (5, 506), (506, 506), 2),
|
||||
]
|
||||
space.add(*walls)
|
||||
|
||||
block_body = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item())
|
||||
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
|
||||
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
|
||||
intersection_area = goal_geom.intersection(block_geom).area
|
||||
goal_area = goal_geom.area
|
||||
coverage = intersection_area / goal_area
|
||||
reward[i] = np.clip(coverage / SUCCESS_THRESHOLD, 0, 1)
|
||||
success[i] = coverage > SUCCESS_THRESHOLD
|
||||
|
||||
# last step of demonstration is considered done
|
||||
done[-1] = True
|
||||
|
||||
ep_dict = {
|
||||
"observation.image": image,
|
||||
"observation.state": agent_pos,
|
||||
"action": actions[idx0:idx1],
|
||||
"episode": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
||||
"frame_id": torch.arange(0, num_frames, 1),
|
||||
"timestamp": torch.arange(0, num_frames, 1) / self.fps,
|
||||
# "next.observation.image": image[1:],
|
||||
# "next.observation.state": agent_pos[1:],
|
||||
# TODO(rcadene): verify that reward and done are aligned with image and agent_pos
|
||||
"next.reward": torch.cat([reward[1:], reward[[-1]]]),
|
||||
"next.done": torch.cat([done[1:], done[[-1]]]),
|
||||
"next.success": torch.cat([success[1:], success[[-1]]]),
|
||||
}
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
assert isinstance(episode_id, int)
|
||||
self.data_ids_per_episode[episode_id] = torch.arange(idx0, idx1, 1)
|
||||
assert len(self.data_ids_per_episode[episode_id]) == num_frames
|
||||
|
||||
idx0 = idx1
|
||||
|
||||
self.data_dict = {}
|
||||
|
||||
keys = ep_dicts[0].keys()
|
||||
for key in keys:
|
||||
self.data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
||||
|
||||
self.data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
|
||||
dataset = Dataset.from_dict(self.data_dict)
|
||||
dataset = dataset.with_format("torch")
|
||||
|
||||
def add_episode_data_id_from_to(frame):
|
||||
ep_id = frame["episode"].item()
|
||||
frame["episode_data_id_from"] = self.data_ids_per_episode[ep_id][0]
|
||||
frame["episode_data_id_to"] = self.data_ids_per_episode[ep_id][-1]
|
||||
return frame
|
||||
|
||||
dataset = dataset.map(add_episode_data_id_from_to, num_proc=4)
|
||||
dataset = dataset.rename_column("episode", "episode_id")
|
||||
dataset.push_to_hub("lerobot/pusht", token=True)
|
||||
|
|
|
@ -1,39 +1,11 @@
|
|||
import io
|
||||
import zipfile
|
||||
from copy import deepcopy
|
||||
from math import ceil
|
||||
from pathlib import Path
|
||||
|
||||
import einops
|
||||
import requests
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
|
||||
def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
|
||||
print(f"downloading from {url}")
|
||||
response = requests.get(url, stream=True)
|
||||
if response.status_code == 200:
|
||||
total_size = int(response.headers.get("content-length", 0))
|
||||
progress_bar = tqdm.tqdm(total=total_size, unit="B", unit_scale=True)
|
||||
|
||||
zip_file = io.BytesIO()
|
||||
for chunk in response.iter_content(chunk_size=1024):
|
||||
if chunk:
|
||||
zip_file.write(chunk)
|
||||
progress_bar.update(len(chunk))
|
||||
|
||||
progress_bar.close()
|
||||
|
||||
zip_file.seek(0)
|
||||
|
||||
with zipfile.ZipFile(zip_file, "r") as zip_ref:
|
||||
zip_ref.extractall(destination_folder)
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def load_previous_and_future_frames(
|
||||
item: dict[torch.Tensor],
|
||||
data_dict: dict[torch.Tensor],
|
||||
|
|
|
@ -1,30 +1,14 @@
|
|||
import pickle
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
from datasets import load_dataset
|
||||
|
||||
from lerobot.common.datasets.utils import load_data_with_delta_timestamps
|
||||
|
||||
|
||||
def download(raw_dir):
|
||||
import gdown
|
||||
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
|
||||
zip_path = raw_dir / "data.zip"
|
||||
gdown.download(url, str(zip_path), quiet=False)
|
||||
print("Extracting...")
|
||||
with zipfile.ZipFile(str(zip_path), "r") as zip_f:
|
||||
for member in zip_f.namelist():
|
||||
if member.startswith("data/xarm") and member.endswith(".pkl"):
|
||||
print(member)
|
||||
zip_f.extract(member=member)
|
||||
zip_path.unlink()
|
||||
from lerobot.common.datasets.utils import load_previous_and_future_frames
|
||||
|
||||
|
||||
class XarmDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
https://huggingface.co/datasets/lerobot/xarm_lift_medium
|
||||
"""
|
||||
|
||||
available_datasets = [
|
||||
"xarm_lift_medium",
|
||||
]
|
||||
|
@ -34,130 +18,40 @@ class XarmDataset(torch.utils.data.Dataset):
|
|||
def __init__(
|
||||
self,
|
||||
dataset_id: str,
|
||||
version: str | None = "v1.1",
|
||||
root: Path | None = None,
|
||||
version: str | None = "v1.0",
|
||||
transform: callable = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dataset_id = dataset_id
|
||||
self.version = version
|
||||
self.root = root
|
||||
self.transform = transform
|
||||
self.delta_timestamps = delta_timestamps
|
||||
|
||||
self.data_dir = self.root / f"{self.dataset_id}"
|
||||
if (self.data_dir / "data_dict.pth").exists() and (
|
||||
self.data_dir / "data_ids_per_episode.pth"
|
||||
).exists():
|
||||
self.data_dict = torch.load(self.data_dir / "data_dict.pth")
|
||||
self.data_ids_per_episode = torch.load(self.data_dir / "data_ids_per_episode.pth")
|
||||
else:
|
||||
self._download_and_preproc_obsolete()
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(self.data_dict, self.data_dir / "data_dict.pth")
|
||||
torch.save(self.data_ids_per_episode, self.data_dir / "data_ids_per_episode.pth")
|
||||
self.data_dict = load_dataset(f"lerobot/{self.dataset_id}", revision=self.version, split="train")
|
||||
self.data_dict = self.data_dict.with_format("torch")
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
return len(self.data_dict["index"]) if "index" in self.data_dict else 0
|
||||
return len(self.data_dict)
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
return len(self.data_ids_per_episode)
|
||||
return len(self.data_dict.unique("episode_id"))
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = {}
|
||||
item = self.data_dict[idx]
|
||||
|
||||
# get episode id and timestamp of the sampled frame
|
||||
current_ts = self.data_dict["timestamp"][idx].item()
|
||||
episode = self.data_dict["episode"][idx].item()
|
||||
|
||||
for key in self.data_dict:
|
||||
if self.delta_timestamps is not None and key in self.delta_timestamps:
|
||||
data, is_pad = load_data_with_delta_timestamps(
|
||||
self.data_dict,
|
||||
self.data_ids_per_episode,
|
||||
self.delta_timestamps,
|
||||
key,
|
||||
current_ts,
|
||||
episode,
|
||||
)
|
||||
item[key] = data
|
||||
item[f"{key}_is_pad"] = is_pad
|
||||
else:
|
||||
item[key] = self.data_dict[key][idx]
|
||||
if self.delta_timestamps is not None:
|
||||
item = load_previous_and_future_frames(
|
||||
item,
|
||||
self.data_dict,
|
||||
self.delta_timestamps,
|
||||
)
|
||||
|
||||
if self.transform is not None:
|
||||
item = self.transform(item)
|
||||
|
||||
return item
|
||||
|
||||
def _download_and_preproc_obsolete(self):
|
||||
assert self.root is not None
|
||||
raw_dir = self.root / f"{self.dataset_id}_raw"
|
||||
if not raw_dir.exists():
|
||||
download(raw_dir)
|
||||
|
||||
dataset_path = self.root / f"{self.dataset_id}" / "buffer.pkl"
|
||||
print(f"Using offline dataset '{dataset_path}'")
|
||||
with open(dataset_path, "rb") as f:
|
||||
dataset_dict = pickle.load(f)
|
||||
|
||||
total_frames = dataset_dict["actions"].shape[0]
|
||||
|
||||
self.data_ids_per_episode = {}
|
||||
ep_dicts = []
|
||||
|
||||
idx0 = 0
|
||||
idx1 = 0
|
||||
episode_id = 0
|
||||
for i in tqdm.tqdm(range(total_frames)):
|
||||
idx1 += 1
|
||||
|
||||
if not dataset_dict["dones"][i]:
|
||||
continue
|
||||
|
||||
num_frames = idx1 - idx0
|
||||
|
||||
image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1])
|
||||
state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1])
|
||||
action = torch.tensor(dataset_dict["actions"][idx0:idx1])
|
||||
# TODO(rcadene): we have a missing last frame which is the observation when the env is done
|
||||
# it is critical to have this frame for tdmpc to predict a "done observation/state"
|
||||
# next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][idx0:idx1])
|
||||
# next_state = torch.tensor(dataset_dict["next_observations"]["state"][idx0:idx1])
|
||||
next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1])
|
||||
next_done = torch.tensor(dataset_dict["dones"][idx0:idx1])
|
||||
|
||||
ep_dict = {
|
||||
"observation.image": image,
|
||||
"observation.state": state,
|
||||
"action": action,
|
||||
"episode": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
||||
"frame_id": torch.arange(0, num_frames, 1),
|
||||
"timestamp": torch.arange(0, num_frames, 1) / self.fps,
|
||||
# "next.observation.image": next_image,
|
||||
# "next.observation.state": next_state,
|
||||
"next.reward": next_reward,
|
||||
"next.done": next_done,
|
||||
}
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
assert isinstance(episode_id, int)
|
||||
self.data_ids_per_episode[episode_id] = torch.arange(idx0, idx1, 1)
|
||||
assert len(self.data_ids_per_episode[episode_id]) == num_frames
|
||||
|
||||
idx0 = idx1
|
||||
episode_id += 1
|
||||
|
||||
self.data_dict = {}
|
||||
|
||||
keys = ep_dicts[0].keys()
|
||||
for key in keys:
|
||||
self.data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
||||
|
||||
self.data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
|
|
|
@ -0,0 +1,415 @@
|
|||
"""
|
||||
This file contains all obsolete download scripts. They are centralized here to not have to load
|
||||
useless dependencies when using datasets.
|
||||
"""
|
||||
|
||||
import io
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
import einops
|
||||
import h5py
|
||||
import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
from datasets import Dataset
|
||||
|
||||
|
||||
def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
|
||||
import zipfile
|
||||
|
||||
import requests
|
||||
|
||||
print(f"downloading from {url}")
|
||||
response = requests.get(url, stream=True)
|
||||
if response.status_code == 200:
|
||||
total_size = int(response.headers.get("content-length", 0))
|
||||
progress_bar = tqdm.tqdm(total=total_size, unit="B", unit_scale=True)
|
||||
|
||||
zip_file = io.BytesIO()
|
||||
for chunk in response.iter_content(chunk_size=1024):
|
||||
if chunk:
|
||||
zip_file.write(chunk)
|
||||
progress_bar.update(len(chunk))
|
||||
|
||||
progress_bar.close()
|
||||
|
||||
zip_file.seek(0)
|
||||
|
||||
with zipfile.ZipFile(zip_file, "r") as zip_ref:
|
||||
zip_ref.extractall(destination_folder)
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def download_and_upload_pusht(root, dataset_id="pusht", fps=10):
|
||||
try:
|
||||
import pymunk
|
||||
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
|
||||
|
||||
from lerobot.common.policies.diffusion.replay_buffer import (
|
||||
ReplayBuffer as DiffusionPolicyReplayBuffer,
|
||||
)
|
||||
except ModuleNotFoundError as e:
|
||||
print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
|
||||
raise e
|
||||
|
||||
# as define in env
|
||||
success_threshold = 0.95 # 95% coverage,
|
||||
|
||||
pusht_url = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip"
|
||||
pusht_zarr = Path("pusht/pusht_cchi_v7_replay.zarr")
|
||||
|
||||
root = Path(root)
|
||||
raw_dir = root / f"{dataset_id}_raw"
|
||||
zarr_path = (raw_dir / pusht_zarr).resolve()
|
||||
if not zarr_path.is_dir():
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
download_and_extract_zip(pusht_url, raw_dir)
|
||||
|
||||
# load
|
||||
dataset_dict = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path) # , keys=['img', 'state', 'action'])
|
||||
|
||||
episode_ids = torch.from_numpy(dataset_dict.get_episode_idxs())
|
||||
num_episodes = dataset_dict.meta["episode_ends"].shape[0]
|
||||
total_frames = dataset_dict["action"].shape[0]
|
||||
# to create test artifact
|
||||
# num_episodes = 1
|
||||
# total_frames = 50
|
||||
assert len(
|
||||
{dataset_dict[key].shape[0] for key in dataset_dict.keys()} # noqa: SIM118
|
||||
), "Some data type dont have the same number of total frames."
|
||||
|
||||
# TODO: verify that goal pose is expected to be fixed
|
||||
goal_pos_angle = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians)
|
||||
goal_body = PushTEnv.get_goal_pose_body(goal_pos_angle)
|
||||
|
||||
imgs = torch.from_numpy(dataset_dict["img"])
|
||||
imgs = einops.rearrange(imgs, "b h w c -> b c h w")
|
||||
states = torch.from_numpy(dataset_dict["state"])
|
||||
actions = torch.from_numpy(dataset_dict["action"])
|
||||
|
||||
data_ids_per_episode = {}
|
||||
ep_dicts = []
|
||||
|
||||
idx0 = 0
|
||||
for episode_id in tqdm.tqdm(range(num_episodes)):
|
||||
idx1 = dataset_dict.meta["episode_ends"][episode_id]
|
||||
|
||||
num_frames = idx1 - idx0
|
||||
|
||||
assert (episode_ids[idx0:idx1] == episode_id).all()
|
||||
|
||||
image = imgs[idx0:idx1]
|
||||
assert image.min() >= 0.0
|
||||
assert image.max() <= 255.0
|
||||
image = image.type(torch.uint8)
|
||||
|
||||
state = states[idx0:idx1]
|
||||
agent_pos = state[:, :2]
|
||||
block_pos = state[:, 2:4]
|
||||
block_angle = state[:, 4]
|
||||
|
||||
reward = torch.zeros(num_frames)
|
||||
success = torch.zeros(num_frames, dtype=torch.bool)
|
||||
done = torch.zeros(num_frames, dtype=torch.bool)
|
||||
for i in range(num_frames):
|
||||
space = pymunk.Space()
|
||||
space.gravity = 0, 0
|
||||
space.damping = 0
|
||||
|
||||
# Add walls.
|
||||
walls = [
|
||||
PushTEnv.add_segment(space, (5, 506), (5, 5), 2),
|
||||
PushTEnv.add_segment(space, (5, 5), (506, 5), 2),
|
||||
PushTEnv.add_segment(space, (506, 5), (506, 506), 2),
|
||||
PushTEnv.add_segment(space, (5, 506), (506, 506), 2),
|
||||
]
|
||||
space.add(*walls)
|
||||
|
||||
block_body = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item())
|
||||
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
|
||||
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
|
||||
intersection_area = goal_geom.intersection(block_geom).area
|
||||
goal_area = goal_geom.area
|
||||
coverage = intersection_area / goal_area
|
||||
reward[i] = np.clip(coverage / success_threshold, 0, 1)
|
||||
success[i] = coverage > success_threshold
|
||||
|
||||
# last step of demonstration is considered done
|
||||
done[-1] = True
|
||||
|
||||
ep_dict = {
|
||||
"observation.image": image,
|
||||
"observation.state": agent_pos,
|
||||
"action": actions[idx0:idx1],
|
||||
"episode_id": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
||||
"frame_id": torch.arange(0, num_frames, 1),
|
||||
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
||||
# "next.observation.image": image[1:],
|
||||
# "next.observation.state": agent_pos[1:],
|
||||
# TODO(rcadene): verify that reward and done are aligned with image and agent_pos
|
||||
"next.reward": torch.cat([reward[1:], reward[[-1]]]),
|
||||
"next.done": torch.cat([done[1:], done[[-1]]]),
|
||||
"next.success": torch.cat([success[1:], success[[-1]]]),
|
||||
}
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
assert isinstance(episode_id, int)
|
||||
data_ids_per_episode[episode_id] = torch.arange(idx0, idx1, 1)
|
||||
assert len(data_ids_per_episode[episode_id]) == num_frames
|
||||
|
||||
idx0 = idx1
|
||||
|
||||
data_dict = {}
|
||||
|
||||
keys = ep_dicts[0].keys()
|
||||
for key in keys:
|
||||
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
||||
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
|
||||
dataset = Dataset.from_dict(data_dict)
|
||||
dataset = dataset.with_format("torch")
|
||||
|
||||
def add_episode_data_id_from_to(frame):
|
||||
ep_id = frame["episode_id"].item()
|
||||
frame["episode_data_id_from"] = data_ids_per_episode[ep_id][0]
|
||||
frame["episode_data_id_to"] = data_ids_per_episode[ep_id][-1]
|
||||
return frame
|
||||
|
||||
dataset = dataset.map(add_episode_data_id_from_to, num_proc=4)
|
||||
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
|
||||
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
|
||||
|
||||
|
||||
def download_and_upload_xarm(root, dataset_id, fps=15):
|
||||
root = Path(root)
|
||||
raw_dir = root / f"{dataset_id}_raw"
|
||||
if not raw_dir.exists():
|
||||
import zipfile
|
||||
|
||||
import gdown
|
||||
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
|
||||
zip_path = raw_dir / "data.zip"
|
||||
gdown.download(url, str(zip_path), quiet=False)
|
||||
print("Extracting...")
|
||||
with zipfile.ZipFile(str(zip_path), "r") as zip_f:
|
||||
for member in zip_f.namelist():
|
||||
if member.startswith("data/xarm") and member.endswith(".pkl"):
|
||||
print(member)
|
||||
zip_f.extract(member=member)
|
||||
zip_path.unlink()
|
||||
|
||||
dataset_path = root / f"{dataset_id}" / "buffer.pkl"
|
||||
print(f"Using offline dataset '{dataset_path}'")
|
||||
with open(dataset_path, "rb") as f:
|
||||
dataset_dict = pickle.load(f)
|
||||
|
||||
total_frames = dataset_dict["actions"].shape[0]
|
||||
|
||||
data_ids_per_episode = {}
|
||||
ep_dicts = []
|
||||
|
||||
idx0 = 0
|
||||
idx1 = 0
|
||||
episode_id = 0
|
||||
for i in tqdm.tqdm(range(total_frames)):
|
||||
idx1 += 1
|
||||
|
||||
if not dataset_dict["dones"][i]:
|
||||
continue
|
||||
|
||||
num_frames = idx1 - idx0
|
||||
|
||||
image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1])
|
||||
state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1])
|
||||
action = torch.tensor(dataset_dict["actions"][idx0:idx1])
|
||||
# TODO(rcadene): we have a missing last frame which is the observation when the env is done
|
||||
# it is critical to have this frame for tdmpc to predict a "done observation/state"
|
||||
# next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][idx0:idx1])
|
||||
# next_state = torch.tensor(dataset_dict["next_observations"]["state"][idx0:idx1])
|
||||
next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1])
|
||||
next_done = torch.tensor(dataset_dict["dones"][idx0:idx1])
|
||||
|
||||
ep_dict = {
|
||||
"observation.image": image,
|
||||
"observation.state": state,
|
||||
"action": action,
|
||||
"episode_id": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
||||
"frame_id": torch.arange(0, num_frames, 1),
|
||||
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
||||
# "next.observation.image": next_image,
|
||||
# "next.observation.state": next_state,
|
||||
"next.reward": next_reward,
|
||||
"next.done": next_done,
|
||||
}
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
assert isinstance(episode_id, int)
|
||||
data_ids_per_episode[episode_id] = torch.arange(idx0, idx1, 1)
|
||||
assert len(data_ids_per_episode[episode_id]) == num_frames
|
||||
|
||||
idx0 = idx1
|
||||
episode_id += 1
|
||||
|
||||
data_dict = {}
|
||||
|
||||
keys = ep_dicts[0].keys()
|
||||
for key in keys:
|
||||
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
||||
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
|
||||
dataset = Dataset.from_dict(data_dict)
|
||||
dataset = dataset.with_format("torch")
|
||||
|
||||
def add_episode_data_id_from_to(frame):
|
||||
ep_id = frame["episode_id"].item()
|
||||
frame["episode_data_id_from"] = data_ids_per_episode[ep_id][0]
|
||||
frame["episode_data_id_to"] = data_ids_per_episode[ep_id][-1]
|
||||
return frame
|
||||
|
||||
dataset = dataset.map(add_episode_data_id_from_to, num_proc=4)
|
||||
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
|
||||
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
|
||||
|
||||
|
||||
def download_and_upload_aloha(root, dataset_id, fps=50):
|
||||
folder_urls = {
|
||||
"aloha_sim_insertion_human": "https://drive.google.com/drive/folders/1RgyD0JgTX30H4IM5XZn8I3zSV_mr8pyF",
|
||||
"aloha_sim_insertion_scripted": "https://drive.google.com/drive/folders/1TsojQQSXtHEoGnqgJ3gmpPQR2DPLtS2N",
|
||||
"aloha_sim_transfer_cube_human": "https://drive.google.com/drive/folders/1sc-E4QYW7A0o23m1u2VWNGVq5smAsfCo",
|
||||
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/drive/folders/1aRyoOhQwxhyt1J8XgEig4s6kzaw__LXj",
|
||||
}
|
||||
|
||||
ep48_urls = {
|
||||
"aloha_sim_insertion_human": "https://drive.google.com/file/d/18Cudl6nikDtgRolea7je8iF_gGKzynOP/view?usp=drive_link",
|
||||
"aloha_sim_insertion_scripted": "https://drive.google.com/file/d/1wfMSZ24oOh5KR_0aaP3Cnu_c4ZCveduB/view?usp=drive_link",
|
||||
"aloha_sim_transfer_cube_human": "https://drive.google.com/file/d/18smMymtr8tIxaNUQ61gW6dG50pt3MvGq/view?usp=drive_link",
|
||||
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/file/d/1pnGIOd-E4-rhz2P3VxpknMKRZCoKt6eI/view?usp=drive_link",
|
||||
}
|
||||
|
||||
ep49_urls = {
|
||||
"aloha_sim_insertion_human": "https://drive.google.com/file/d/1C1kZYyROzs-PrLc0SkDgUgMi4-L3lauE/view?usp=drive_link",
|
||||
"aloha_sim_insertion_scripted": "https://drive.google.com/file/d/17EuCUWS6uCCr6yyNzpXdcdE-_TTNCKtf/view?usp=drive_link",
|
||||
"aloha_sim_transfer_cube_human": "https://drive.google.com/file/d/1Nk7l53d9sJoGDBKAOnNrExX5nLacATc6/view?usp=drive_link",
|
||||
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/file/d/1GKReZHrXU73NMiC5zKCq_UtqPVtYq8eo/view?usp=drive_link",
|
||||
}
|
||||
|
||||
num_episodes = {
|
||||
"aloha_sim_insertion_human": 50,
|
||||
"aloha_sim_insertion_scripted": 50,
|
||||
"aloha_sim_transfer_cube_human": 50,
|
||||
"aloha_sim_transfer_cube_scripted": 50,
|
||||
}
|
||||
|
||||
episode_len = {
|
||||
"aloha_sim_insertion_human": 500,
|
||||
"aloha_sim_insertion_scripted": 400,
|
||||
"aloha_sim_transfer_cube_human": 400,
|
||||
"aloha_sim_transfer_cube_scripted": 400,
|
||||
}
|
||||
|
||||
cameras = {
|
||||
"aloha_sim_insertion_human": ["top"],
|
||||
"aloha_sim_insertion_scripted": ["top"],
|
||||
"aloha_sim_transfer_cube_human": ["top"],
|
||||
"aloha_sim_transfer_cube_scripted": ["top"],
|
||||
}
|
||||
|
||||
root = Path(root)
|
||||
raw_dir = root / f"{dataset_id}_raw"
|
||||
if not raw_dir.is_dir():
|
||||
import gdown
|
||||
|
||||
assert dataset_id in folder_urls
|
||||
assert dataset_id in ep48_urls
|
||||
assert dataset_id in ep49_urls
|
||||
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
gdown.download_folder(folder_urls[dataset_id], output=str(raw_dir))
|
||||
|
||||
# because of the 50 files limit per directory, two files episode 48 and 49 were missing
|
||||
gdown.download(ep48_urls[dataset_id], output=str(raw_dir / "episode_48.hdf5"), fuzzy=True)
|
||||
gdown.download(ep49_urls[dataset_id], output=str(raw_dir / "episode_49.hdf5"), fuzzy=True)
|
||||
|
||||
data_ids_per_episode = {}
|
||||
ep_dicts = []
|
||||
|
||||
frame_idx = 0
|
||||
for ep_id in tqdm.tqdm(range(num_episodes[dataset_id])):
|
||||
ep_path = raw_dir / f"episode_{ep_id}.hdf5"
|
||||
with h5py.File(ep_path, "r") as ep:
|
||||
num_frames = ep["/action"].shape[0]
|
||||
assert episode_len[dataset_id] == num_frames
|
||||
|
||||
# last step of demonstration is considered done
|
||||
done = torch.zeros(num_frames, dtype=torch.bool)
|
||||
done[-1] = True
|
||||
|
||||
state = torch.from_numpy(ep["/observations/qpos"][:])
|
||||
action = torch.from_numpy(ep["/action"][:])
|
||||
|
||||
ep_dict = {
|
||||
"observation.state": state,
|
||||
"action": action,
|
||||
"episode_id": torch.tensor([ep_id] * num_frames),
|
||||
"frame_id": torch.arange(0, num_frames, 1),
|
||||
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
||||
# "next.observation.state": state,
|
||||
# TODO(rcadene): compute reward and success
|
||||
# "next.reward": reward,
|
||||
"next.done": done,
|
||||
# "next.success": success,
|
||||
}
|
||||
|
||||
for cam in cameras[dataset_id]:
|
||||
image = torch.from_numpy(ep[f"/observations/images/{cam}"][:])
|
||||
image = einops.rearrange(image, "b h w c -> b c h w").contiguous()
|
||||
ep_dict[f"observation.images.{cam}"] = image
|
||||
# ep_dict[f"next.observation.images.{cam}"] = image
|
||||
|
||||
assert isinstance(ep_id, int)
|
||||
data_ids_per_episode[ep_id] = torch.arange(frame_idx, frame_idx + num_frames, 1)
|
||||
assert len(data_ids_per_episode[ep_id]) == num_frames
|
||||
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
frame_idx += num_frames
|
||||
|
||||
data_dict = {}
|
||||
|
||||
keys = ep_dicts[0].keys()
|
||||
for key in keys:
|
||||
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
||||
|
||||
total_frames = frame_idx
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
|
||||
dataset = Dataset.from_dict(data_dict)
|
||||
dataset = dataset.with_format("torch")
|
||||
|
||||
def add_episode_data_id_from_to(frame):
|
||||
ep_id = frame["episode_id"].item()
|
||||
frame["episode_data_id_from"] = data_ids_per_episode[ep_id][0]
|
||||
frame["episode_data_id_to"] = data_ids_per_episode[ep_id][-1]
|
||||
return frame
|
||||
|
||||
dataset = dataset.map(add_episode_data_id_from_to)
|
||||
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
|
||||
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
root = "data"
|
||||
# download_and_upload_pusht(root, dataset_id="pusht")
|
||||
# download_and_upload_xarm(root, dataset_id="xarm_lift_medium")
|
||||
download_and_upload_aloha(root, dataset_id="aloha_sim_insertion_human")
|
||||
download_and_upload_aloha(root, dataset_id="aloha_sim_insertion_scripted")
|
||||
download_and_upload_aloha(root, dataset_id="aloha_sim_transfer_cube_human")
|
||||
download_and_upload_aloha(root, dataset_id="aloha_sim_transfer_cube_scripted")
|
|
@ -41,6 +41,7 @@ import gymnasium as gym
|
|||
import imageio
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
|
@ -199,30 +200,28 @@ def eval_policy(
|
|||
ep_dicts = []
|
||||
num_episodes = dones.shape[0]
|
||||
total_frames = 0
|
||||
idx0 = idx1 = 0
|
||||
data_ids_per_episode = {}
|
||||
idx_from = 0
|
||||
for ep_id in range(num_episodes):
|
||||
num_frames = done_indices[ep_id].item() + 1
|
||||
total_frames += num_frames
|
||||
|
||||
# TODO(rcadene): We need to add a missing last frame which is the observation
|
||||
# of a done state. it is critical to have this frame for tdmpc to predict a "done observation/state"
|
||||
ep_dict = {
|
||||
"action": actions[ep_id, :num_frames],
|
||||
"episode": torch.tensor([ep_id] * num_frames),
|
||||
"episode_id": torch.tensor([ep_id] * num_frames),
|
||||
"frame_id": torch.arange(0, num_frames, 1),
|
||||
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
||||
"next.done": dones[ep_id, :num_frames],
|
||||
"next.reward": rewards[ep_id, :num_frames].type(torch.float32),
|
||||
"episode_data_id_from": torch.tensor([idx_from] * num_frames),
|
||||
"episode_data_id_to": torch.tensor([idx_from + num_frames - 1] * num_frames),
|
||||
}
|
||||
for key in observations:
|
||||
ep_dict[key] = observations[key][ep_id, :num_frames]
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
total_frames += num_frames
|
||||
idx1 += num_frames
|
||||
|
||||
data_ids_per_episode[ep_id] = torch.arange(idx0, idx1, 1)
|
||||
|
||||
idx0 = idx1
|
||||
idx_from += num_frames
|
||||
|
||||
# similar logic is implemented in dataset preprocessing
|
||||
data_dict = {}
|
||||
|
@ -231,6 +230,8 @@ def eval_policy(
|
|||
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
|
||||
data_dict = Dataset.from_dict(data_dict).with_format("torch")
|
||||
|
||||
if max_episodes_rendered > 0:
|
||||
batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *)
|
||||
|
||||
|
@ -280,10 +281,7 @@ def eval_policy(
|
|||
"eval_s": time.time() - start,
|
||||
"eval_ep_s": (time.time() - start) / num_episodes,
|
||||
},
|
||||
"episodes": {
|
||||
"data_dict": data_dict,
|
||||
"data_ids_per_episode": data_ids_per_episode,
|
||||
},
|
||||
"episodes": data_dict,
|
||||
}
|
||||
if max_episodes_rendered > 0:
|
||||
info["videos"] = videos
|
||||
|
|
|
@ -4,6 +4,8 @@ from pathlib import Path
|
|||
|
||||
import hydra
|
||||
import torch
|
||||
from datasets import concatenate_datasets
|
||||
from datasets.utils.logging import disable_progress_bar
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.utils import cycle
|
||||
|
@ -128,29 +130,33 @@ def calculate_online_sample_weight(n_off: int, n_on: int, pc_on: float):
|
|||
return -(n_off * pc_on) / (n_on * (pc_on - 1))
|
||||
|
||||
|
||||
def add_episodes_inplace(episodes, online_dataset, concat_dataset, sampler, pc_online_samples):
|
||||
data_dict = episodes["data_dict"]
|
||||
data_ids_per_episode = episodes["data_ids_per_episode"]
|
||||
def add_episodes_inplace(data_dict, online_dataset, concat_dataset, sampler, pc_online_samples):
|
||||
first_episode_id = data_dict.select_columns("episode_id")[0]["episode_id"].item()
|
||||
first_index = data_dict.select_columns("index")[0]["index"].item()
|
||||
assert first_episode_id == 0, f"We expect the first episode_id to be 0 and not {first_episode_id}"
|
||||
assert first_index == 0, f"We expect the first first_index to be 0 and not {first_index}"
|
||||
|
||||
if len(online_dataset) == 0:
|
||||
# initialize online dataset
|
||||
online_dataset.data_dict = data_dict
|
||||
online_dataset.data_ids_per_episode = data_ids_per_episode
|
||||
else:
|
||||
# find episode index and data frame indices according to previous episode in online_dataset
|
||||
start_episode = max(online_dataset.data_ids_per_episode.keys()) + 1
|
||||
start_episode = online_dataset.data_dict["episode_id"][-1].item() + 1
|
||||
start_index = online_dataset.data_dict["index"][-1].item() + 1
|
||||
data_dict["episode"] += start_episode
|
||||
data_dict["index"] += start_index
|
||||
|
||||
def shift_indices(example):
|
||||
# note: we dont shift "frame_id" since it represents the index of the frame in the episode it belongs to
|
||||
example["episode_id"] += start_episode
|
||||
example["index"] += start_index
|
||||
example["episode_data_id_from"] += start_index
|
||||
example["episode_data_id_to"] += start_index
|
||||
return example
|
||||
|
||||
disable_progress_bar() # map has a tqdm progres bar
|
||||
data_dict = data_dict.map(shift_indices)
|
||||
|
||||
# extend online dataset
|
||||
for key in data_dict:
|
||||
# TODO(rcadene): avoid reallocating memory at every step by preallocating memory or changing our data structure
|
||||
online_dataset.data_dict[key] = torch.cat([online_dataset.data_dict[key], data_dict[key]])
|
||||
for ep_id in data_ids_per_episode:
|
||||
online_dataset.data_ids_per_episode[ep_id + start_episode] = (
|
||||
data_ids_per_episode[ep_id] + start_index
|
||||
)
|
||||
online_dataset.data_dict = concatenate_datasets([online_dataset.data_dict, data_dict])
|
||||
|
||||
# update the concatenated dataset length used during sampling
|
||||
concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets)
|
||||
|
@ -269,7 +275,6 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
# create an empty online dataset similar to offline dataset
|
||||
online_dataset = deepcopy(offline_dataset)
|
||||
online_dataset.data_dict = {}
|
||||
online_dataset.data_ids_per_episode = {}
|
||||
|
||||
# create dataloader for online training
|
||||
concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset])
|
||||
|
|
Loading…
Reference in New Issue