diff --git a/download_and_upload_dataset.py b/download_and_upload_dataset.py deleted file mode 100644 index 767885ac..00000000 --- a/download_and_upload_dataset.py +++ /dev/null @@ -1,779 +0,0 @@ -""" -This file contains all obsolete download scripts. They are centralized here to not have to load -useless dependencies when using datasets. -""" - -import io -import json -import pickle -import shutil -from pathlib import Path - -import einops -import h5py -import numpy as np -import torch -import tqdm -from datasets import Dataset, Features, Image, Sequence, Value -from huggingface_hub import HfApi -from PIL import Image as PILImage -from safetensors.torch import save_file - -from lerobot.common.datasets.utils import compute_stats, flatten_dict, hf_transform_to_torch - - -def download_and_upload(root, revision, dataset_id): - # TODO(rcadene, adilzouitine): add community_id/user_id (e.g. "lerobot", "cadene") or repo_id (e.g. "lerobot/pusht") - if "pusht" in dataset_id: - download_and_upload_pusht(root, revision, dataset_id) - elif "xarm" in dataset_id: - download_and_upload_xarm(root, revision, dataset_id) - elif "aloha" in dataset_id: - download_and_upload_aloha(root, revision, dataset_id) - elif "umi" in dataset_id: - download_and_upload_umi(root, revision, dataset_id) - else: - raise ValueError(dataset_id) - - -def concatenate_episodes(ep_dicts): - data_dict = {} - - keys = ep_dicts[0].keys() - for key in keys: - if torch.is_tensor(ep_dicts[0][key][0]): - data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts]) - else: - if key not in data_dict: - data_dict[key] = [] - for ep_dict in ep_dicts: - for x in ep_dict[key]: - data_dict[key].append(x) - - total_frames = data_dict["frame_index"].shape[0] - data_dict["index"] = torch.arange(0, total_frames, 1) - return data_dict - - -def download_and_extract_zip(url: str, destination_folder: Path) -> bool: - import zipfile - - import requests - - print(f"downloading from {url}") - response = requests.get(url, stream=True) - if response.status_code == 200: - total_size = int(response.headers.get("content-length", 0)) - progress_bar = tqdm.tqdm(total=total_size, unit="B", unit_scale=True) - - zip_file = io.BytesIO() - for chunk in response.iter_content(chunk_size=1024): - if chunk: - zip_file.write(chunk) - progress_bar.update(len(chunk)) - - progress_bar.close() - - zip_file.seek(0) - - with zipfile.ZipFile(zip_file, "r") as zip_ref: - zip_ref.extractall(destination_folder) - return True - else: - return False - - -def push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id): - # push to main to indicate latest version - hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True) - - # push to version branch - hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision=revision) - - # create and store meta_data - meta_data_dir = root / dataset_id / "meta_data" - meta_data_dir.mkdir(parents=True, exist_ok=True) - - api = HfApi() - - # info - info_path = meta_data_dir / "info.json" - with open(str(info_path), "w") as f: - json.dump(info, f, indent=4) - api.upload_file( - path_or_fileobj=info_path, - path_in_repo=str(info_path).replace(f"{root}/{dataset_id}", ""), - repo_id=f"lerobot/{dataset_id}", - repo_type="dataset", - ) - api.upload_file( - path_or_fileobj=info_path, - path_in_repo=str(info_path).replace(f"{root}/{dataset_id}", ""), - repo_id=f"lerobot/{dataset_id}", - repo_type="dataset", - revision=revision, - ) - - # stats - stats_path = meta_data_dir / "stats.safetensors" - save_file(flatten_dict(stats), stats_path) - api.upload_file( - path_or_fileobj=stats_path, - path_in_repo=str(stats_path).replace(f"{root}/{dataset_id}", ""), - repo_id=f"lerobot/{dataset_id}", - repo_type="dataset", - ) - api.upload_file( - path_or_fileobj=stats_path, - path_in_repo=str(stats_path).replace(f"{root}/{dataset_id}", ""), - repo_id=f"lerobot/{dataset_id}", - repo_type="dataset", - revision=revision, - ) - - # episode_data_index - episode_data_index = {key: torch.tensor(episode_data_index[key]) for key in episode_data_index} - ep_data_idx_path = meta_data_dir / "episode_data_index.safetensors" - save_file(episode_data_index, ep_data_idx_path) - api.upload_file( - path_or_fileobj=ep_data_idx_path, - path_in_repo=str(ep_data_idx_path).replace(f"{root}/{dataset_id}", ""), - repo_id=f"lerobot/{dataset_id}", - repo_type="dataset", - ) - api.upload_file( - path_or_fileobj=ep_data_idx_path, - path_in_repo=str(ep_data_idx_path).replace(f"{root}/{dataset_id}", ""), - repo_id=f"lerobot/{dataset_id}", - repo_type="dataset", - revision=revision, - ) - - # copy in tests folder, the first episode and the meta_data directory - num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0] - hf_dataset.select(range(num_items_first_ep)).with_format("torch").save_to_disk( - f"tests/data/lerobot/{dataset_id}/train" - ) - if Path(f"tests/data/lerobot/{dataset_id}/meta_data").exists(): - shutil.rmtree(f"tests/data/lerobot/{dataset_id}/meta_data") - shutil.copytree(meta_data_dir, f"tests/data/lerobot/{dataset_id}/meta_data") - - -def download_and_upload_pusht(root, revision, dataset_id="pusht", fps=10): - try: - import pymunk - from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely - - from lerobot.common.datasets._diffusion_policy_replay_buffer import ( - ReplayBuffer as DiffusionPolicyReplayBuffer, - ) - except ModuleNotFoundError as e: - print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`") - raise e - - # as define in env - success_threshold = 0.95 # 95% coverage, - - pusht_url = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip" - pusht_zarr = Path("pusht/pusht_cchi_v7_replay.zarr") - - root = Path(root) - raw_dir = root / f"{dataset_id}_raw" - zarr_path = (raw_dir / pusht_zarr).resolve() - if not zarr_path.is_dir(): - raw_dir.mkdir(parents=True, exist_ok=True) - download_and_extract_zip(pusht_url, raw_dir) - - # load - dataset_dict = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path) # , keys=['img', 'state', 'action']) - - episode_ids = torch.from_numpy(dataset_dict.get_episode_idxs()) - num_episodes = dataset_dict.meta["episode_ends"].shape[0] - assert len( - {dataset_dict[key].shape[0] for key in dataset_dict.keys()} # noqa: SIM118 - ), "Some data type dont have the same number of total frames." - - # TODO: verify that goal pose is expected to be fixed - goal_pos_angle = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians) - goal_body = PushTEnv.get_goal_pose_body(goal_pos_angle) - - imgs = torch.from_numpy(dataset_dict["img"]) # b h w c - states = torch.from_numpy(dataset_dict["state"]) - actions = torch.from_numpy(dataset_dict["action"]) - - ep_dicts = [] - episode_data_index = {"from": [], "to": []} - - id_from = 0 - for episode_id in tqdm.tqdm(range(num_episodes)): - id_to = dataset_dict.meta["episode_ends"][episode_id] - - num_frames = id_to - id_from - - assert (episode_ids[id_from:id_to] == episode_id).all() - - image = imgs[id_from:id_to] - assert image.min() >= 0.0 - assert image.max() <= 255.0 - image = image.type(torch.uint8) - - state = states[id_from:id_to] - agent_pos = state[:, :2] - block_pos = state[:, 2:4] - block_angle = state[:, 4] - - reward = torch.zeros(num_frames) - success = torch.zeros(num_frames, dtype=torch.bool) - done = torch.zeros(num_frames, dtype=torch.bool) - for i in range(num_frames): - space = pymunk.Space() - space.gravity = 0, 0 - space.damping = 0 - - # Add walls. - walls = [ - PushTEnv.add_segment(space, (5, 506), (5, 5), 2), - PushTEnv.add_segment(space, (5, 5), (506, 5), 2), - PushTEnv.add_segment(space, (506, 5), (506, 506), 2), - PushTEnv.add_segment(space, (5, 506), (506, 506), 2), - ] - space.add(*walls) - - block_body = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item()) - goal_geom = pymunk_to_shapely(goal_body, block_body.shapes) - block_geom = pymunk_to_shapely(block_body, block_body.shapes) - intersection_area = goal_geom.intersection(block_geom).area - goal_area = goal_geom.area - coverage = intersection_area / goal_area - reward[i] = np.clip(coverage / success_threshold, 0, 1) - success[i] = coverage > success_threshold - - # last step of demonstration is considered done - done[-1] = True - - ep_dict = { - "observation.image": [PILImage.fromarray(x.numpy()) for x in image], - "observation.state": agent_pos, - "action": actions[id_from:id_to], - "episode_index": torch.tensor([episode_id] * num_frames, dtype=torch.int), - "frame_index": torch.arange(0, num_frames, 1), - "timestamp": torch.arange(0, num_frames, 1) / fps, - # "next.observation.image": image[1:], - # "next.observation.state": agent_pos[1:], - # TODO(rcadene): verify that reward and done are aligned with image and agent_pos - "next.reward": torch.cat([reward[1:], reward[[-1]]]), - "next.done": torch.cat([done[1:], done[[-1]]]), - "next.success": torch.cat([success[1:], success[[-1]]]), - } - ep_dicts.append(ep_dict) - - episode_data_index["from"].append(id_from) - episode_data_index["to"].append(id_from + num_frames) - - id_from += num_frames - - data_dict = concatenate_episodes(ep_dicts) - - features = { - "observation.image": Image(), - "observation.state": Sequence( - length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None) - ), - "action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)), - "episode_index": Value(dtype="int64", id=None), - "frame_index": Value(dtype="int64", id=None), - "timestamp": Value(dtype="float32", id=None), - "next.reward": Value(dtype="float32", id=None), - "next.done": Value(dtype="bool", id=None), - "next.success": Value(dtype="bool", id=None), - "index": Value(dtype="int64", id=None), - } - features = Features(features) - hf_dataset = Dataset.from_dict(data_dict, features=features) - hf_dataset.set_transform(hf_transform_to_torch) - - info = { - "fps": fps, - } - stats = compute_stats(hf_dataset) - push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id) - - -def download_and_upload_xarm(root, revision, dataset_id, fps=15): - root = Path(root) - raw_dir = root / "xarm_datasets_raw" - if not raw_dir.exists(): - import zipfile - - import gdown - - raw_dir.mkdir(parents=True, exist_ok=True) - # from https://github.com/fyhMer/fowm/blob/main/scripts/download_datasets.py - url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya" - zip_path = raw_dir / "data.zip" - gdown.download(url, str(zip_path), quiet=False) - print("Extracting...") - with zipfile.ZipFile(str(zip_path), "r") as zip_f: - for member in zip_f.namelist(): - if member.startswith("data/xarm") and member.endswith(".pkl"): - print(member) - zip_f.extract(member=member) - zip_path.unlink() - - dataset_path = root / f"{dataset_id}" / "buffer.pkl" - print(f"Using offline dataset '{dataset_path}'") - with open(dataset_path, "rb") as f: - dataset_dict = pickle.load(f) - - ep_dicts = [] - episode_data_index = {"from": [], "to": []} - - id_from = 0 - id_to = 0 - episode_id = 0 - total_frames = dataset_dict["actions"].shape[0] - for i in tqdm.tqdm(range(total_frames)): - id_to += 1 - - if not dataset_dict["dones"][i]: - continue - - num_frames = id_to - id_from - - image = torch.tensor(dataset_dict["observations"]["rgb"][id_from:id_to]) - image = einops.rearrange(image, "b c h w -> b h w c") - state = torch.tensor(dataset_dict["observations"]["state"][id_from:id_to]) - action = torch.tensor(dataset_dict["actions"][id_from:id_to]) - # TODO(rcadene): we have a missing last frame which is the observation when the env is done - # it is critical to have this frame for tdmpc to predict a "done observation/state" - # next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][id_from:id_to]) - # next_state = torch.tensor(dataset_dict["next_observations"]["state"][id_from:id_to]) - next_reward = torch.tensor(dataset_dict["rewards"][id_from:id_to]) - next_done = torch.tensor(dataset_dict["dones"][id_from:id_to]) - - ep_dict = { - "observation.image": [PILImage.fromarray(x.numpy()) for x in image], - "observation.state": state, - "action": action, - "episode_index": torch.tensor([episode_id] * num_frames, dtype=torch.int), - "frame_index": torch.arange(0, num_frames, 1), - "timestamp": torch.arange(0, num_frames, 1) / fps, - # "next.observation.image": next_image, - # "next.observation.state": next_state, - "next.reward": next_reward, - "next.done": next_done, - } - ep_dicts.append(ep_dict) - - episode_data_index["from"].append(id_from) - episode_data_index["to"].append(id_from + num_frames) - - id_from = id_to - episode_id += 1 - - data_dict = concatenate_episodes(ep_dicts) - - features = { - "observation.image": Image(), - "observation.state": Sequence( - length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None) - ), - "action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)), - "episode_index": Value(dtype="int64", id=None), - "frame_index": Value(dtype="int64", id=None), - "timestamp": Value(dtype="float32", id=None), - "next.reward": Value(dtype="float32", id=None), - "next.done": Value(dtype="bool", id=None), - #'next.success': Value(dtype='bool', id=None), - "index": Value(dtype="int64", id=None), - } - features = Features(features) - hf_dataset = Dataset.from_dict(data_dict, features=features) - hf_dataset.set_transform(hf_transform_to_torch) - - info = { - "fps": fps, - } - stats = compute_stats(hf_dataset) - push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id) - - -def download_and_upload_aloha(root, revision, dataset_id, fps=50): - folder_urls = { - "aloha_sim_insertion_human": "https://drive.google.com/drive/folders/1RgyD0JgTX30H4IM5XZn8I3zSV_mr8pyF", - "aloha_sim_insertion_scripted": "https://drive.google.com/drive/folders/1TsojQQSXtHEoGnqgJ3gmpPQR2DPLtS2N", - "aloha_sim_transfer_cube_human": "https://drive.google.com/drive/folders/1sc-E4QYW7A0o23m1u2VWNGVq5smAsfCo", - "aloha_sim_transfer_cube_scripted": "https://drive.google.com/drive/folders/1aRyoOhQwxhyt1J8XgEig4s6kzaw__LXj", - } - - ep48_urls = { - "aloha_sim_insertion_human": "https://drive.google.com/file/d/18Cudl6nikDtgRolea7je8iF_gGKzynOP/view?usp=drive_link", - "aloha_sim_insertion_scripted": "https://drive.google.com/file/d/1wfMSZ24oOh5KR_0aaP3Cnu_c4ZCveduB/view?usp=drive_link", - "aloha_sim_transfer_cube_human": "https://drive.google.com/file/d/18smMymtr8tIxaNUQ61gW6dG50pt3MvGq/view?usp=drive_link", - "aloha_sim_transfer_cube_scripted": "https://drive.google.com/file/d/1pnGIOd-E4-rhz2P3VxpknMKRZCoKt6eI/view?usp=drive_link", - } - - ep49_urls = { - "aloha_sim_insertion_human": "https://drive.google.com/file/d/1C1kZYyROzs-PrLc0SkDgUgMi4-L3lauE/view?usp=drive_link", - "aloha_sim_insertion_scripted": "https://drive.google.com/file/d/17EuCUWS6uCCr6yyNzpXdcdE-_TTNCKtf/view?usp=drive_link", - "aloha_sim_transfer_cube_human": "https://drive.google.com/file/d/1Nk7l53d9sJoGDBKAOnNrExX5nLacATc6/view?usp=drive_link", - "aloha_sim_transfer_cube_scripted": "https://drive.google.com/file/d/1GKReZHrXU73NMiC5zKCq_UtqPVtYq8eo/view?usp=drive_link", - } - - num_episodes = { - "aloha_sim_insertion_human": 50, - "aloha_sim_insertion_scripted": 50, - "aloha_sim_transfer_cube_human": 50, - "aloha_sim_transfer_cube_scripted": 50, - } - - episode_len = { - "aloha_sim_insertion_human": 500, - "aloha_sim_insertion_scripted": 400, - "aloha_sim_transfer_cube_human": 400, - "aloha_sim_transfer_cube_scripted": 400, - } - - cameras = { - "aloha_sim_insertion_human": ["top"], - "aloha_sim_insertion_scripted": ["top"], - "aloha_sim_transfer_cube_human": ["top"], - "aloha_sim_transfer_cube_scripted": ["top"], - } - - root = Path(root) - raw_dir = root / f"{dataset_id}_raw" - if not raw_dir.is_dir(): - import gdown - - assert dataset_id in folder_urls - assert dataset_id in ep48_urls - assert dataset_id in ep49_urls - - raw_dir.mkdir(parents=True, exist_ok=True) - - gdown.download_folder(folder_urls[dataset_id], output=str(raw_dir)) - - # because of the 50 files limit per directory, two files episode 48 and 49 were missing - gdown.download(ep48_urls[dataset_id], output=str(raw_dir / "episode_48.hdf5"), fuzzy=True) - gdown.download(ep49_urls[dataset_id], output=str(raw_dir / "episode_49.hdf5"), fuzzy=True) - - ep_dicts = [] - episode_data_index = {"from": [], "to": []} - - id_from = 0 - for ep_id in tqdm.tqdm(range(num_episodes[dataset_id])): - ep_path = raw_dir / f"episode_{ep_id}.hdf5" - with h5py.File(ep_path, "r") as ep: - num_frames = ep["/action"].shape[0] - assert episode_len[dataset_id] == num_frames - - # last step of demonstration is considered done - done = torch.zeros(num_frames, dtype=torch.bool) - done[-1] = True - - state = torch.from_numpy(ep["/observations/qpos"][:]) - action = torch.from_numpy(ep["/action"][:]) - - ep_dict = {} - - for cam in cameras[dataset_id]: - image = torch.from_numpy(ep[f"/observations/images/{cam}"][:]) # b h w c - # image = einops.rearrange(image, "b h w c -> b c h w").contiguous() - ep_dict[f"observation.images.{cam}"] = [PILImage.fromarray(x.numpy()) for x in image] - # ep_dict[f"next.observation.images.{cam}"] = image - - ep_dict.update( - { - "observation.state": state, - "action": action, - "episode_index": torch.tensor([ep_id] * num_frames), - "frame_index": torch.arange(0, num_frames, 1), - "timestamp": torch.arange(0, num_frames, 1) / fps, - # "next.observation.state": state, - # TODO(rcadene): compute reward and success - # "next.reward": reward, - "next.done": done, - # "next.success": success, - } - ) - - assert isinstance(ep_id, int) - ep_dicts.append(ep_dict) - - episode_data_index["from"].append(id_from) - episode_data_index["to"].append(id_from + num_frames) - - id_from += num_frames - - data_dict = concatenate_episodes(ep_dicts) - - features = { - "observation.images.top": Image(), - "observation.state": Sequence( - length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None) - ), - "action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)), - "episode_index": Value(dtype="int64", id=None), - "frame_index": Value(dtype="int64", id=None), - "timestamp": Value(dtype="float32", id=None), - # "next.reward": Value(dtype="float32", id=None), - "next.done": Value(dtype="bool", id=None), - # "next.success": Value(dtype="bool", id=None), - "index": Value(dtype="int64", id=None), - } - features = Features(features) - hf_dataset = Dataset.from_dict(data_dict, features=features) - hf_dataset.set_transform(hf_transform_to_torch) - - info = { - "fps": fps, - } - stats = compute_stats(hf_dataset) - push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id) - - -def download_and_upload_umi(root, revision, dataset_id, fps=10): - # fps is equal to 10 source:https://arxiv.org/pdf/2402.10329.pdf#table.caption.16 - import os - import re - import shutil - from glob import glob - - import numpy as np - import torch - import tqdm - import zarr - from datasets import Dataset, Features, Image, Sequence, Value - - from lerobot.common.datasets._umi_imagecodecs_numcodecs import register_codecs - - # NOTE: This is critical otherwise ValueError: codec not available: 'imagecodecs_jpegxl' - # will be raised - register_codecs() - - url_cup_in_the_wild = "https://real.stanford.edu/umi/data/zarr_datasets/cup_in_the_wild.zarr.zip" - cup_in_the_wild_zarr = Path("umi/cup_in_the_wild/cup_in_the_wild.zarr") - - root = Path(root) - raw_dir = root / f"{dataset_id}_raw" - zarr_path = (raw_dir / cup_in_the_wild_zarr).resolve() - if not zarr_path.is_dir(): - raw_dir.mkdir(parents=True, exist_ok=True) - download_and_extract_zip(url_cup_in_the_wild, zarr_path) - zarr_data = zarr.open(zarr_path, mode="r") - - # We process the image data separately because it is too large to fit in memory - end_pose = torch.from_numpy(zarr_data["data/robot0_demo_end_pose"][:]) - start_pos = torch.from_numpy(zarr_data["data/robot0_demo_start_pose"][:]) - eff_pos = torch.from_numpy(zarr_data["data/robot0_eef_pos"][:]) - eff_rot_axis_angle = torch.from_numpy(zarr_data["data/robot0_eef_rot_axis_angle"][:]) - gripper_width = torch.from_numpy(zarr_data["data/robot0_gripper_width"][:]) - - states_pos = torch.cat([eff_pos, eff_rot_axis_angle], dim=1) - states = torch.cat([states_pos, gripper_width], dim=1) - - def get_episode_idxs(episode_ends: np.ndarray) -> np.ndarray: - # Optimized and simplified version of this function: https://github.com/real-stanford/universal_manipulation_interface/blob/298776ce251f33b6b3185a98d6e7d1f9ad49168b/diffusion_policy/common/replay_buffer.py#L374 - from numba import jit - - @jit(nopython=True) - def _get_episode_idxs(episode_ends): - result = np.zeros((episode_ends[-1],), dtype=np.int64) - start_idx = 0 - for episode_number, end_idx in enumerate(episode_ends): - result[start_idx:end_idx] = episode_number - start_idx = end_idx - return result - - return _get_episode_idxs(episode_ends) - - episode_ends = zarr_data["meta/episode_ends"][:] - num_episodes: int = episode_ends.shape[0] - - episode_ids = torch.from_numpy(get_episode_idxs(episode_ends)) - - # We convert it in torch tensor later because the jit function does not support torch tensors - episode_ends = torch.from_numpy(episode_ends) - - ep_dicts = [] - episode_data_index = {"from": [], "to": []} - id_from = 0 - - for episode_id in tqdm.tqdm(range(num_episodes)): - id_to = episode_ends[episode_id] - - num_frames = id_to - id_from - - assert ( - episode_ids[id_from:id_to] == episode_id - ).all(), f"episode_ids[{id_from}:{id_to}] != {episode_id}" - - state = states[id_from:id_to] - ep_dict = { - # observation.image will be filled later - "observation.state": state, - "episode_index": torch.tensor([episode_id] * num_frames, dtype=torch.int), - "frame_index": torch.arange(0, num_frames, 1), - "timestamp": torch.arange(0, num_frames, 1) / fps, - "episode_data_index_from": torch.tensor([id_from] * num_frames), - "episode_data_index_to": torch.tensor([id_from + num_frames] * num_frames), - "end_pose": end_pose[id_from:id_to], - "start_pos": start_pos[id_from:id_to], - "gripper_width": gripper_width[id_from:id_to], - } - ep_dicts.append(ep_dict) - episode_data_index["from"].append(id_from) - episode_data_index["to"].append(id_from + num_frames) - id_from += num_frames - - data_dict = concatenate_episodes(ep_dicts) - - total_frames = id_from - data_dict["index"] = torch.arange(0, total_frames, 1) - - print("Saving images to disk in temporary folder...") - # datasets.Image() can take a list of paths to images, so we save the images to a temporary folder - # to avoid loading them all in memory - _umi_save_images_concurrently(zarr_data, "tmp_umi_images", max_workers=12) - print("Saving images to disk in temporary folder... Done") - - # Sort files by number eg. 1.png, 2.png, 3.png, 9.png, 10.png instead of 1.png, 10.png, 2.png, 3.png, 9.png - # to correctly match the images with the data - images_path = sorted(glob("tmp_umi_images/*"), key=lambda x: int(re.search(r"(\d+)\.png$", x).group(1))) - data_dict["observation.image"] = images_path - - features = { - "observation.image": Image(), - "observation.state": Sequence( - length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None) - ), - "episode_index": Value(dtype="int64", id=None), - "frame_index": Value(dtype="int64", id=None), - "timestamp": Value(dtype="float32", id=None), - "index": Value(dtype="int64", id=None), - "episode_data_index_from": Value(dtype="int64", id=None), - "episode_data_index_to": Value(dtype="int64", id=None), - # `start_pos` and `end_pos` respectively represent the positions of the end-effector - # at the beginning and the end of the episode. - # `gripper_width` indicates the distance between the grippers, and this value is included - # in the state vector, which comprises the concatenation of the end-effector position - # and gripper width. - "end_pose": Sequence(length=data_dict["end_pose"].shape[1], feature=Value(dtype="float32", id=None)), - "start_pos": Sequence( - length=data_dict["start_pos"].shape[1], feature=Value(dtype="float32", id=None) - ), - "gripper_width": Sequence( - length=data_dict["gripper_width"].shape[1], feature=Value(dtype="float32", id=None) - ), - } - features = Features(features) - hf_dataset = Dataset.from_dict(data_dict, features=features) - hf_dataset.set_transform(hf_transform_to_torch) - - info = { - "fps": fps, - } - stats = compute_stats(hf_dataset) - push_to_hub( - hf_dataset=hf_dataset, - episode_data_index=episode_data_index, - info=info, - stats=stats, - root=root, - revision=revision, - dataset_id=dataset_id, - ) - # Cleanup - if os.path.exists("tmp_umi_images"): - print("Removing temporary images folder") - shutil.rmtree("tmp_umi_images") - print("Cleanup done") - - -def _umi_clear_folder(folder_path: str): - import os - - """ - Clears all the content of the specified folder. Creates the folder if it does not exist. - - Args: - folder_path (str): Path to the folder to clear. - - Examples: - >>> import os - >>> os.makedirs('example_folder', exist_ok=True) - >>> with open('example_folder/temp_file.txt', 'w') as f: - ... f.write('example') - >>> clear_folder('example_folder') - >>> os.listdir('example_folder') - [] - """ - if os.path.exists(folder_path): - for filename in os.listdir(folder_path): - file_path = os.path.join(folder_path, filename) - try: - if os.path.isfile(file_path) or os.path.islink(file_path): - os.unlink(file_path) - elif os.path.isdir(file_path): - shutil.rmtree(file_path) - except Exception as e: - print(f"Failed to delete {file_path}. Reason: {e}") - else: - os.makedirs(folder_path) - - -def _umi_save_image(img_array: np.array, i: int, folder_path: str): - import os - - """ - Saves a single image to the specified folder. - - Args: - img_array (ndarray): The numpy array of the image. - i (int): Index of the image, used for naming. - folder_path (str): Path to the folder where the image will be saved. - """ - img = PILImage.fromarray(img_array) - img_format = "PNG" if img_array.dtype == np.uint8 else "JPEG" - img.save(os.path.join(folder_path, f"{i}.{img_format.lower()}"), quality=100) - - -def _umi_save_images_concurrently(zarr_data: dict, folder_path: str, max_workers: int = 4): - from concurrent.futures import ThreadPoolExecutor - - """ - Saves images from the zarr_data to the specified folder using multithreading. - - Args: - zarr_data (dict): A dictionary containing image data in an array format. - folder_path (str): Path to the folder where images will be saved. - max_workers (int): The maximum number of threads to use for saving images. - """ - num_images = len(zarr_data["data/camera0_rgb"]) - _umi_clear_folder(folder_path) # Clear or create folder first - - with ThreadPoolExecutor(max_workers=max_workers) as executor: - [ - executor.submit(_umi_save_image, zarr_data["data/camera0_rgb"][i], i, folder_path) - for i in range(num_images) - ] - - -if __name__ == "__main__": - root = "data" - revision = "v1.1" - dataset_ids = [ - "pusht", - "xarm_lift_medium", - "xarm_lift_medium_replay", - "xarm_push_medium", - "xarm_push_medium_replay", - "aloha_sim_insertion_human", - "aloha_sim_insertion_scripted", - "aloha_sim_transfer_cube_human", - "aloha_sim_transfer_cube_scripted", - "umi_cup_in_the_wild", - ] - for dataset_id in dataset_ids: - download_and_upload(root, revision, dataset_id) diff --git a/lerobot/__init__.py b/lerobot/__init__.py index eb317e8c..e11fba8b 100644 --- a/lerobot/__init__.py +++ b/lerobot/__init__.py @@ -61,6 +61,13 @@ available_datasets = list( itertools.chain(*available_datasets_per_env.values(), available_datasets_without_env) ) +# TODO(rcadene, aliberts, alexander-soare): Add real-world env with a gym API +available_datasets_without_env = ["lerobot/umi_cup_in_the_wild"] + +available_datasets = list( + itertools.chain(*available_datasets_per_env.values(), available_datasets_without_env) +) + available_policies = [ "act", "diffusion", diff --git a/lerobot/common/datasets/_diffusion_policy_replay_buffer.py b/lerobot/common/datasets/push_dataset_to_hub/_diffusion_policy_replay_buffer.py similarity index 100% rename from lerobot/common/datasets/_diffusion_policy_replay_buffer.py rename to lerobot/common/datasets/push_dataset_to_hub/_diffusion_policy_replay_buffer.py diff --git a/lerobot/common/datasets/push_dataset_to_hub/_download_raw.py b/lerobot/common/datasets/push_dataset_to_hub/_download_raw.py new file mode 100644 index 00000000..263c929d --- /dev/null +++ b/lerobot/common/datasets/push_dataset_to_hub/_download_raw.py @@ -0,0 +1,179 @@ +""" +This file contains all obsolete download scripts. They are centralized here to not have to load +useless dependencies when using datasets. +""" + +import io +from pathlib import Path + +import tqdm + + +def download_raw(root, dataset_id) -> Path: + if "pusht" in dataset_id: + return download_pusht(root=root, dataset_id=dataset_id) + elif "xarm" in dataset_id: + return download_xarm(root=root, dataset_id=dataset_id) + elif "aloha" in dataset_id: + return download_aloha(root=root, dataset_id=dataset_id) + elif "umi" in dataset_id: + return download_umi(root=root, dataset_id=dataset_id) + else: + raise ValueError(dataset_id) + + +def download_and_extract_zip(url: str, destination_folder: Path) -> bool: + import zipfile + + import requests + + print(f"downloading from {url}") + response = requests.get(url, stream=True) + if response.status_code == 200: + total_size = int(response.headers.get("content-length", 0)) + progress_bar = tqdm.tqdm(total=total_size, unit="B", unit_scale=True) + + zip_file = io.BytesIO() + for chunk in response.iter_content(chunk_size=1024): + if chunk: + zip_file.write(chunk) + progress_bar.update(len(chunk)) + + progress_bar.close() + + zip_file.seek(0) + + with zipfile.ZipFile(zip_file, "r") as zip_ref: + zip_ref.extractall(destination_folder) + return True + else: + return False + + +def download_pusht(root: str, dataset_id: str = "pusht", fps: int = 10) -> Path: + pusht_url = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip" + pusht_zarr = Path("pusht/pusht_cchi_v7_replay.zarr") + + root = Path(root) + raw_dir: Path = root / f"{dataset_id}_raw" + zarr_path: Path = (raw_dir / pusht_zarr).resolve() + if not zarr_path.is_dir(): + raw_dir.mkdir(parents=True, exist_ok=True) + download_and_extract_zip(pusht_url, raw_dir) + return zarr_path + + +def download_xarm(root: str, dataset_id: str, fps: int = 15) -> Path: + root = Path(root) + raw_dir: Path = root / "xarm_datasets_raw" + if not raw_dir.exists(): + import zipfile + + import gdown + + raw_dir.mkdir(parents=True, exist_ok=True) + # from https://github.com/fyhMer/fowm/blob/main/scripts/download_datasets.py + url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya" + zip_path = raw_dir / "data.zip" + gdown.download(url, str(zip_path), quiet=False) + print("Extracting...") + with zipfile.ZipFile(str(zip_path), "r") as zip_f: + for member in zip_f.namelist(): + if member.startswith("data/xarm") and member.endswith(".pkl"): + print(member) + zip_f.extract(member=member) + zip_path.unlink() + + dataset_path: Path = root / f"{dataset_id}" + return dataset_path + + +def download_aloha(root: str, dataset_id: str) -> Path: + folder_urls = { + "aloha_sim_insertion_human": "https://drive.google.com/drive/folders/1RgyD0JgTX30H4IM5XZn8I3zSV_mr8pyF", + "aloha_sim_insertion_scripted": "https://drive.google.com/drive/folders/1TsojQQSXtHEoGnqgJ3gmpPQR2DPLtS2N", + "aloha_sim_transfer_cube_human": "https://drive.google.com/drive/folders/1sc-E4QYW7A0o23m1u2VWNGVq5smAsfCo", + "aloha_sim_transfer_cube_scripted": "https://drive.google.com/drive/folders/1aRyoOhQwxhyt1J8XgEig4s6kzaw__LXj", + } + + ep48_urls = { + "aloha_sim_insertion_human": "https://drive.google.com/file/d/18Cudl6nikDtgRolea7je8iF_gGKzynOP/view?usp=drive_link", + "aloha_sim_insertion_scripted": "https://drive.google.com/file/d/1wfMSZ24oOh5KR_0aaP3Cnu_c4ZCveduB/view?usp=drive_link", + "aloha_sim_transfer_cube_human": "https://drive.google.com/file/d/18smMymtr8tIxaNUQ61gW6dG50pt3MvGq/view?usp=drive_link", + "aloha_sim_transfer_cube_scripted": "https://drive.google.com/file/d/1pnGIOd-E4-rhz2P3VxpknMKRZCoKt6eI/view?usp=drive_link", + } + + ep49_urls = { + "aloha_sim_insertion_human": "https://drive.google.com/file/d/1C1kZYyROzs-PrLc0SkDgUgMi4-L3lauE/view?usp=drive_link", + "aloha_sim_insertion_scripted": "https://drive.google.com/file/d/17EuCUWS6uCCr6yyNzpXdcdE-_TTNCKtf/view?usp=drive_link", + "aloha_sim_transfer_cube_human": "https://drive.google.com/file/d/1Nk7l53d9sJoGDBKAOnNrExX5nLacATc6/view?usp=drive_link", + "aloha_sim_transfer_cube_scripted": "https://drive.google.com/file/d/1GKReZHrXU73NMiC5zKCq_UtqPVtYq8eo/view?usp=drive_link", + } + num_episodes = { # noqa: F841 # we keep this for reference + "aloha_sim_insertion_human": 50, + "aloha_sim_insertion_scripted": 50, + "aloha_sim_transfer_cube_human": 50, + "aloha_sim_transfer_cube_scripted": 50, + } + + episode_len = { # noqa: F841 # we keep this for reference + "aloha_sim_insertion_human": 500, + "aloha_sim_insertion_scripted": 400, + "aloha_sim_transfer_cube_human": 400, + "aloha_sim_transfer_cube_scripted": 400, + } + + cameras = { # noqa: F841 # we keep this for reference + "aloha_sim_insertion_human": ["top"], + "aloha_sim_insertion_scripted": ["top"], + "aloha_sim_transfer_cube_human": ["top"], + "aloha_sim_transfer_cube_scripted": ["top"], + } + root = Path(root) + raw_dir: Path = root / f"{dataset_id}_raw" + if not raw_dir.is_dir(): + import gdown + + assert dataset_id in folder_urls + assert dataset_id in ep48_urls + assert dataset_id in ep49_urls + + raw_dir.mkdir(parents=True, exist_ok=True) + + gdown.download_folder(folder_urls[dataset_id], output=str(raw_dir)) + + # because of the 50 files limit per directory, two files episode 48 and 49 were missing + gdown.download(ep48_urls[dataset_id], output=str(raw_dir / "episode_48.hdf5"), fuzzy=True) + gdown.download(ep49_urls[dataset_id], output=str(raw_dir / "episode_49.hdf5"), fuzzy=True) + return raw_dir + + +def download_umi(root: str, dataset_id: str) -> Path: + url_cup_in_the_wild = "https://real.stanford.edu/umi/data/zarr_datasets/cup_in_the_wild.zarr.zip" + cup_in_the_wild_zarr = Path("umi/cup_in_the_wild/cup_in_the_wild.zarr") + + root = Path(root) + raw_dir: Path = root / f"{dataset_id}_raw" + zarr_path: Path = (raw_dir / cup_in_the_wild_zarr).resolve() + if not zarr_path.is_dir(): + raw_dir.mkdir(parents=True, exist_ok=True) + download_and_extract_zip(url_cup_in_the_wild, zarr_path) + return zarr_path + + +if __name__ == "__main__": + root = "data" + dataset_ids = [ + "pusht", + "xarm_lift_medium", + "xarm_lift_medium_replay", + "xarm_push_medium", + "xarm_push_medium_replay", + "aloha_sim_insertion_human", + "aloha_sim_insertion_scripted", + "aloha_sim_transfer_cube_human", + "aloha_sim_transfer_cube_scripted", + "umi_cup_in_the_wild", + ] + for dataset_id in dataset_ids: + download_raw(root=root, dataset_id=dataset_id) diff --git a/lerobot/common/datasets/_umi_imagecodecs_numcodecs.py b/lerobot/common/datasets/push_dataset_to_hub/_umi_imagecodecs_numcodecs.py similarity index 100% rename from lerobot/common/datasets/_umi_imagecodecs_numcodecs.py rename to lerobot/common/datasets/push_dataset_to_hub/_umi_imagecodecs_numcodecs.py diff --git a/lerobot/common/datasets/push_dataset_to_hub/aloha_processor.py b/lerobot/common/datasets/push_dataset_to_hub/aloha_processor.py new file mode 100644 index 00000000..f6a66577 --- /dev/null +++ b/lerobot/common/datasets/push_dataset_to_hub/aloha_processor.py @@ -0,0 +1,199 @@ +import re +from pathlib import Path + +import h5py +import torch +import tqdm +from datasets import Dataset, Features, Image, Sequence, Value +from PIL import Image as PILImage + +from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes +from lerobot.common.datasets.utils import ( + hf_transform_to_torch, +) + + +class AlohaProcessor: + """ + Process HDF5 files formatted like in: https://github.com/tonyzhaozh/act + + Attributes: + folder_path (Path): Path to the directory containing HDF5 files. + cameras (list[str]): List of camera identifiers to check in the files. + fps (int): Frames per second used in timestamp calculations. + + Methods: + is_valid() -> bool: + Validates if each HDF5 file within the folder contains all required datasets. + preprocess() -> dict: + Processes the files and returns structured data suitable for further analysis. + to_hf_dataset(data_dict: dict) -> Dataset: + Converts processed data into a Hugging Face Dataset object. + """ + + def __init__(self, folder_path: Path, cameras: list[str] | None = None, fps: int | None = None): + """ + Initializes the AlohaProcessor with a specified directory path containing HDF5 files, + an optional list of cameras, and a frame rate. + + Args: + folder_path (Path): The directory path where HDF5 files are stored. + cameras (list[str] | None): Optional list of cameras to validate within the files. Defaults to ['top'] if None. + fps (int): Frame rate for the datasets, used in time calculations. Default is 50. + + Examples: + >>> processor = AlohaProcessor(Path("path_to_hdf5_directory"), ["camera1", "camera2"]) + >>> processor.is_valid() + True + """ + self.folder_path = folder_path + if cameras is None: + cameras = ["top"] + self.cameras = cameras + if fps is None: + fps = 50 + self._fps = fps + + @property + def fps(self) -> int: + return self._fps + + def is_valid(self) -> bool: + """ + Validates the HDF5 files in the specified folder to ensure they contain the required datasets + for actions, positions, and images for each specified camera. + + Returns: + bool: True if all files are valid HDF5 files with all required datasets, False otherwise. + """ + hdf5_files: list[Path] = list(self.folder_path.glob("episode_*.hdf5")) + if len(hdf5_files) == 0: + return False + try: + hdf5_files = sorted( + hdf5_files, key=lambda x: int(re.search(r"episode_(\d+).hdf5", x.name).group(1)) + ) + except AttributeError: + # All file names must contain a numerical identifier matching 'episode_(\\d+).hdf5 + return False + + # Check if the sequence is consecutive eg episode_0, episode_1, episode_2, etc. + # If not, return False + previous_number = None + for file in hdf5_files: + current_number = int(re.search(r"episode_(\d+).hdf5", file.name).group(1)) + if previous_number is not None and current_number - previous_number != 1: + return False + previous_number = current_number + + for file in hdf5_files: + try: + with h5py.File(file, "r") as file: + # Check for the expected datasets within the HDF5 file + required_datasets = ["/action", "/observations/qpos"] + # Add camera-specific image datasets to the required datasets + camera_datasets = [f"/observations/images/{cam}" for cam in self.cameras] + required_datasets.extend(camera_datasets) + + if not all(dataset in file for dataset in required_datasets): + return False + except OSError: + return False + return True + + def preprocess(self): + """ + Collects episode data from the HDF5 file and returns it as an AlohaStep named tuple. + + Returns: + AlohaStep: Named tuple containing episode data. + + Raises: + ValueError: If the file is not valid. + """ + if not self.is_valid(): + raise ValueError("The HDF5 file is invalid or does not contain the required datasets.") + + hdf5_files = list(self.folder_path.glob("*.hdf5")) + hdf5_files = sorted(hdf5_files, key=lambda x: int(re.search(r"episode_(\d+)", x.name).group(1))) + ep_dicts = [] + episode_data_index = {"from": [], "to": []} + + id_from = 0 + + for ep_path in tqdm.tqdm(hdf5_files): + with h5py.File(ep_path, "r") as ep: + ep_id = int(re.search(r"episode_(\d+)", ep_path.name).group(1)) + num_frames = ep["/action"].shape[0] + + # last step of demonstration is considered done + done = torch.zeros(num_frames, dtype=torch.bool) + done[-1] = True + + state = torch.from_numpy(ep["/observations/qpos"][:]) + action = torch.from_numpy(ep["/action"][:]) + + ep_dict = {} + + for cam in self.cameras: + image = torch.from_numpy(ep[f"/observations/images/{cam}"][:]) # b h w c + ep_dict[f"observation.images.{cam}"] = [PILImage.fromarray(x.numpy()) for x in image] + + ep_dict.update( + { + "observation.state": state, + "action": action, + "episode_index": torch.tensor([ep_id] * num_frames), + "frame_index": torch.arange(0, num_frames, 1), + "timestamp": torch.arange(0, num_frames, 1) / self.fps, + # TODO(rcadene): compute reward and success + # "next.reward": reward, + "next.done": done, + # "next.success": success, + } + ) + + assert isinstance(ep_id, int) + ep_dicts.append(ep_dict) + + episode_data_index["from"].append(id_from) + episode_data_index["to"].append(id_from + num_frames) + + id_from += num_frames + + data_dict = concatenate_episodes(ep_dicts) + return data_dict, episode_data_index + + def to_hf_dataset(self, data_dict) -> Dataset: + """ + Converts a dictionary of data into a Hugging Face Dataset object. + + Args: + data_dict (dict): A dictionary containing the data to be converted. + + Returns: + Dataset: The converted Hugging Face Dataset object. + """ + image_features = {f"observation.images.{cam}": Image() for cam in self.cameras} + features = { + "observation.state": Sequence( + length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None) + ), + "action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)), + "episode_index": Value(dtype="int64", id=None), + "frame_index": Value(dtype="int64", id=None), + "timestamp": Value(dtype="float32", id=None), + # "next.reward": Value(dtype="float32", id=None), + "next.done": Value(dtype="bool", id=None), + # "next.success": Value(dtype="bool", id=None), + "index": Value(dtype="int64", id=None), + } + update_features = {**image_features, **features} + features = Features(update_features) + hf_dataset = Dataset.from_dict(data_dict, features=features) + hf_dataset.set_transform(hf_transform_to_torch) + + return hf_dataset + + def cleanup(self): + pass diff --git a/lerobot/common/datasets/push_dataset_to_hub/pusht_processor.py b/lerobot/common/datasets/push_dataset_to_hub/pusht_processor.py new file mode 100644 index 00000000..2f0ec3d6 --- /dev/null +++ b/lerobot/common/datasets/push_dataset_to_hub/pusht_processor.py @@ -0,0 +1,180 @@ +from pathlib import Path + +import numpy as np +import torch +import tqdm +import zarr +from datasets import Dataset, Features, Image, Sequence, Value +from PIL import Image as PILImage + +from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes +from lerobot.common.datasets.utils import ( + hf_transform_to_torch, +) + + +class PushTProcessor: + """ Process zarr files formatted like in: https://github.com/real-stanford/diffusion_policy + """ + def __init__(self, folder_path: Path, fps: int | None = None): + self.zarr_path = folder_path + if fps is None: + fps = 10 + self._fps = fps + + @property + def fps(self) -> int: + return self._fps + + def is_valid(self): + try: + zarr_data = zarr.open(self.zarr_path, mode="r") + except Exception: + # TODO (azouitine): Handle the exception properly + return False + required_datasets = { + "data/action", + "data/img", + "data/keypoint", + "data/n_contacts", + "data/state", + "meta/episode_ends", + } + for dataset in required_datasets: + if dataset not in zarr_data: + return False + nb_frames = zarr_data["data/img"].shape[0] + + required_datasets.remove("meta/episode_ends") + + return all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets) + + def preprocess(self): + try: + import pymunk + from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely + + from lerobot.common.datasets.push_dataset_to_hub._diffusion_policy_replay_buffer import ( + ReplayBuffer as DiffusionPolicyReplayBuffer, + ) + except ModuleNotFoundError as e: + print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`") + raise e + + # as define in env + success_threshold = 0.95 # 95% coverage, + + dataset_dict = DiffusionPolicyReplayBuffer.copy_from_path( + self.zarr_path + ) # , keys=['img', 'state', 'action']) + + episode_ids = torch.from_numpy(dataset_dict.get_episode_idxs()) + num_episodes = dataset_dict.meta["episode_ends"].shape[0] + assert len( + {dataset_dict[key].shape[0] for key in dataset_dict.keys()} # noqa: SIM118 + ), "Some data type dont have the same number of total frames." + + # TODO: verify that goal pose is expected to be fixed + goal_pos_angle = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians) + goal_body = PushTEnv.get_goal_pose_body(goal_pos_angle) + + imgs = torch.from_numpy(dataset_dict["img"]) # b h w c + states = torch.from_numpy(dataset_dict["state"]) + actions = torch.from_numpy(dataset_dict["action"]) + + ep_dicts = [] + episode_data_index = {"from": [], "to": []} + + id_from = 0 + for episode_id in tqdm.tqdm(range(num_episodes)): + id_to = dataset_dict.meta["episode_ends"][episode_id] + + num_frames = id_to - id_from + + assert (episode_ids[id_from:id_to] == episode_id).all() + + image = imgs[id_from:id_to] + assert image.min() >= 0.0 + assert image.max() <= 255.0 + image = image.type(torch.uint8) + + state = states[id_from:id_to] + agent_pos = state[:, :2] + block_pos = state[:, 2:4] + block_angle = state[:, 4] + + reward = torch.zeros(num_frames) + success = torch.zeros(num_frames, dtype=torch.bool) + done = torch.zeros(num_frames, dtype=torch.bool) + for i in range(num_frames): + space = pymunk.Space() + space.gravity = 0, 0 + space.damping = 0 + + # Add walls. + walls = [ + PushTEnv.add_segment(space, (5, 506), (5, 5), 2), + PushTEnv.add_segment(space, (5, 5), (506, 5), 2), + PushTEnv.add_segment(space, (506, 5), (506, 506), 2), + PushTEnv.add_segment(space, (5, 506), (506, 506), 2), + ] + space.add(*walls) + + block_body = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item()) + goal_geom = pymunk_to_shapely(goal_body, block_body.shapes) + block_geom = pymunk_to_shapely(block_body, block_body.shapes) + intersection_area = goal_geom.intersection(block_geom).area + goal_area = goal_geom.area + coverage = intersection_area / goal_area + reward[i] = np.clip(coverage / success_threshold, 0, 1) + success[i] = coverage > success_threshold + + # last step of demonstration is considered done + done[-1] = True + + ep_dict = { + "observation.image": [PILImage.fromarray(x.numpy()) for x in image], + "observation.state": agent_pos, + "action": actions[id_from:id_to], + "episode_index": torch.tensor([episode_id] * num_frames, dtype=torch.int), + "frame_index": torch.arange(0, num_frames, 1), + "timestamp": torch.arange(0, num_frames, 1) / self.fps, + # "next.observation.image": image[1:], + # "next.observation.state": agent_pos[1:], + # TODO(rcadene): verify that reward and done are aligned with image and agent_pos + "next.reward": torch.cat([reward[1:], reward[[-1]]]), + "next.done": torch.cat([done[1:], done[[-1]]]), + "next.success": torch.cat([success[1:], success[[-1]]]), + } + ep_dicts.append(ep_dict) + + episode_data_index["from"].append(id_from) + episode_data_index["to"].append(id_from + num_frames) + + id_from += num_frames + + data_dict = concatenate_episodes(ep_dicts) + return data_dict, episode_data_index + + def to_hf_dataset(self, data_dict): + features = { + "observation.image": Image(), + "observation.state": Sequence( + length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None) + ), + "action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)), + "episode_index": Value(dtype="int64", id=None), + "frame_index": Value(dtype="int64", id=None), + "timestamp": Value(dtype="float32", id=None), + "next.reward": Value(dtype="float32", id=None), + "next.done": Value(dtype="bool", id=None), + "next.success": Value(dtype="bool", id=None), + "index": Value(dtype="int64", id=None), + } + features = Features(features) + hf_dataset = Dataset.from_dict(data_dict, features=features) + hf_dataset.set_transform(hf_transform_to_torch) + return hf_dataset + + def cleanup(self): + pass diff --git a/lerobot/common/datasets/push_dataset_to_hub/umi_processor.py b/lerobot/common/datasets/push_dataset_to_hub/umi_processor.py new file mode 100644 index 00000000..fb91ae08 --- /dev/null +++ b/lerobot/common/datasets/push_dataset_to_hub/umi_processor.py @@ -0,0 +1,280 @@ +import os +import re +import shutil +from glob import glob + +import numpy as np +import torch +import tqdm +import zarr +from datasets import Dataset, Features, Image, Sequence, Value +from PIL import Image as PILImage + +from lerobot.common.datasets.push_dataset_to_hub._umi_imagecodecs_numcodecs import register_codecs +from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes +from lerobot.common.datasets.utils import ( + hf_transform_to_torch, +) + + +class UmiProcessor: + """ + Process UMI (Universal Manipulation Interface) data stored in Zarr format like in: https://github.com/real-stanford/universal_manipulation_interface + + Attributes: + folder_path (str): The path to the folder containing Zarr datasets. + fps (int): Frames per second, used to calculate timestamps for frames. + + """ + + def __init__(self, folder_path: str, fps: int | None = None): + self.zarr_path = folder_path + if fps is None: + # TODO (azouitine): Add reference to the paper + fps = 15 + self._fps = fps + register_codecs() + + @property + def fps(self) -> int: + return self._fps + + def is_valid(self) -> bool: + """ + Validates the Zarr folder to ensure it contains all required datasets with consistent frame counts. + + Returns: + bool: True if all required datasets are present and have consistent frame counts, False otherwise. + """ + # Check if the Zarr folder is valid + try: + zarr_data = zarr.open(self.zarr_path, mode="r") + except Exception: + # TODO (azouitine): Handle the exception properly + return False + required_datasets = { + "data/robot0_demo_end_pose", + "data/robot0_demo_start_pose", + "data/robot0_eef_pos", + "data/robot0_eef_rot_axis_angle", + "data/robot0_gripper_width", + "meta/episode_ends", + "data/camera0_rgb", + } + for dataset in required_datasets: + if dataset not in zarr_data: + return False + nb_frames = zarr_data["data/camera0_rgb"].shape[0] + + required_datasets.remove("meta/episode_ends") + + return all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets) + + def preprocess(self): + """ + Collects and processes all episodes from the Zarr dataset into structured data dictionaries. + + Returns: + Tuple[Dict, Dict]: A tuple containing the structured episode data and episode index mappings. + """ + zarr_data = zarr.open(self.zarr_path, mode="r") + + # We process the image data separately because it is too large to fit in memory + end_pose = torch.from_numpy(zarr_data["data/robot0_demo_end_pose"][:]) + start_pos = torch.from_numpy(zarr_data["data/robot0_demo_start_pose"][:]) + eff_pos = torch.from_numpy(zarr_data["data/robot0_eef_pos"][:]) + eff_rot_axis_angle = torch.from_numpy(zarr_data["data/robot0_eef_rot_axis_angle"][:]) + gripper_width = torch.from_numpy(zarr_data["data/robot0_gripper_width"][:]) + + states_pos = torch.cat([eff_pos, eff_rot_axis_angle], dim=1) + states = torch.cat([states_pos, gripper_width], dim=1) + + episode_ends = zarr_data["meta/episode_ends"][:] + num_episodes: int = episode_ends.shape[0] + + episode_ids = torch.from_numpy(self.get_episode_idxs(episode_ends)) + + # We convert it in torch tensor later because the jit function does not support torch tensors + episode_ends = torch.from_numpy(episode_ends) + + ep_dicts = [] + episode_data_index = {"from": [], "to": []} + id_from = 0 + + for episode_id in tqdm.tqdm(range(num_episodes)): + id_to = episode_ends[episode_id] + + num_frames = id_to - id_from + + assert ( + episode_ids[id_from:id_to] == episode_id + ).all(), f"episode_ids[{id_from}:{id_to}] != {episode_id}" + + state = states[id_from:id_to] + ep_dict = { + # observation.image will be filled later + "observation.state": state, + "episode_index": torch.tensor([episode_id] * num_frames, dtype=torch.int), + "frame_index": torch.arange(0, num_frames, 1), + "timestamp": torch.arange(0, num_frames, 1) / self.fps, + "episode_data_index_from": torch.tensor([id_from] * num_frames), + "episode_data_index_to": torch.tensor([id_from + num_frames] * num_frames), + "end_pose": end_pose[id_from:id_to], + "start_pos": start_pos[id_from:id_to], + "gripper_width": gripper_width[id_from:id_to], + } + ep_dicts.append(ep_dict) + episode_data_index["from"].append(id_from) + episode_data_index["to"].append(id_from + num_frames) + id_from += num_frames + + data_dict = concatenate_episodes(ep_dicts) + + total_frames = id_from + data_dict["index"] = torch.arange(0, total_frames, 1) + + print("Saving images to disk in temporary folder...") + # datasets.Image() can take a list of paths to images, so we save the images to a temporary folder + # to avoid loading them all in memory + _save_images_concurrently( + data=zarr_data, image_key="data/camera0_rgb", folder_path="tmp_umi_images", max_workers=12 + ) + print("Saving images to disk in temporary folder... Done") + + # Sort files by number eg. 1.png, 2.png, 3.png, 9.png, 10.png instead of 1.png, 10.png, 2.png, 3.png, 9.png + # to correctly match the images with the data + images_path = sorted( + glob("tmp_umi_images/*"), key=lambda x: int(re.search(r"(\d+)\.png$", x).group(1)) + ) + data_dict["observation.image"] = images_path + print("Images saved to disk, do not forget to delete the folder tmp_umi_images/") + + # Cleanup + return data_dict, episode_data_index + + def to_hf_dataset(self, data_dict): + """ + Converts the processed data dictionary into a Hugging Face dataset with defined features. + + Args: + data_dict (Dict): The data dictionary containing tensors and episode information. + + Returns: + Dataset: A Hugging Face dataset constructed from the provided data dictionary. + """ + features = { + "observation.image": Image(), + "observation.state": Sequence( + length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None) + ), + "episode_index": Value(dtype="int64", id=None), + "frame_index": Value(dtype="int64", id=None), + "timestamp": Value(dtype="float32", id=None), + "index": Value(dtype="int64", id=None), + "episode_data_index_from": Value(dtype="int64", id=None), + "episode_data_index_to": Value(dtype="int64", id=None), + # `start_pos` and `end_pos` respectively represent the positions of the end-effector + # at the beginning and the end of the episode. + # `gripper_width` indicates the distance between the grippers, and this value is included + # in the state vector, which comprises the concatenation of the end-effector position + # and gripper width. + "end_pose": Sequence( + length=data_dict["end_pose"].shape[1], feature=Value(dtype="float32", id=None) + ), + "start_pos": Sequence( + length=data_dict["start_pos"].shape[1], feature=Value(dtype="float32", id=None) + ), + "gripper_width": Sequence( + length=data_dict["gripper_width"].shape[1], feature=Value(dtype="float32", id=None) + ), + } + features = Features(features) + hf_dataset = Dataset.from_dict(data_dict, features=features) + hf_dataset.set_transform(hf_transform_to_torch) + + return hf_dataset + + def cleanup(self): + # Cleanup + if os.path.exists("tmp_umi_images"): + print("Removing temporary images folder") + shutil.rmtree("tmp_umi_images") + print("Cleanup done") + + @classmethod + def get_episode_idxs(cls, episode_ends: np.ndarray) -> np.ndarray: + # Optimized and simplified version of this function: https://github.com/real-stanford/universal_manipulation_interface/blob/298776ce251f33b6b3185a98d6e7d1f9ad49168b/diffusion_policy/common/replay_buffer.py#L374 + from numba import jit + + @jit(nopython=True) + def _get_episode_idxs(episode_ends): + result = np.zeros((episode_ends[-1],), dtype=np.int64) + start_idx = 0 + for episode_number, end_idx in enumerate(episode_ends): + result[start_idx:end_idx] = episode_number + start_idx = end_idx + return result + + return _get_episode_idxs(episode_ends) + + +def _clear_folder(folder_path: str): + """ + Clears all the content of the specified folder. Creates the folder if it does not exist. + + Args: + folder_path (str): Path to the folder to clear. + + Examples: + >>> import os + >>> os.makedirs('example_folder', exist_ok=True) + >>> with open('example_folder/temp_file.txt', 'w') as f: + ... f.write('example') + >>> clear_folder('example_folder') + >>> os.listdir('example_folder') + [] + """ + if os.path.exists(folder_path): + for filename in os.listdir(folder_path): + file_path = os.path.join(folder_path, filename) + try: + if os.path.isfile(file_path) or os.path.islink(file_path): + os.unlink(file_path) + elif os.path.isdir(file_path): + shutil.rmtree(file_path) + except Exception as e: + print(f"Failed to delete {file_path}. Reason: {e}") + else: + os.makedirs(folder_path) + + +def _save_image(img_array: np.array, i: int, folder_path: str): + """ + Saves a single image to the specified folder. + + Args: + img_array (ndarray): The numpy array of the image. + i (int): Index of the image, used for naming. + folder_path (str): Path to the folder where the image will be saved. + """ + img = PILImage.fromarray(img_array) + img_format = "PNG" if img_array.dtype == np.uint8 else "JPEG" + img.save(os.path.join(folder_path, f"{i}.{img_format.lower()}"), quality=100) + + +def _save_images_concurrently(data: dict, image_key: str, folder_path: str, max_workers: int = 4): + from concurrent.futures import ThreadPoolExecutor + + """ + Saves images from the zarr_data to the specified folder using multithreading. + + Args: + zarr_data (dict): A dictionary containing image data in an array format. + folder_path (str): Path to the folder where images will be saved. + max_workers (int): The maximum number of threads to use for saving images. + """ + num_images = len(data["data/camera0_rgb"]) + _clear_folder(folder_path) # Clear or create folder first + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + [executor.submit(_save_image, data[image_key][i], i, folder_path) for i in range(num_images)] diff --git a/lerobot/common/datasets/push_dataset_to_hub/utils.py b/lerobot/common/datasets/push_dataset_to_hub/utils.py new file mode 100644 index 00000000..1076eb4e --- /dev/null +++ b/lerobot/common/datasets/push_dataset_to_hub/utils.py @@ -0,0 +1,20 @@ +import torch + + +def concatenate_episodes(ep_dicts): + data_dict = {} + + keys = ep_dicts[0].keys() + for key in keys: + if torch.is_tensor(ep_dicts[0][key][0]): + data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts]) + else: + if key not in data_dict: + data_dict[key] = [] + for ep_dict in ep_dicts: + for x in ep_dict[key]: + data_dict[key].append(x) + + total_frames = data_dict["frame_index"].shape[0] + data_dict["index"] = torch.arange(0, total_frames, 1) + return data_dict diff --git a/lerobot/common/datasets/push_dataset_to_hub/xarm_processor.py b/lerobot/common/datasets/push_dataset_to_hub/xarm_processor.py new file mode 100644 index 00000000..57401955 --- /dev/null +++ b/lerobot/common/datasets/push_dataset_to_hub/xarm_processor.py @@ -0,0 +1,145 @@ +import pickle +from pathlib import Path + +import einops +import torch +import tqdm +from datasets import Dataset, Features, Image, Sequence, Value +from PIL import Image as PILImage + +from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes +from lerobot.common.datasets.utils import ( + hf_transform_to_torch, +) + + +class XarmProcessor: + """Process pickle files formatted like in: https://github.com/fyhMer/fowm""" + + def __init__(self, folder_path: str, fps: int | None = None): + self.folder_path = Path(folder_path) + self.keys = {"actions", "rewards", "dones", "masks"} + self.nested_keys = {"observations": {"rgb", "state"}, "next_observations": {"rgb", "state"}} + if fps is None: + fps = 15 + self._fps = fps + + @property + def fps(self) -> int: + return self._fps + + def is_valid(self) -> bool: + # get all .pkl files + xarm_files = list(self.folder_path.glob("*.pkl")) + if len(xarm_files) != 1: + return False + + try: + with open(xarm_files[0], "rb") as f: + dataset_dict = pickle.load(f) + except Exception: + return False + + if not isinstance(dataset_dict, dict): + return False + + if not all(k in dataset_dict for k in self.keys): + return False + + # Check for consistent lengths in nested keys + try: + expected_len = len(dataset_dict["actions"]) + if any(len(dataset_dict[key]) != expected_len for key in self.keys if key in dataset_dict): + return False + + for key, subkeys in self.nested_keys.items(): + nested_dict = dataset_dict.get(key, {}) + if any( + len(nested_dict[subkey]) != expected_len for subkey in subkeys if subkey in nested_dict + ): + return False + except KeyError: # If any expected key or subkey is missing + return False + + return True # All checks passed + + def preprocess(self): + if not self.is_valid(): + raise ValueError("The Xarm file is invalid or does not contain the required datasets.") + + xarm_files = list(self.folder_path.glob("*.pkl")) + + with open(xarm_files[0], "rb") as f: + dataset_dict = pickle.load(f) + ep_dicts = [] + episode_data_index = {"from": [], "to": []} + + id_from = 0 + id_to = 0 + episode_id = 0 + total_frames = dataset_dict["actions"].shape[0] + for i in tqdm.tqdm(range(total_frames)): + id_to += 1 + + if not dataset_dict["dones"][i]: + continue + + num_frames = id_to - id_from + + image = torch.tensor(dataset_dict["observations"]["rgb"][id_from:id_to]) + image = einops.rearrange(image, "b c h w -> b h w c") + state = torch.tensor(dataset_dict["observations"]["state"][id_from:id_to]) + action = torch.tensor(dataset_dict["actions"][id_from:id_to]) + # TODO(rcadene): we have a missing last frame which is the observation when the env is done + # it is critical to have this frame for tdmpc to predict a "done observation/state" + # next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][id_from:id_to]) + # next_state = torch.tensor(dataset_dict["next_observations"]["state"][id_from:id_to]) + next_reward = torch.tensor(dataset_dict["rewards"][id_from:id_to]) + next_done = torch.tensor(dataset_dict["dones"][id_from:id_to]) + + ep_dict = { + "observation.image": [PILImage.fromarray(x.numpy()) for x in image], + "observation.state": state, + "action": action, + "episode_index": torch.tensor([episode_id] * num_frames, dtype=torch.int), + "frame_index": torch.arange(0, num_frames, 1), + "timestamp": torch.arange(0, num_frames, 1) / self.fps, + # "next.observation.image": next_image, + # "next.observation.state": next_state, + "next.reward": next_reward, + "next.done": next_done, + } + ep_dicts.append(ep_dict) + + episode_data_index["from"].append(id_from) + episode_data_index["to"].append(id_from + num_frames) + + id_from = id_to + episode_id += 1 + + data_dict = concatenate_episodes(ep_dicts) + return data_dict, episode_data_index + + def to_hf_dataset(self, data_dict): + features = { + "observation.image": Image(), + "observation.state": Sequence( + length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None) + ), + "action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)), + "episode_index": Value(dtype="int64", id=None), + "frame_index": Value(dtype="int64", id=None), + "timestamp": Value(dtype="float32", id=None), + "next.reward": Value(dtype="float32", id=None), + "next.done": Value(dtype="bool", id=None), + #'next.success': Value(dtype='bool', id=None), + "index": Value(dtype="int64", id=None), + } + features = Features(features) + hf_dataset = Dataset.from_dict(data_dict, features=features) + hf_dataset.set_transform(hf_transform_to_torch) + + return hf_dataset + + def cleanup(self): + pass diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index f7186b6a..ea035371 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -342,7 +342,6 @@ def compute_stats(hf_dataset, batch_size=32, max_num_samples=None): "max": max[key], "min": min[key], } - return stats diff --git a/lerobot/scripts/push_dataset_to_hub.py b/lerobot/scripts/push_dataset_to_hub.py new file mode 100644 index 00000000..0c04fd37 --- /dev/null +++ b/lerobot/scripts/push_dataset_to_hub.py @@ -0,0 +1,338 @@ +import argparse +import json +import shutil +from pathlib import Path +from typing import Any, Protocol + +import torch +from datasets import Dataset +from huggingface_hub import HfApi +from safetensors.torch import save_file + +from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw +from lerobot.common.datasets.push_dataset_to_hub.aloha_processor import ( + AlohaProcessor, +) +from lerobot.common.datasets.push_dataset_to_hub.pusht_processor import PushTProcessor +from lerobot.common.datasets.push_dataset_to_hub.umi_processor import UmiProcessor +from lerobot.common.datasets.push_dataset_to_hub.xarm_processor import XarmProcessor +from lerobot.common.datasets.utils import compute_stats, flatten_dict + + +def push_lerobot_dataset_to_hub( + hf_dataset: Dataset, + episode_data_index: dict[str, list[int]], + info: dict[str, Any], + stats: dict[str, dict[str, torch.Tensor]], + root: Path, + revision: str, + dataset_id: str, + community_id: str = "lerobot", + dry_run: bool = False, +) -> None: + """ + Pushes a dataset to the Hugging Face Hub. + + Args: + hf_dataset (Dataset): The dataset to be pushed. + episode_data_index (dict[str, list[int]]): The index of episode data. + info (dict[str, Any]): Information about the dataset, eg. fps. + stats (dict[str, dict[str, torch.Tensor]]): Statistics of the dataset. + root (Path): The root directory of the dataset. + revision (str): The revision of the dataset. + dataset_id (str): The ID of the dataset. + community_id (str, optional): The ID of the community or the user where the + dataset will be stored. Defaults to "lerobot". + dry_run (bool, optional): If True, performs a dry run without actually pushing the dataset. Defaults to False. + """ + if not dry_run: + # push to main to indicate latest version + hf_dataset.push_to_hub(f"{community_id}/{dataset_id}", token=True) + + # push to version branch + hf_dataset.push_to_hub(f"{community_id}/{dataset_id}", token=True, revision=revision) + + # create and store meta_data + meta_data_dir = root / community_id / dataset_id / "meta_data" + meta_data_dir.mkdir(parents=True, exist_ok=True) + + # info + info_path = meta_data_dir / "info.json" + + with open(str(info_path), "w") as f: + json.dump(info, f, indent=4) + # stats + stats_path = meta_data_dir / "stats.safetensors" + save_file(flatten_dict(stats), stats_path) + + # episode_data_index + episode_data_index = {key: torch.tensor(episode_data_index[key]) for key in episode_data_index} + ep_data_idx_path = meta_data_dir / "episode_data_index.safetensors" + save_file(episode_data_index, ep_data_idx_path) + + if not dry_run: + api = HfApi() + + api.upload_file( + path_or_fileobj=info_path, + path_in_repo=str(info_path).replace(f"{root}/{community_id}/{dataset_id}", ""), + repo_id=f"{community_id}/{dataset_id}", + repo_type="dataset", + ) + api.upload_file( + path_or_fileobj=info_path, + path_in_repo=str(info_path).replace(f"{root}/{community_id}/{dataset_id}", ""), + repo_id=f"{community_id}/{dataset_id}", + repo_type="dataset", + revision=revision, + ) + + # stats + api.upload_file( + path_or_fileobj=stats_path, + path_in_repo=str(stats_path).replace(f"{root}/{community_id}/{dataset_id}", ""), + repo_id=f"{community_id}/{dataset_id}", + repo_type="dataset", + ) + api.upload_file( + path_or_fileobj=stats_path, + path_in_repo=str(stats_path).replace(f"{root}/{community_id}/{dataset_id}", ""), + repo_id=f"{community_id}/{dataset_id}", + repo_type="dataset", + revision=revision, + ) + + api.upload_file( + path_or_fileobj=ep_data_idx_path, + path_in_repo=str(ep_data_idx_path).replace(f"{root}/{community_id}/{dataset_id}", ""), + repo_id=f"{community_id}/{dataset_id}", + repo_type="dataset", + ) + api.upload_file( + path_or_fileobj=ep_data_idx_path, + path_in_repo=str(ep_data_idx_path).replace(f"{root}/{community_id}/{dataset_id}", ""), + repo_id=f"{community_id}/{dataset_id}", + repo_type="dataset", + revision=revision, + ) + + # copy in tests folder, the first episode and the meta_data directory + num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0] + hf_dataset.select(range(num_items_first_ep)).with_format("torch").save_to_disk( + f"tests/data/{community_id}/{dataset_id}/train" + ) + if Path(f"tests/data/{community_id}/{dataset_id}/meta_data").exists(): + shutil.rmtree(f"tests/data/{community_id}/{dataset_id}/meta_data") + shutil.copytree(meta_data_dir, f"tests/data/{community_id}/{dataset_id}/meta_data") + + +def push_dataset_to_hub( + dataset_id: str, + root: Path, + fps: int | None, + dataset_folder: Path | None = None, + dry_run: bool = False, + revision: str = "v1.1", + community_id: str = "lerobot", + no_preprocess: bool = False, + path_save_to_disk: str | None = None, + **kwargs, +) -> None: + """ + Download a raw dataset if needed or access a local raw dataset, detect the raw format (e.g. aloha, pusht, umi) and process it accordingly in a common data format which is then pushed to the Hugging Face Hub. + + Args: + dataset_id (str): The ID of the dataset. + root (Path): The root directory where the dataset will be downloaded. + fps (int | None): The desired frames per second for the dataset. + dataset_folder (Path | None, optional): The path to the dataset folder. If not provided, the dataset will be downloaded using the dataset ID. Defaults to None. + dry_run (bool, optional): If True, performs a dry run without actually pushing the dataset. Defaults to False. + revision (str, optional): Version of the `push_dataset_to_hub.py` codebase used to preprocess the dataset. Defaults to "v1.1". + community_id (str, optional): The ID of the community. Defaults to "lerobot". + no_preprocess (bool, optional): If True, does not preprocesses the dataset. Defaults to False. + path_save_to_disk (str | None, optional): The path to save the dataset to disk. Works when `dry_run` is True, which allows to only save on disk without uploading. By default, the dataset is not saved on disk. + **kwargs: Additional keyword arguments for the preprocessor init method. + + + """ + if dataset_folder is None: + dataset_folder = download_raw(root=root, dataset_id=dataset_id) + + if not no_preprocess: + processor = guess_dataset_type(dataset_folder=dataset_folder, fps=fps, **kwargs) + data_dict, episode_data_index = processor.preprocess() + hf_dataset = processor.to_hf_dataset(data_dict) + + info = { + "fps": processor.fps, + } + stats: dict[str, dict[str, torch.Tensor]] = compute_stats(hf_dataset) + + push_lerobot_dataset_to_hub( + hf_dataset=hf_dataset, + episode_data_index=episode_data_index, + info=info, + stats=stats, + root=root, + revision=revision, + dataset_id=dataset_id, + community_id=community_id, + dry_run=dry_run, + ) + if path_save_to_disk: + hf_dataset.with_format("torch").save_to_disk(dataset_path=str(path_save_to_disk)) + + processor.cleanup() + + +class DatasetProcessor(Protocol): + """A class for processing datasets. + + This class provides methods for validating, preprocessing, and converting datasets. + + Args: + folder_path (str): The path to the folder containing the dataset. + fps (int | None): The frames per second of the dataset. If None, the default value is used. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + """ + + def __init__(self, folder_path: str, fps: int | None, *args, **kwargs) -> None: ... + + def is_valid(self) -> bool: + """Check if the dataset is valid. + + Returns: + bool: True if the dataset is valid, False otherwise. + """ + ... + + def preprocess(self) -> tuple[dict, dict]: + """Preprocess the dataset. + + Returns: + tuple[dict, dict]: A tuple containing two dictionaries representing the preprocessed data. + """ + ... + + def to_hf_dataset(self, data_dict: dict) -> Dataset: + """Convert the preprocessed data to a Hugging Face dataset. + + Args: + data_dict (dict): The preprocessed data. + + Returns: + Dataset: The converted Hugging Face dataset. + """ + ... + + @property + def fps(self) -> int: + """Get the frames per second of the dataset. + + Returns: + int: The frames per second. + """ + ... + + def cleanup(self): + """Clean up any resources used by the dataset processor.""" + ... + + +def guess_dataset_type(dataset_folder: Path, **processor_kwargs) -> DatasetProcessor: + if (processor := AlohaProcessor(folder_path=dataset_folder, **processor_kwargs)).is_valid(): + return processor + if (processor := XarmProcessor(folder_path=dataset_folder, **processor_kwargs)).is_valid(): + return processor + if (processor := PushTProcessor(folder_path=dataset_folder, **processor_kwargs)).is_valid(): + return processor + if (processor := UmiProcessor(folder_path=dataset_folder, **processor_kwargs)).is_valid(): + return processor + # TODO: Propose a registration mechanism for new dataset types + raise ValueError(f"Could not guess dataset type for folder {dataset_folder}") + + +def main(): + """ + Main function to process command line arguments and push dataset to Hugging Face Hub. + + Parses command line arguments to get dataset details and conditions under which the dataset + is processed and pushed. It manages dataset preparation and uploading based on the user-defined parameters. + """ + parser = argparse.ArgumentParser( + description="Push a dataset to the Hugging Face Hub with optional parameters for customization.", + epilog=""" + Example usage: + python -m lerobot.scripts.push_dataset_to_hub --dataset-folder /path/to/dataset --dataset-id example_dataset --root /path/to/root --dry-run --revision v2.0 --community-id example_community --fps 30 --path-save-to-disk /path/to/save --no-preprocess + + This processes and optionally pushes 'example_dataset' located in '/path/to/dataset' to Hugging Face Hub, + with various parameters to control the processing and uploading behavior. + """, + ) + + parser.add_argument( + "--dataset-folder", + type=Path, + default=None, + help="The filesystem path to the dataset folder. If not provided, the dataset must be identified and managed by other means.", + ) + parser.add_argument( + "--dataset-id", + type=str, + required=True, + help="Unique identifier for the dataset to be processed and uploaded.", + ) + parser.add_argument( + "--root", type=Path, required=True, help="Root directory where the dataset operations are managed." + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Simulate the push process without uploading any data, for testing purposes.", + ) + parser.add_argument( + "--community-id", + type=str, + default="lerobot", + help="Community or user ID under which the dataset will be hosted on the Hub.", + ) + parser.add_argument( + "--fps", + type=int, + help="Target frame rate for video or image sequence datasets. Optional and applicable only if the dataset includes temporal media.", + ) + parser.add_argument( + "--revision", + type=str, + default="v1.0", + help="Dataset version identifier to manage different iterations of the dataset.", + ) + parser.add_argument( + "--no-preprocess", + action="store_true", + help="Does not preprocess the dataset, set this flag if you only want dowload the dataset raw.", + ) + parser.add_argument( + "--path-save-to-disk", + type=Path, + help="Optional path where the processed dataset can be saved locally.", + ) + + args = parser.parse_args() + + push_dataset_to_hub( + dataset_folder=args.dataset_folder, + dataset_id=args.dataset_id, + root=args.root, + fps=args.fps, + dry_run=args.dry_run, + community_id=args.community_id, + revision=args.revision, + no_preprocess=args.no_preprocess, + path_save_to_disk=args.path_save_to_disk, + ) + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index dafdec4d..3e9845cf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,6 +7,7 @@ authors = [ "Alexander Soare ", "Quentin Gallouédec ", "Simon Alibert ", + "Adil Zouitine ", "Thomas Wolf ", ] repository = "https://github.com/huggingface/lerobot" @@ -66,7 +67,6 @@ dev = ["pre-commit", "debugpy"] test = ["pytest", "pytest-cov"] umi = ["imagecodecs"] - [tool.ruff] line-length = 110 target-version = "py310" diff --git a/tests/scripts/save_dataset_to_safetensors.py b/tests/scripts/save_dataset_to_safetensors.py index a8ea1065..82a43917 100644 --- a/tests/scripts/save_dataset_to_safetensors.py +++ b/tests/scripts/save_dataset_to_safetensors.py @@ -16,6 +16,7 @@ from pathlib import Path from safetensors.torch import save_file +from lerobot import available_datasets from lerobot.common.datasets.lerobot_dataset import LeRobotDataset @@ -26,8 +27,7 @@ def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"): shutil.rmtree(data_dir) data_dir.mkdir(parents=True, exist_ok=True) - - dataset = LeRobotDataset(repo_id) + dataset = LeRobotDataset(repo_id=repo_id, root=data_dir) # save 2 first frames of first episode i = dataset.episode_data_index["from"][0].item() @@ -64,4 +64,11 @@ def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"): if __name__ == "__main__": - save_dataset_to_safetensors("tests/data/save_dataset_to_safetensors") + available_datasets = [ + "lerobot/pusht", + "lerobot/xarm_push_medium", + "lerobot/aloha_sim_insertion_human", + "lerobot/umi_cup_in_the_wild", + ] + for dataset in available_datasets: + save_dataset_to_safetensors("tests/data/save_dataset_to_safetensors", repo_id=dataset) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index bebc3479..9438abb7 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -241,57 +241,65 @@ def test_flatten_unflatten_dict(): def test_backward_compatibility(): """This tests artifacts have been generated by `tests/scripts/save_dataset_to_safetensors.py`.""" - repo_id = "lerobot/pusht" + all_repo_id = [ + "lerobot/pusht", + # TODO (azouitine): Add artifacts for the following datasets + # "lerobot/aloha_sim_insertion_human", + # "lerobot/xarm_push_medium", + # "lerobot/umi_cup_in_the_wild", + ] + for repo_id in all_repo_id: + dataset = LeRobotDataset( + repo_id, + root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None, + ) - dataset = LeRobotDataset( - repo_id, - root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None, - ) + data_dir = Path("tests/data/save_dataset_to_safetensors") / repo_id - data_dir = Path("tests/data/save_dataset_to_safetensors") / repo_id + def load_and_compare(i): + new_frame = dataset[i] # noqa: B023 + old_frame = load_file(data_dir / f"frame_{i}.safetensors") # noqa: B023 - def load_and_compare(i): - new_frame = dataset[i] - old_frame = load_file(data_dir / f"frame_{i}.safetensors") + new_keys = set(new_frame.keys()) + old_keys = set(old_frame.keys()) + assert new_keys == old_keys, f"{new_keys=} and {old_keys=} are not the same" - new_keys = set(new_frame.keys()) - old_keys = set(old_frame.keys()) - assert new_keys == old_keys, f"{new_keys=} and {old_keys=} are not the same" + for key in new_frame: + assert ( + new_frame[key] == old_frame[key] + ).all(), f"{key=} for index={i} does not contain the same value" - for key in new_frame: - assert ( - new_frame[key] == old_frame[key] - ).all(), f"{key=} for index={i} does not contain the same value" + # test2 first frames of first episode + i = dataset.episode_data_index["from"][0].item() + load_and_compare(i) + load_and_compare(i + 1) - # test2 first frames of first episode - i = dataset.episode_data_index["from"][0].item() - load_and_compare(i) - load_and_compare(i + 1) + # test 2 frames at the middle of first episode + i = int( + (dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2 + ) + load_and_compare(i) + load_and_compare(i + 1) - # test 2 frames at the middle of first episode - i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2) - load_and_compare(i) - load_and_compare(i + 1) + # test 2 last frames of first episode + i = dataset.episode_data_index["to"][0].item() + load_and_compare(i - 2) + load_and_compare(i - 1) - # test 2 last frames of first episode - i = dataset.episode_data_index["to"][0].item() - load_and_compare(i - 2) - load_and_compare(i - 1) + # TODO(rcadene): Enable testing on second and last episode + # We currently cant because our test dataset only contains the first episode - # TODO(rcadene): Enable testing on second and last episode - # We currently cant because our test dataset only contains the first episode + # # test 2 first frames of second episode + # i = dataset.episode_data_index["from"][1].item() + # load_and_compare(i) + # load_and_compare(i+1) - # # test 2 first frames of second episode - # i = dataset.episode_data_index["from"][1].item() - # load_and_compare(i) - # load_and_compare(i+1) + # #test 2 last frames of second episode + # i = dataset.episode_data_index["to"][1].item() + # load_and_compare(i-2) + # load_and_compare(i-1) - # #test 2 last frames of second episode - # i = dataset.episode_data_index["to"][1].item() - # load_and_compare(i-2) - # load_and_compare(i-1) - - # # test 2 last frames of last episode - # i = dataset.episode_data_index["to"][-1].item() - # load_and_compare(i-2) - # load_and_compare(i-1) + # # test 2 last frames of last episode + # i = dataset.episode_data_index["to"][-1].item() + # load_and_compare(i-2) + # load_and_compare(i-1)