Refactor push_dataset_to_hub (#118)
This commit is contained in:
parent
2765877f28
commit
e4e739f4f8
|
@ -4,20 +4,22 @@ useless dependencies when using datasets.
|
|||
"""
|
||||
|
||||
import io
|
||||
import logging
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import tqdm
|
||||
|
||||
|
||||
def download_raw(root, dataset_id) -> Path:
|
||||
def download_raw(raw_dir, dataset_id):
|
||||
if "pusht" in dataset_id:
|
||||
return download_pusht(root=root, dataset_id=dataset_id)
|
||||
download_pusht(raw_dir)
|
||||
elif "xarm" in dataset_id:
|
||||
return download_xarm(root=root, dataset_id=dataset_id)
|
||||
download_xarm(raw_dir)
|
||||
elif "aloha" in dataset_id:
|
||||
return download_aloha(root=root, dataset_id=dataset_id)
|
||||
download_aloha(raw_dir, dataset_id)
|
||||
elif "umi" in dataset_id:
|
||||
return download_umi(root=root, dataset_id=dataset_id)
|
||||
download_umi(raw_dir)
|
||||
else:
|
||||
raise ValueError(dataset_id)
|
||||
|
||||
|
@ -45,32 +47,27 @@ def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
|
|||
|
||||
with zipfile.ZipFile(zip_file, "r") as zip_ref:
|
||||
zip_ref.extractall(destination_folder)
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def download_pusht(root: str, dataset_id: str = "pusht", fps: int = 10) -> Path:
|
||||
def download_pusht(raw_dir: str):
|
||||
pusht_url = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip"
|
||||
pusht_zarr = Path("pusht/pusht_cchi_v7_replay.zarr")
|
||||
|
||||
root = Path(root)
|
||||
raw_dir: Path = root / f"{dataset_id}_raw"
|
||||
zarr_path: Path = (raw_dir / pusht_zarr).resolve()
|
||||
if not zarr_path.is_dir():
|
||||
raw_dir = Path(raw_dir)
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
download_and_extract_zip(pusht_url, raw_dir)
|
||||
return zarr_path
|
||||
# file is created inside a useful "pusht" directory, so we move it out and delete the dir
|
||||
zarr_path = raw_dir / "pusht_cchi_v7_replay.zarr"
|
||||
shutil.move(raw_dir / "pusht" / "pusht_cchi_v7_replay.zarr", zarr_path)
|
||||
shutil.rmtree(raw_dir / "pusht")
|
||||
|
||||
|
||||
def download_xarm(root: str, dataset_id: str, fps: int = 15) -> Path:
|
||||
root = Path(root)
|
||||
raw_dir: Path = root / "xarm_datasets_raw"
|
||||
if not raw_dir.exists():
|
||||
def download_xarm(raw_dir: Path):
|
||||
"""Download all xarm datasets at once"""
|
||||
import zipfile
|
||||
|
||||
import gdown
|
||||
|
||||
raw_dir = Path(raw_dir)
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
# from https://github.com/fyhMer/fowm/blob/main/scripts/download_datasets.py
|
||||
url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
|
||||
|
@ -78,17 +75,25 @@ def download_xarm(root: str, dataset_id: str, fps: int = 15) -> Path:
|
|||
gdown.download(url, str(zip_path), quiet=False)
|
||||
print("Extracting...")
|
||||
with zipfile.ZipFile(str(zip_path), "r") as zip_f:
|
||||
for member in zip_f.namelist():
|
||||
if member.startswith("data/xarm") and member.endswith(".pkl"):
|
||||
print(member)
|
||||
zip_f.extract(member=member)
|
||||
for pkl_path in zip_f.namelist():
|
||||
if pkl_path.startswith("data/xarm") and pkl_path.endswith(".pkl"):
|
||||
zip_f.extract(member=pkl_path)
|
||||
# move to corresponding raw directory
|
||||
extract_dir = pkl_path.replace("/buffer.pkl", "")
|
||||
raw_pkl_path = raw_dir / "buffer.pkl"
|
||||
shutil.move(pkl_path, raw_pkl_path)
|
||||
shutil.rmtree(extract_dir)
|
||||
zip_path.unlink()
|
||||
|
||||
dataset_path: Path = root / f"{dataset_id}"
|
||||
return dataset_path
|
||||
|
||||
def download_aloha(raw_dir: Path, dataset_id: str):
|
||||
# TODO(rcadene): remove gdown and use hugging face download instead
|
||||
import gdown
|
||||
|
||||
logging.warning(
|
||||
"Aloha download is broken and requires a custom version of gdown which is not limited on number of files"
|
||||
)
|
||||
|
||||
def download_aloha(root: str, dataset_id: str) -> Path:
|
||||
folder_urls = {
|
||||
"aloha_sim_insertion_human": "https://drive.google.com/drive/folders/1RgyD0JgTX30H4IM5XZn8I3zSV_mr8pyF",
|
||||
"aloha_sim_insertion_scripted": "https://drive.google.com/drive/folders/1TsojQQSXtHEoGnqgJ3gmpPQR2DPLtS2N",
|
||||
|
@ -129,15 +134,12 @@ def download_aloha(root: str, dataset_id: str) -> Path:
|
|||
"aloha_sim_transfer_cube_human": ["top"],
|
||||
"aloha_sim_transfer_cube_scripted": ["top"],
|
||||
}
|
||||
root = Path(root)
|
||||
raw_dir: Path = root / f"{dataset_id}_raw"
|
||||
if not raw_dir.is_dir():
|
||||
import gdown
|
||||
|
||||
assert dataset_id in folder_urls
|
||||
assert dataset_id in ep48_urls
|
||||
assert dataset_id in ep49_urls
|
||||
|
||||
raw_dir = Path(raw_dir)
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
gdown.download_folder(folder_urls[dataset_id], output=str(raw_dir))
|
||||
|
@ -145,24 +147,19 @@ def download_aloha(root: str, dataset_id: str) -> Path:
|
|||
# because of the 50 files limit per directory, two files episode 48 and 49 were missing
|
||||
gdown.download(ep48_urls[dataset_id], output=str(raw_dir / "episode_48.hdf5"), fuzzy=True)
|
||||
gdown.download(ep49_urls[dataset_id], output=str(raw_dir / "episode_49.hdf5"), fuzzy=True)
|
||||
return raw_dir
|
||||
|
||||
|
||||
def download_umi(root: str, dataset_id: str) -> Path:
|
||||
def download_umi(raw_dir: Path):
|
||||
url_cup_in_the_wild = "https://real.stanford.edu/umi/data/zarr_datasets/cup_in_the_wild.zarr.zip"
|
||||
cup_in_the_wild_zarr = Path("umi/cup_in_the_wild/cup_in_the_wild.zarr")
|
||||
zarr_path = raw_dir / "cup_in_the_wild.zarr"
|
||||
|
||||
root = Path(root)
|
||||
raw_dir: Path = root / f"{dataset_id}_raw"
|
||||
zarr_path: Path = (raw_dir / cup_in_the_wild_zarr).resolve()
|
||||
if not zarr_path.is_dir():
|
||||
raw_dir = Path(raw_dir)
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
download_and_extract_zip(url_cup_in_the_wild, zarr_path)
|
||||
return zarr_path
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
root = "data"
|
||||
data_dir = Path("data")
|
||||
dataset_ids = [
|
||||
"pusht",
|
||||
"xarm_lift_medium",
|
||||
|
@ -176,4 +173,5 @@ if __name__ == "__main__":
|
|||
"umi_cup_in_the_wild",
|
||||
]
|
||||
for dataset_id in dataset_ids:
|
||||
download_raw(root=root, dataset_id=dataset_id)
|
||||
raw_dir = data_dir / f"{dataset_id}_raw"
|
||||
download_raw(raw_dir, dataset_id)
|
||||
|
|
|
@ -0,0 +1,163 @@
|
|||
"""
|
||||
Contains utilities to process raw data format of HDF5 files like in: https://github.com/tonyzhaozh/act
|
||||
"""
|
||||
|
||||
import re
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import h5py
|
||||
import torch
|
||||
import tqdm
|
||||
from datasets import Dataset, Features, Image, Sequence, Value
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
||||
from lerobot.common.datasets.utils import (
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
|
||||
# TODO(rcadene): enable for PR video dataset
|
||||
# from lerobot.common.datasets.video_utils import encode_video_frames
|
||||
|
||||
|
||||
def check_format(raw_dir) -> bool:
|
||||
cameras = ["top"]
|
||||
|
||||
hdf5_files: list[Path] = list(raw_dir.glob("episode_*.hdf5"))
|
||||
assert len(hdf5_files) != 0
|
||||
hdf5_files = sorted(hdf5_files, key=lambda x: int(re.search(r"episode_(\d+).hdf5", x.name).group(1)))
|
||||
|
||||
# Check if the sequence is consecutive eg episode_0, episode_1, episode_2, etc.
|
||||
previous_number = None
|
||||
for file in hdf5_files:
|
||||
current_number = int(re.search(r"episode_(\d+).hdf5", file.name).group(1))
|
||||
if previous_number is not None:
|
||||
assert current_number - previous_number == 1
|
||||
previous_number = current_number
|
||||
|
||||
for file in hdf5_files:
|
||||
with h5py.File(file, "r") as file:
|
||||
# Check for the expected datasets within the HDF5 file
|
||||
required_datasets = ["/action", "/observations/qpos"]
|
||||
# Add camera-specific image datasets to the required datasets
|
||||
camera_datasets = [f"/observations/images/{cam}" for cam in cameras]
|
||||
required_datasets.extend(camera_datasets)
|
||||
|
||||
assert all(dataset in file for dataset in required_datasets)
|
||||
|
||||
|
||||
def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||
hdf5_files = list(raw_dir.glob("*.hdf5"))
|
||||
hdf5_files = sorted(hdf5_files, key=lambda x: int(re.search(r"episode_(\d+)", x.name).group(1)))
|
||||
ep_dicts = []
|
||||
episode_data_index = {"from": [], "to": []}
|
||||
|
||||
id_from = 0
|
||||
|
||||
for ep_path in tqdm.tqdm(hdf5_files):
|
||||
with h5py.File(ep_path, "r") as ep:
|
||||
ep_idx = int(re.search(r"episode_(\d+)", ep_path.name).group(1))
|
||||
num_frames = ep["/action"].shape[0]
|
||||
|
||||
# last step of demonstration is considered done
|
||||
done = torch.zeros(num_frames, dtype=torch.bool)
|
||||
done[-1] = True
|
||||
|
||||
state = torch.from_numpy(ep["/observations/qpos"][:])
|
||||
action = torch.from_numpy(ep["/action"][:])
|
||||
|
||||
ep_dict = {}
|
||||
|
||||
cameras = list(ep["/observations/images"].keys())
|
||||
for cam in cameras:
|
||||
img_key = f"observation.images.{cam}"
|
||||
imgs_array = ep[f"/observations/images/{cam}"][:] # b h w c
|
||||
if video:
|
||||
# save png images in temporary directory
|
||||
tmp_imgs_dir = out_dir / "tmp_images"
|
||||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||
|
||||
# encode images to a mp4 video
|
||||
video_path = out_dir / "videos" / f"{img_key}_episode_{ep_idx:06d}.mp4"
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps) # noqa: F821
|
||||
|
||||
# clean temporary images directory
|
||||
shutil.rmtree(tmp_imgs_dir)
|
||||
|
||||
# store the episode idx
|
||||
ep_dict[img_key] = torch.tensor([ep_idx] * num_frames, dtype=torch.int)
|
||||
else:
|
||||
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
||||
|
||||
ep_dict["observation.state"] = state
|
||||
ep_dict["action"] = action
|
||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames)
|
||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
||||
ep_dict["next.done"] = done
|
||||
# TODO(rcadene): compute reward and success
|
||||
# ep_dict[""next.reward"] = reward
|
||||
# ep_dict[""next.success"] = success
|
||||
|
||||
assert isinstance(ep_idx, int)
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
episode_data_index["from"].append(id_from)
|
||||
episode_data_index["to"].append(id_from + num_frames)
|
||||
|
||||
id_from += num_frames
|
||||
|
||||
# process first episode only
|
||||
if debug:
|
||||
break
|
||||
|
||||
data_dict = concatenate_episodes(ep_dicts)
|
||||
return data_dict, episode_data_index
|
||||
|
||||
|
||||
def to_hf_dataset(data_dict, video) -> Dataset:
|
||||
features = {}
|
||||
|
||||
image_keys = [key for key in data_dict if "observation.images." in key]
|
||||
for image_key in image_keys:
|
||||
if video:
|
||||
features[image_key] = Value(dtype="int64", id="video")
|
||||
else:
|
||||
features[image_key] = Image()
|
||||
|
||||
features["observation.state"] = Sequence(
|
||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
features["action"] = Sequence(
|
||||
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
features["episode_index"] = Value(dtype="int64", id=None)
|
||||
features["frame_index"] = Value(dtype="int64", id=None)
|
||||
features["timestamp"] = Value(dtype="float32", id=None)
|
||||
features["next.done"] = Value(dtype="bool", id=None)
|
||||
features["index"] = Value(dtype="int64", id=None)
|
||||
# TODO(rcadene): add reward and success
|
||||
# features["next.reward"] = Value(dtype="float32", id=None)
|
||||
# features["next.success"] = Value(dtype="bool", id=None)
|
||||
|
||||
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return hf_dataset
|
||||
|
||||
|
||||
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False):
|
||||
# sanity check
|
||||
check_format(raw_dir)
|
||||
|
||||
if fps is None:
|
||||
fps = 50
|
||||
|
||||
data_dir, episode_data_index = load_from_raw(raw_dir, out_dir, fps, video, debug)
|
||||
hf_dataset = to_hf_dataset(data_dir, video)
|
||||
|
||||
info = {
|
||||
"fps": fps,
|
||||
"video": video,
|
||||
}
|
||||
return hf_dataset, episode_data_index, info
|
|
@ -1,199 +0,0 @@
|
|||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import h5py
|
||||
import torch
|
||||
import tqdm
|
||||
from datasets import Dataset, Features, Image, Sequence, Value
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes
|
||||
from lerobot.common.datasets.utils import (
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
|
||||
|
||||
class AlohaProcessor:
|
||||
"""
|
||||
Process HDF5 files formatted like in: https://github.com/tonyzhaozh/act
|
||||
|
||||
Attributes:
|
||||
folder_path (Path): Path to the directory containing HDF5 files.
|
||||
cameras (list[str]): List of camera identifiers to check in the files.
|
||||
fps (int): Frames per second used in timestamp calculations.
|
||||
|
||||
Methods:
|
||||
is_valid() -> bool:
|
||||
Validates if each HDF5 file within the folder contains all required datasets.
|
||||
preprocess() -> dict:
|
||||
Processes the files and returns structured data suitable for further analysis.
|
||||
to_hf_dataset(data_dict: dict) -> Dataset:
|
||||
Converts processed data into a Hugging Face Dataset object.
|
||||
"""
|
||||
|
||||
def __init__(self, folder_path: Path, cameras: list[str] | None = None, fps: int | None = None):
|
||||
"""
|
||||
Initializes the AlohaProcessor with a specified directory path containing HDF5 files,
|
||||
an optional list of cameras, and a frame rate.
|
||||
|
||||
Args:
|
||||
folder_path (Path): The directory path where HDF5 files are stored.
|
||||
cameras (list[str] | None): Optional list of cameras to validate within the files. Defaults to ['top'] if None.
|
||||
fps (int): Frame rate for the datasets, used in time calculations. Default is 50.
|
||||
|
||||
Examples:
|
||||
>>> processor = AlohaProcessor(Path("path_to_hdf5_directory"), ["camera1", "camera2"])
|
||||
>>> processor.is_valid()
|
||||
True
|
||||
"""
|
||||
self.folder_path = folder_path
|
||||
if cameras is None:
|
||||
cameras = ["top"]
|
||||
self.cameras = cameras
|
||||
if fps is None:
|
||||
fps = 50
|
||||
self._fps = fps
|
||||
|
||||
@property
|
||||
def fps(self) -> int:
|
||||
return self._fps
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
"""
|
||||
Validates the HDF5 files in the specified folder to ensure they contain the required datasets
|
||||
for actions, positions, and images for each specified camera.
|
||||
|
||||
Returns:
|
||||
bool: True if all files are valid HDF5 files with all required datasets, False otherwise.
|
||||
"""
|
||||
hdf5_files: list[Path] = list(self.folder_path.glob("episode_*.hdf5"))
|
||||
if len(hdf5_files) == 0:
|
||||
return False
|
||||
try:
|
||||
hdf5_files = sorted(
|
||||
hdf5_files, key=lambda x: int(re.search(r"episode_(\d+).hdf5", x.name).group(1))
|
||||
)
|
||||
except AttributeError:
|
||||
# All file names must contain a numerical identifier matching 'episode_(\\d+).hdf5
|
||||
return False
|
||||
|
||||
# Check if the sequence is consecutive eg episode_0, episode_1, episode_2, etc.
|
||||
# If not, return False
|
||||
previous_number = None
|
||||
for file in hdf5_files:
|
||||
current_number = int(re.search(r"episode_(\d+).hdf5", file.name).group(1))
|
||||
if previous_number is not None and current_number - previous_number != 1:
|
||||
return False
|
||||
previous_number = current_number
|
||||
|
||||
for file in hdf5_files:
|
||||
try:
|
||||
with h5py.File(file, "r") as file:
|
||||
# Check for the expected datasets within the HDF5 file
|
||||
required_datasets = ["/action", "/observations/qpos"]
|
||||
# Add camera-specific image datasets to the required datasets
|
||||
camera_datasets = [f"/observations/images/{cam}" for cam in self.cameras]
|
||||
required_datasets.extend(camera_datasets)
|
||||
|
||||
if not all(dataset in file for dataset in required_datasets):
|
||||
return False
|
||||
except OSError:
|
||||
return False
|
||||
return True
|
||||
|
||||
def preprocess(self):
|
||||
"""
|
||||
Collects episode data from the HDF5 file and returns it as an AlohaStep named tuple.
|
||||
|
||||
Returns:
|
||||
AlohaStep: Named tuple containing episode data.
|
||||
|
||||
Raises:
|
||||
ValueError: If the file is not valid.
|
||||
"""
|
||||
if not self.is_valid():
|
||||
raise ValueError("The HDF5 file is invalid or does not contain the required datasets.")
|
||||
|
||||
hdf5_files = list(self.folder_path.glob("*.hdf5"))
|
||||
hdf5_files = sorted(hdf5_files, key=lambda x: int(re.search(r"episode_(\d+)", x.name).group(1)))
|
||||
ep_dicts = []
|
||||
episode_data_index = {"from": [], "to": []}
|
||||
|
||||
id_from = 0
|
||||
|
||||
for ep_path in tqdm.tqdm(hdf5_files):
|
||||
with h5py.File(ep_path, "r") as ep:
|
||||
ep_id = int(re.search(r"episode_(\d+)", ep_path.name).group(1))
|
||||
num_frames = ep["/action"].shape[0]
|
||||
|
||||
# last step of demonstration is considered done
|
||||
done = torch.zeros(num_frames, dtype=torch.bool)
|
||||
done[-1] = True
|
||||
|
||||
state = torch.from_numpy(ep["/observations/qpos"][:])
|
||||
action = torch.from_numpy(ep["/action"][:])
|
||||
|
||||
ep_dict = {}
|
||||
|
||||
for cam in self.cameras:
|
||||
image = torch.from_numpy(ep[f"/observations/images/{cam}"][:]) # b h w c
|
||||
ep_dict[f"observation.images.{cam}"] = [PILImage.fromarray(x.numpy()) for x in image]
|
||||
|
||||
ep_dict.update(
|
||||
{
|
||||
"observation.state": state,
|
||||
"action": action,
|
||||
"episode_index": torch.tensor([ep_id] * num_frames),
|
||||
"frame_index": torch.arange(0, num_frames, 1),
|
||||
"timestamp": torch.arange(0, num_frames, 1) / self.fps,
|
||||
# TODO(rcadene): compute reward and success
|
||||
# "next.reward": reward,
|
||||
"next.done": done,
|
||||
# "next.success": success,
|
||||
}
|
||||
)
|
||||
|
||||
assert isinstance(ep_id, int)
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
episode_data_index["from"].append(id_from)
|
||||
episode_data_index["to"].append(id_from + num_frames)
|
||||
|
||||
id_from += num_frames
|
||||
|
||||
data_dict = concatenate_episodes(ep_dicts)
|
||||
return data_dict, episode_data_index
|
||||
|
||||
def to_hf_dataset(self, data_dict) -> Dataset:
|
||||
"""
|
||||
Converts a dictionary of data into a Hugging Face Dataset object.
|
||||
|
||||
Args:
|
||||
data_dict (dict): A dictionary containing the data to be converted.
|
||||
|
||||
Returns:
|
||||
Dataset: The converted Hugging Face Dataset object.
|
||||
"""
|
||||
image_features = {f"observation.images.{cam}": Image() for cam in self.cameras}
|
||||
features = {
|
||||
"observation.state": Sequence(
|
||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
),
|
||||
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
|
||||
"episode_index": Value(dtype="int64", id=None),
|
||||
"frame_index": Value(dtype="int64", id=None),
|
||||
"timestamp": Value(dtype="float32", id=None),
|
||||
# "next.reward": Value(dtype="float32", id=None),
|
||||
"next.done": Value(dtype="bool", id=None),
|
||||
# "next.success": Value(dtype="bool", id=None),
|
||||
"index": Value(dtype="int64", id=None),
|
||||
}
|
||||
update_features = {**image_features, **features}
|
||||
features = Features(update_features)
|
||||
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
|
||||
return hf_dataset
|
||||
|
||||
def cleanup(self):
|
||||
pass
|
|
@ -1,180 +0,0 @@
|
|||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
import zarr
|
||||
from datasets import Dataset, Features, Image, Sequence, Value
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes
|
||||
from lerobot.common.datasets.utils import (
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
|
||||
|
||||
class PushTProcessor:
|
||||
"""Process zarr files formatted like in: https://github.com/real-stanford/diffusion_policy"""
|
||||
|
||||
def __init__(self, folder_path: Path, fps: int | None = None):
|
||||
self.zarr_path = folder_path
|
||||
if fps is None:
|
||||
fps = 10
|
||||
self._fps = fps
|
||||
|
||||
@property
|
||||
def fps(self) -> int:
|
||||
return self._fps
|
||||
|
||||
def is_valid(self):
|
||||
try:
|
||||
zarr_data = zarr.open(self.zarr_path, mode="r")
|
||||
except Exception:
|
||||
# TODO (azouitine): Handle the exception properly
|
||||
return False
|
||||
required_datasets = {
|
||||
"data/action",
|
||||
"data/img",
|
||||
"data/keypoint",
|
||||
"data/n_contacts",
|
||||
"data/state",
|
||||
"meta/episode_ends",
|
||||
}
|
||||
for dataset in required_datasets:
|
||||
if dataset not in zarr_data:
|
||||
return False
|
||||
nb_frames = zarr_data["data/img"].shape[0]
|
||||
|
||||
required_datasets.remove("meta/episode_ends")
|
||||
|
||||
return all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets)
|
||||
|
||||
def preprocess(self):
|
||||
try:
|
||||
import pymunk
|
||||
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
|
||||
|
||||
from lerobot.common.datasets.push_dataset_to_hub._diffusion_policy_replay_buffer import (
|
||||
ReplayBuffer as DiffusionPolicyReplayBuffer,
|
||||
)
|
||||
except ModuleNotFoundError as e:
|
||||
print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
|
||||
raise e
|
||||
|
||||
# as define in env
|
||||
success_threshold = 0.95 # 95% coverage,
|
||||
|
||||
dataset_dict = DiffusionPolicyReplayBuffer.copy_from_path(
|
||||
self.zarr_path
|
||||
) # , keys=['img', 'state', 'action'])
|
||||
|
||||
episode_ids = torch.from_numpy(dataset_dict.get_episode_idxs())
|
||||
num_episodes = dataset_dict.meta["episode_ends"].shape[0]
|
||||
assert len(
|
||||
{dataset_dict[key].shape[0] for key in dataset_dict.keys()} # noqa: SIM118
|
||||
), "Some data type dont have the same number of total frames."
|
||||
|
||||
# TODO: verify that goal pose is expected to be fixed
|
||||
goal_pos_angle = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians)
|
||||
goal_body = PushTEnv.get_goal_pose_body(goal_pos_angle)
|
||||
|
||||
imgs = torch.from_numpy(dataset_dict["img"]) # b h w c
|
||||
states = torch.from_numpy(dataset_dict["state"])
|
||||
actions = torch.from_numpy(dataset_dict["action"])
|
||||
|
||||
ep_dicts = []
|
||||
episode_data_index = {"from": [], "to": []}
|
||||
|
||||
id_from = 0
|
||||
for episode_id in tqdm.tqdm(range(num_episodes)):
|
||||
id_to = dataset_dict.meta["episode_ends"][episode_id]
|
||||
|
||||
num_frames = id_to - id_from
|
||||
|
||||
assert (episode_ids[id_from:id_to] == episode_id).all()
|
||||
|
||||
image = imgs[id_from:id_to]
|
||||
assert image.min() >= 0.0
|
||||
assert image.max() <= 255.0
|
||||
image = image.type(torch.uint8)
|
||||
|
||||
state = states[id_from:id_to]
|
||||
agent_pos = state[:, :2]
|
||||
block_pos = state[:, 2:4]
|
||||
block_angle = state[:, 4]
|
||||
|
||||
reward = torch.zeros(num_frames)
|
||||
success = torch.zeros(num_frames, dtype=torch.bool)
|
||||
done = torch.zeros(num_frames, dtype=torch.bool)
|
||||
for i in range(num_frames):
|
||||
space = pymunk.Space()
|
||||
space.gravity = 0, 0
|
||||
space.damping = 0
|
||||
|
||||
# Add walls.
|
||||
walls = [
|
||||
PushTEnv.add_segment(space, (5, 506), (5, 5), 2),
|
||||
PushTEnv.add_segment(space, (5, 5), (506, 5), 2),
|
||||
PushTEnv.add_segment(space, (506, 5), (506, 506), 2),
|
||||
PushTEnv.add_segment(space, (5, 506), (506, 506), 2),
|
||||
]
|
||||
space.add(*walls)
|
||||
|
||||
block_body = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item())
|
||||
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
|
||||
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
|
||||
intersection_area = goal_geom.intersection(block_geom).area
|
||||
goal_area = goal_geom.area
|
||||
coverage = intersection_area / goal_area
|
||||
reward[i] = np.clip(coverage / success_threshold, 0, 1)
|
||||
success[i] = coverage > success_threshold
|
||||
|
||||
# last step of demonstration is considered done
|
||||
done[-1] = True
|
||||
|
||||
ep_dict = {
|
||||
"observation.image": [PILImage.fromarray(x.numpy()) for x in image],
|
||||
"observation.state": agent_pos,
|
||||
"action": actions[id_from:id_to],
|
||||
"episode_index": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
||||
"frame_index": torch.arange(0, num_frames, 1),
|
||||
"timestamp": torch.arange(0, num_frames, 1) / self.fps,
|
||||
# "next.observation.image": image[1:],
|
||||
# "next.observation.state": agent_pos[1:],
|
||||
# TODO(rcadene): verify that reward and done are aligned with image and agent_pos
|
||||
"next.reward": torch.cat([reward[1:], reward[[-1]]]),
|
||||
"next.done": torch.cat([done[1:], done[[-1]]]),
|
||||
"next.success": torch.cat([success[1:], success[[-1]]]),
|
||||
}
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
episode_data_index["from"].append(id_from)
|
||||
episode_data_index["to"].append(id_from + num_frames)
|
||||
|
||||
id_from += num_frames
|
||||
|
||||
data_dict = concatenate_episodes(ep_dicts)
|
||||
return data_dict, episode_data_index
|
||||
|
||||
def to_hf_dataset(self, data_dict):
|
||||
features = {
|
||||
"observation.image": Image(),
|
||||
"observation.state": Sequence(
|
||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
),
|
||||
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
|
||||
"episode_index": Value(dtype="int64", id=None),
|
||||
"frame_index": Value(dtype="int64", id=None),
|
||||
"timestamp": Value(dtype="float32", id=None),
|
||||
"next.reward": Value(dtype="float32", id=None),
|
||||
"next.done": Value(dtype="bool", id=None),
|
||||
"next.success": Value(dtype="bool", id=None),
|
||||
"index": Value(dtype="int64", id=None),
|
||||
}
|
||||
features = Features(features)
|
||||
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return hf_dataset
|
||||
|
||||
def cleanup(self):
|
||||
pass
|
|
@ -0,0 +1,214 @@
|
|||
"""Process zarr files formatted like in: https://github.com/real-stanford/diffusion_policy"""
|
||||
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
import zarr
|
||||
from datasets import Dataset, Features, Image, Sequence, Value
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
||||
from lerobot.common.datasets.utils import (
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
|
||||
# TODO(rcadene): enable for PR video dataset
|
||||
# from lerobot.common.datasets.video_utils import encode_video_frames
|
||||
|
||||
|
||||
def check_format(raw_dir):
|
||||
zarr_path = raw_dir / "pusht_cchi_v7_replay.zarr"
|
||||
zarr_data = zarr.open(zarr_path, mode="r")
|
||||
|
||||
required_datasets = {
|
||||
"data/action",
|
||||
"data/img",
|
||||
"data/keypoint",
|
||||
"data/n_contacts",
|
||||
"data/state",
|
||||
"meta/episode_ends",
|
||||
}
|
||||
for dataset in required_datasets:
|
||||
assert dataset in zarr_data
|
||||
nb_frames = zarr_data["data/img"].shape[0]
|
||||
|
||||
required_datasets.remove("meta/episode_ends")
|
||||
|
||||
assert all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets)
|
||||
|
||||
|
||||
def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||
try:
|
||||
import pymunk
|
||||
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
|
||||
|
||||
from lerobot.common.datasets.push_dataset_to_hub._diffusion_policy_replay_buffer import (
|
||||
ReplayBuffer as DiffusionPolicyReplayBuffer,
|
||||
)
|
||||
except ModuleNotFoundError as e:
|
||||
print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
|
||||
raise e
|
||||
# as define in gmy-pusht env: https://github.com/huggingface/gym-pusht/blob/e0684ff988d223808c0a9dcfaba9dc4991791370/gym_pusht/envs/pusht.py#L174
|
||||
success_threshold = 0.95 # 95% coverage,
|
||||
|
||||
zarr_path = raw_dir / "pusht_cchi_v7_replay.zarr"
|
||||
zarr_data = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path)
|
||||
|
||||
episode_ids = torch.from_numpy(zarr_data.get_episode_idxs())
|
||||
num_episodes = zarr_data.meta["episode_ends"].shape[0]
|
||||
assert len(
|
||||
{zarr_data[key].shape[0] for key in zarr_data.keys()} # noqa: SIM118
|
||||
), "Some data type dont have the same number of total frames."
|
||||
|
||||
# TODO(rcadene): verify that goal pose is expected to be fixed
|
||||
goal_pos_angle = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians)
|
||||
goal_body = PushTEnv.get_goal_pose_body(goal_pos_angle)
|
||||
|
||||
imgs = torch.from_numpy(zarr_data["img"]) # b h w c
|
||||
states = torch.from_numpy(zarr_data["state"])
|
||||
actions = torch.from_numpy(zarr_data["action"])
|
||||
|
||||
ep_dicts = []
|
||||
episode_data_index = {"from": [], "to": []}
|
||||
|
||||
id_from = 0
|
||||
for ep_idx in tqdm.tqdm(range(num_episodes)):
|
||||
id_to = zarr_data.meta["episode_ends"][ep_idx]
|
||||
num_frames = id_to - id_from
|
||||
|
||||
# sanity check
|
||||
assert (episode_ids[id_from:id_to] == ep_idx).all()
|
||||
|
||||
# get image
|
||||
image = imgs[id_from:id_to]
|
||||
assert image.min() >= 0.0
|
||||
assert image.max() <= 255.0
|
||||
image = image.type(torch.uint8)
|
||||
|
||||
# get state
|
||||
state = states[id_from:id_to]
|
||||
agent_pos = state[:, :2]
|
||||
block_pos = state[:, 2:4]
|
||||
block_angle = state[:, 4]
|
||||
|
||||
# get reward, success, done
|
||||
reward = torch.zeros(num_frames)
|
||||
success = torch.zeros(num_frames, dtype=torch.bool)
|
||||
done = torch.zeros(num_frames, dtype=torch.bool)
|
||||
for i in range(num_frames):
|
||||
space = pymunk.Space()
|
||||
space.gravity = 0, 0
|
||||
space.damping = 0
|
||||
|
||||
# Add walls.
|
||||
walls = [
|
||||
PushTEnv.add_segment(space, (5, 506), (5, 5), 2),
|
||||
PushTEnv.add_segment(space, (5, 5), (506, 5), 2),
|
||||
PushTEnv.add_segment(space, (506, 5), (506, 506), 2),
|
||||
PushTEnv.add_segment(space, (5, 506), (506, 506), 2),
|
||||
]
|
||||
space.add(*walls)
|
||||
|
||||
block_body = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item())
|
||||
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
|
||||
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
|
||||
intersection_area = goal_geom.intersection(block_geom).area
|
||||
goal_area = goal_geom.area
|
||||
coverage = intersection_area / goal_area
|
||||
reward[i] = np.clip(coverage / success_threshold, 0, 1)
|
||||
success[i] = coverage > success_threshold
|
||||
|
||||
# last step of demonstration is considered done
|
||||
done[-1] = True
|
||||
|
||||
ep_dict = {}
|
||||
|
||||
imgs_array = [x.numpy() for x in image]
|
||||
if video:
|
||||
# save png images in temporary directory
|
||||
tmp_imgs_dir = out_dir / "tmp_images"
|
||||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||
|
||||
# encode images to a mp4 video
|
||||
video_path = out_dir / "videos" / f"observation.image_episode_{ep_idx:06d}.mp4"
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps) # noqa: F821
|
||||
|
||||
# clean temporary images directory
|
||||
shutil.rmtree(tmp_imgs_dir)
|
||||
|
||||
# store the episode index
|
||||
ep_dict["observation.image"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int)
|
||||
else:
|
||||
ep_dict["observation.image"] = [PILImage.fromarray(x) for x in imgs_array]
|
||||
|
||||
ep_dict["observation.state"] = agent_pos
|
||||
ep_dict["action"] = actions[id_from:id_to]
|
||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int)
|
||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
||||
# ep_dict["next.observation.image"] = image[1:],
|
||||
# ep_dict["next.observation.state"] = agent_pos[1:],
|
||||
# TODO(rcadene)] = verify that reward and done are aligned with image and agent_pos
|
||||
ep_dict["next.reward"] = torch.cat([reward[1:], reward[[-1]]])
|
||||
ep_dict["next.done"] = torch.cat([done[1:], done[[-1]]])
|
||||
ep_dict["next.success"] = torch.cat([success[1:], success[[-1]]])
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
episode_data_index["from"].append(id_from)
|
||||
episode_data_index["to"].append(id_from + num_frames)
|
||||
|
||||
id_from += num_frames
|
||||
|
||||
# process first episode only
|
||||
if debug:
|
||||
break
|
||||
|
||||
data_dict = concatenate_episodes(ep_dicts)
|
||||
return data_dict, episode_data_index
|
||||
|
||||
|
||||
def to_hf_dataset(data_dict, video):
|
||||
features = {}
|
||||
|
||||
if video:
|
||||
features["observation.image"] = Value(dtype="int64", id="video")
|
||||
else:
|
||||
features["observation.image"] = Image()
|
||||
|
||||
features["observation.state"] = Sequence(
|
||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
features["action"] = Sequence(
|
||||
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
features["episode_index"] = Value(dtype="int64", id=None)
|
||||
features["frame_index"] = Value(dtype="int64", id=None)
|
||||
features["timestamp"] = Value(dtype="float32", id=None)
|
||||
features["next.reward"] = Value(dtype="float32", id=None)
|
||||
features["next.done"] = Value(dtype="bool", id=None)
|
||||
features["next.success"] = Value(dtype="bool", id=None)
|
||||
features["index"] = Value(dtype="int64", id=None)
|
||||
|
||||
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return hf_dataset
|
||||
|
||||
|
||||
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False):
|
||||
# sanity check
|
||||
check_format(raw_dir)
|
||||
|
||||
if fps is None:
|
||||
fps = 10
|
||||
|
||||
data_dict, episode_data_index = load_from_raw(raw_dir, out_dir, fps, video, debug)
|
||||
hf_dataset = to_hf_dataset(data_dict, video)
|
||||
|
||||
info = {
|
||||
"fps": fps,
|
||||
"video": video,
|
||||
}
|
||||
return hf_dataset, episode_data_index, info
|
|
@ -1,280 +0,0 @@
|
|||
import os
|
||||
import re
|
||||
import shutil
|
||||
from glob import glob
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
import zarr
|
||||
from datasets import Dataset, Features, Image, Sequence, Value
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from lerobot.common.datasets.push_dataset_to_hub._umi_imagecodecs_numcodecs import register_codecs
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes
|
||||
from lerobot.common.datasets.utils import (
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
|
||||
|
||||
class UmiProcessor:
|
||||
"""
|
||||
Process UMI (Universal Manipulation Interface) data stored in Zarr format like in: https://github.com/real-stanford/universal_manipulation_interface
|
||||
|
||||
Attributes:
|
||||
folder_path (str): The path to the folder containing Zarr datasets.
|
||||
fps (int): Frames per second, used to calculate timestamps for frames.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, folder_path: str, fps: int | None = None):
|
||||
self.zarr_path = folder_path
|
||||
if fps is None:
|
||||
# https://arxiv.org/pdf/2402.10329#table.caption.16
|
||||
fps = 10 # For umi cup in the wild
|
||||
self._fps = fps
|
||||
register_codecs()
|
||||
|
||||
@property
|
||||
def fps(self) -> int:
|
||||
return self._fps
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
"""
|
||||
Validates the Zarr folder to ensure it contains all required datasets with consistent frame counts.
|
||||
|
||||
Returns:
|
||||
bool: True if all required datasets are present and have consistent frame counts, False otherwise.
|
||||
"""
|
||||
# Check if the Zarr folder is valid
|
||||
try:
|
||||
zarr_data = zarr.open(self.zarr_path, mode="r")
|
||||
except Exception:
|
||||
# TODO (azouitine): Handle the exception properly
|
||||
return False
|
||||
required_datasets = {
|
||||
"data/robot0_demo_end_pose",
|
||||
"data/robot0_demo_start_pose",
|
||||
"data/robot0_eef_pos",
|
||||
"data/robot0_eef_rot_axis_angle",
|
||||
"data/robot0_gripper_width",
|
||||
"meta/episode_ends",
|
||||
"data/camera0_rgb",
|
||||
}
|
||||
for dataset in required_datasets:
|
||||
if dataset not in zarr_data:
|
||||
return False
|
||||
nb_frames = zarr_data["data/camera0_rgb"].shape[0]
|
||||
|
||||
required_datasets.remove("meta/episode_ends")
|
||||
|
||||
return all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets)
|
||||
|
||||
def preprocess(self):
|
||||
"""
|
||||
Collects and processes all episodes from the Zarr dataset into structured data dictionaries.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict, Dict]: A tuple containing the structured episode data and episode index mappings.
|
||||
"""
|
||||
zarr_data = zarr.open(self.zarr_path, mode="r")
|
||||
|
||||
# We process the image data separately because it is too large to fit in memory
|
||||
end_pose = torch.from_numpy(zarr_data["data/robot0_demo_end_pose"][:])
|
||||
start_pos = torch.from_numpy(zarr_data["data/robot0_demo_start_pose"][:])
|
||||
eff_pos = torch.from_numpy(zarr_data["data/robot0_eef_pos"][:])
|
||||
eff_rot_axis_angle = torch.from_numpy(zarr_data["data/robot0_eef_rot_axis_angle"][:])
|
||||
gripper_width = torch.from_numpy(zarr_data["data/robot0_gripper_width"][:])
|
||||
|
||||
states_pos = torch.cat([eff_pos, eff_rot_axis_angle], dim=1)
|
||||
states = torch.cat([states_pos, gripper_width], dim=1)
|
||||
|
||||
episode_ends = zarr_data["meta/episode_ends"][:]
|
||||
num_episodes: int = episode_ends.shape[0]
|
||||
|
||||
episode_ids = torch.from_numpy(self.get_episode_idxs(episode_ends))
|
||||
|
||||
# We convert it in torch tensor later because the jit function does not support torch tensors
|
||||
episode_ends = torch.from_numpy(episode_ends)
|
||||
|
||||
ep_dicts = []
|
||||
episode_data_index = {"from": [], "to": []}
|
||||
id_from = 0
|
||||
|
||||
for episode_id in tqdm.tqdm(range(num_episodes)):
|
||||
id_to = episode_ends[episode_id]
|
||||
|
||||
num_frames = id_to - id_from
|
||||
|
||||
assert (
|
||||
episode_ids[id_from:id_to] == episode_id
|
||||
).all(), f"episode_ids[{id_from}:{id_to}] != {episode_id}"
|
||||
|
||||
state = states[id_from:id_to]
|
||||
ep_dict = {
|
||||
# observation.image will be filled later
|
||||
"observation.state": state,
|
||||
"episode_index": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
||||
"frame_index": torch.arange(0, num_frames, 1),
|
||||
"timestamp": torch.arange(0, num_frames, 1) / self.fps,
|
||||
"episode_data_index_from": torch.tensor([id_from] * num_frames),
|
||||
"episode_data_index_to": torch.tensor([id_from + num_frames] * num_frames),
|
||||
"end_pose": end_pose[id_from:id_to],
|
||||
"start_pos": start_pos[id_from:id_to],
|
||||
"gripper_width": gripper_width[id_from:id_to],
|
||||
}
|
||||
ep_dicts.append(ep_dict)
|
||||
episode_data_index["from"].append(id_from)
|
||||
episode_data_index["to"].append(id_from + num_frames)
|
||||
id_from += num_frames
|
||||
|
||||
data_dict = concatenate_episodes(ep_dicts)
|
||||
|
||||
total_frames = id_from
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
|
||||
print("Saving images to disk in temporary folder...")
|
||||
# datasets.Image() can take a list of paths to images, so we save the images to a temporary folder
|
||||
# to avoid loading them all in memory
|
||||
_save_images_concurrently(
|
||||
data=zarr_data, image_key="data/camera0_rgb", folder_path="tmp_umi_images", max_workers=12
|
||||
)
|
||||
print("Saving images to disk in temporary folder... Done")
|
||||
|
||||
# Sort files by number eg. 1.png, 2.png, 3.png, 9.png, 10.png instead of 1.png, 10.png, 2.png, 3.png, 9.png
|
||||
# to correctly match the images with the data
|
||||
images_path = sorted(
|
||||
glob("tmp_umi_images/*"), key=lambda x: int(re.search(r"(\d+)\.png$", x).group(1))
|
||||
)
|
||||
data_dict["observation.image"] = images_path
|
||||
print("Images saved to disk, do not forget to delete the folder tmp_umi_images/")
|
||||
|
||||
# Cleanup
|
||||
return data_dict, episode_data_index
|
||||
|
||||
def to_hf_dataset(self, data_dict):
|
||||
"""
|
||||
Converts the processed data dictionary into a Hugging Face dataset with defined features.
|
||||
|
||||
Args:
|
||||
data_dict (Dict): The data dictionary containing tensors and episode information.
|
||||
|
||||
Returns:
|
||||
Dataset: A Hugging Face dataset constructed from the provided data dictionary.
|
||||
"""
|
||||
features = {
|
||||
"observation.image": Image(),
|
||||
"observation.state": Sequence(
|
||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
),
|
||||
"episode_index": Value(dtype="int64", id=None),
|
||||
"frame_index": Value(dtype="int64", id=None),
|
||||
"timestamp": Value(dtype="float32", id=None),
|
||||
"index": Value(dtype="int64", id=None),
|
||||
"episode_data_index_from": Value(dtype="int64", id=None),
|
||||
"episode_data_index_to": Value(dtype="int64", id=None),
|
||||
# `start_pos` and `end_pos` respectively represent the positions of the end-effector
|
||||
# at the beginning and the end of the episode.
|
||||
# `gripper_width` indicates the distance between the grippers, and this value is included
|
||||
# in the state vector, which comprises the concatenation of the end-effector position
|
||||
# and gripper width.
|
||||
"end_pose": Sequence(
|
||||
length=data_dict["end_pose"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
),
|
||||
"start_pos": Sequence(
|
||||
length=data_dict["start_pos"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
),
|
||||
"gripper_width": Sequence(
|
||||
length=data_dict["gripper_width"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
),
|
||||
}
|
||||
features = Features(features)
|
||||
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
|
||||
return hf_dataset
|
||||
|
||||
def cleanup(self):
|
||||
# Cleanup
|
||||
if os.path.exists("tmp_umi_images"):
|
||||
print("Removing temporary images folder")
|
||||
shutil.rmtree("tmp_umi_images")
|
||||
print("Cleanup done")
|
||||
|
||||
@classmethod
|
||||
def get_episode_idxs(cls, episode_ends: np.ndarray) -> np.ndarray:
|
||||
# Optimized and simplified version of this function: https://github.com/real-stanford/universal_manipulation_interface/blob/298776ce251f33b6b3185a98d6e7d1f9ad49168b/diffusion_policy/common/replay_buffer.py#L374
|
||||
from numba import jit
|
||||
|
||||
@jit(nopython=True)
|
||||
def _get_episode_idxs(episode_ends):
|
||||
result = np.zeros((episode_ends[-1],), dtype=np.int64)
|
||||
start_idx = 0
|
||||
for episode_number, end_idx in enumerate(episode_ends):
|
||||
result[start_idx:end_idx] = episode_number
|
||||
start_idx = end_idx
|
||||
return result
|
||||
|
||||
return _get_episode_idxs(episode_ends)
|
||||
|
||||
|
||||
def _clear_folder(folder_path: str):
|
||||
"""
|
||||
Clears all the content of the specified folder. Creates the folder if it does not exist.
|
||||
|
||||
Args:
|
||||
folder_path (str): Path to the folder to clear.
|
||||
|
||||
Examples:
|
||||
>>> import os
|
||||
>>> os.makedirs('example_folder', exist_ok=True)
|
||||
>>> with open('example_folder/temp_file.txt', 'w') as f:
|
||||
... f.write('example')
|
||||
>>> clear_folder('example_folder')
|
||||
>>> os.listdir('example_folder')
|
||||
[]
|
||||
"""
|
||||
if os.path.exists(folder_path):
|
||||
for filename in os.listdir(folder_path):
|
||||
file_path = os.path.join(folder_path, filename)
|
||||
try:
|
||||
if os.path.isfile(file_path) or os.path.islink(file_path):
|
||||
os.unlink(file_path)
|
||||
elif os.path.isdir(file_path):
|
||||
shutil.rmtree(file_path)
|
||||
except Exception as e:
|
||||
print(f"Failed to delete {file_path}. Reason: {e}")
|
||||
else:
|
||||
os.makedirs(folder_path)
|
||||
|
||||
|
||||
def _save_image(img_array: np.array, i: int, folder_path: str):
|
||||
"""
|
||||
Saves a single image to the specified folder.
|
||||
|
||||
Args:
|
||||
img_array (ndarray): The numpy array of the image.
|
||||
i (int): Index of the image, used for naming.
|
||||
folder_path (str): Path to the folder where the image will be saved.
|
||||
"""
|
||||
img = PILImage.fromarray(img_array)
|
||||
img_format = "PNG" if img_array.dtype == np.uint8 else "JPEG"
|
||||
img.save(os.path.join(folder_path, f"{i}.{img_format.lower()}"), quality=100)
|
||||
|
||||
|
||||
def _save_images_concurrently(data: dict, image_key: str, folder_path: str, max_workers: int = 4):
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
"""
|
||||
Saves images from the zarr_data to the specified folder using multithreading.
|
||||
|
||||
Args:
|
||||
zarr_data (dict): A dictionary containing image data in an array format.
|
||||
folder_path (str): Path to the folder where images will be saved.
|
||||
max_workers (int): The maximum number of threads to use for saving images.
|
||||
"""
|
||||
num_images = len(data["data/camera0_rgb"])
|
||||
_clear_folder(folder_path) # Clear or create folder first
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
[executor.submit(_save_image, data[image_key][i], i, folder_path) for i in range(num_images)]
|
|
@ -0,0 +1,207 @@
|
|||
"""Process UMI (Universal Manipulation Interface) data stored in Zarr format like in: https://github.com/real-stanford/universal_manipulation_interface"""
|
||||
|
||||
import logging
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
import zarr
|
||||
from datasets import Dataset, Features, Image, Sequence, Value
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from lerobot.common.datasets.push_dataset_to_hub._umi_imagecodecs_numcodecs import register_codecs
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
||||
from lerobot.common.datasets.utils import (
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
|
||||
# TODO(rcadene): enable for PR video dataset
|
||||
# from lerobot.common.datasets.video_utils import encode_video_frames
|
||||
|
||||
|
||||
def check_format(raw_dir) -> bool:
|
||||
zarr_path = raw_dir / "cup_in_the_wild.zarr"
|
||||
zarr_data = zarr.open(zarr_path, mode="r")
|
||||
|
||||
required_datasets = {
|
||||
"data/robot0_demo_end_pose",
|
||||
"data/robot0_demo_start_pose",
|
||||
"data/robot0_eef_pos",
|
||||
"data/robot0_eef_rot_axis_angle",
|
||||
"data/robot0_gripper_width",
|
||||
"meta/episode_ends",
|
||||
"data/camera0_rgb",
|
||||
}
|
||||
for dataset in required_datasets:
|
||||
if dataset not in zarr_data:
|
||||
return False
|
||||
|
||||
# mandatory to access zarr_data
|
||||
register_codecs()
|
||||
nb_frames = zarr_data["data/camera0_rgb"].shape[0]
|
||||
|
||||
required_datasets.remove("meta/episode_ends")
|
||||
assert all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets)
|
||||
|
||||
|
||||
def get_episode_idxs(episode_ends: np.ndarray) -> np.ndarray:
|
||||
# Optimized and simplified version of this function: https://github.com/real-stanford/universal_manipulation_interface/blob/298776ce251f33b6b3185a98d6e7d1f9ad49168b/diffusion_policy/common/replay_buffer.py#L374
|
||||
from numba import jit
|
||||
|
||||
@jit(nopython=True)
|
||||
def _get_episode_idxs(episode_ends):
|
||||
result = np.zeros((episode_ends[-1],), dtype=np.int64)
|
||||
start_idx = 0
|
||||
for episode_number, end_idx in enumerate(episode_ends):
|
||||
result[start_idx:end_idx] = episode_number
|
||||
start_idx = end_idx
|
||||
return result
|
||||
|
||||
return _get_episode_idxs(episode_ends)
|
||||
|
||||
|
||||
def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||
zarr_path = raw_dir / "cup_in_the_wild.zarr"
|
||||
zarr_data = zarr.open(zarr_path, mode="r")
|
||||
|
||||
# We process the image data separately because it is too large to fit in memory
|
||||
end_pose = torch.from_numpy(zarr_data["data/robot0_demo_end_pose"][:])
|
||||
start_pos = torch.from_numpy(zarr_data["data/robot0_demo_start_pose"][:])
|
||||
eff_pos = torch.from_numpy(zarr_data["data/robot0_eef_pos"][:])
|
||||
eff_rot_axis_angle = torch.from_numpy(zarr_data["data/robot0_eef_rot_axis_angle"][:])
|
||||
gripper_width = torch.from_numpy(zarr_data["data/robot0_gripper_width"][:])
|
||||
|
||||
states_pos = torch.cat([eff_pos, eff_rot_axis_angle], dim=1)
|
||||
states = torch.cat([states_pos, gripper_width], dim=1)
|
||||
|
||||
episode_ends = zarr_data["meta/episode_ends"][:]
|
||||
num_episodes = episode_ends.shape[0]
|
||||
|
||||
episode_ids = torch.from_numpy(get_episode_idxs(episode_ends))
|
||||
|
||||
# We convert it in torch tensor later because the jit function does not support torch tensors
|
||||
episode_ends = torch.from_numpy(episode_ends)
|
||||
|
||||
ep_dicts = []
|
||||
episode_data_index = {"from": [], "to": []}
|
||||
|
||||
id_from = 0
|
||||
for ep_idx in tqdm.tqdm(range(num_episodes)):
|
||||
id_to = episode_ends[ep_idx]
|
||||
num_frames = id_to - id_from
|
||||
|
||||
# sanity heck
|
||||
assert (episode_ids[id_from:id_to] == ep_idx).all()
|
||||
|
||||
# TODO(rcadene): save temporary images of the episode?
|
||||
|
||||
state = states[id_from:id_to]
|
||||
|
||||
ep_dict = {}
|
||||
|
||||
# load 57MB of images in RAM (400x224x224x3 uint8)
|
||||
imgs_array = zarr_data["data/camera0_rgb"][id_from:id_to]
|
||||
if video:
|
||||
# save png images in temporary directory
|
||||
tmp_imgs_dir = out_dir / "tmp_images"
|
||||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||
|
||||
# encode images to a mp4 video
|
||||
video_path = out_dir / "videos" / f"observation.image_episode_{ep_idx:06d}.mp4"
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps) # noqa: F821
|
||||
|
||||
# clean temporary images directory
|
||||
shutil.rmtree(tmp_imgs_dir)
|
||||
|
||||
# store the episode index
|
||||
ep_dict["observation.image"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int)
|
||||
else:
|
||||
ep_dict["observation.image"] = [PILImage.fromarray(x) for x in imgs_array]
|
||||
|
||||
ep_dict["observation.state"] = state
|
||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int)
|
||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
||||
ep_dict["episode_data_index_from"] = torch.tensor([id_from] * num_frames)
|
||||
ep_dict["episode_data_index_to"] = torch.tensor([id_from + num_frames] * num_frames)
|
||||
ep_dict["end_pose"] = end_pose[id_from:id_to]
|
||||
ep_dict["start_pos"] = start_pos[id_from:id_to]
|
||||
ep_dict["gripper_width"] = gripper_width[id_from:id_to]
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
episode_data_index["from"].append(id_from)
|
||||
episode_data_index["to"].append(id_from + num_frames)
|
||||
id_from += num_frames
|
||||
|
||||
# process first episode only
|
||||
if debug:
|
||||
break
|
||||
|
||||
data_dict = concatenate_episodes(ep_dicts)
|
||||
|
||||
total_frames = id_from
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
|
||||
return data_dict, episode_data_index
|
||||
|
||||
|
||||
def to_hf_dataset(data_dict, video):
|
||||
features = {}
|
||||
|
||||
if video:
|
||||
features["observation.image"] = Value(dtype="int64", id="video")
|
||||
else:
|
||||
features["observation.image"] = Image()
|
||||
|
||||
features["observation.state"] = Sequence(
|
||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
features["episode_index"] = Value(dtype="int64", id=None)
|
||||
features["frame_index"] = Value(dtype="int64", id=None)
|
||||
features["timestamp"] = Value(dtype="float32", id=None)
|
||||
features["index"] = Value(dtype="int64", id=None)
|
||||
features["episode_data_index_from"] = Value(dtype="int64", id=None)
|
||||
features["episode_data_index_to"] = Value(dtype="int64", id=None)
|
||||
# `start_pos` and `end_pos` respectively represent the positions of the end-effector
|
||||
# at the beginning and the end of the episode.
|
||||
# `gripper_width` indicates the distance between the grippers, and this value is included
|
||||
# in the state vector, which comprises the concatenation of the end-effector position
|
||||
# and gripper width.
|
||||
features["end_pose"] = Sequence(
|
||||
length=data_dict["end_pose"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
features["start_pos"] = Sequence(
|
||||
length=data_dict["start_pos"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
features["gripper_width"] = Sequence(
|
||||
length=data_dict["gripper_width"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
|
||||
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return hf_dataset
|
||||
|
||||
|
||||
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False):
|
||||
# sanity check
|
||||
check_format(raw_dir)
|
||||
|
||||
if fps is None:
|
||||
# For umi cup in the wild: https://arxiv.org/pdf/2402.10329#table.caption.16
|
||||
fps = 10
|
||||
|
||||
if not video:
|
||||
logging.warning(
|
||||
"Generating UMI dataset without `video=True` creates ~150GB on disk and requires ~80GB in RAM."
|
||||
)
|
||||
|
||||
data_dict, episode_data_index = load_from_raw(raw_dir, out_dir, fps, video, debug)
|
||||
hf_dataset = to_hf_dataset(data_dict, video)
|
||||
|
||||
info = {
|
||||
"fps": fps,
|
||||
"video": video,
|
||||
}
|
||||
return hf_dataset, episode_data_index, info
|
|
@ -1,3 +1,8 @@
|
|||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
|
||||
import numpy
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
|
||||
|
@ -18,3 +23,16 @@ def concatenate_episodes(ep_dicts):
|
|||
total_frames = data_dict["frame_index"].shape[0]
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
return data_dict
|
||||
|
||||
|
||||
def save_images_concurrently(imgs_array: numpy.array, out_dir: Path, max_workers: int = 4):
|
||||
out_dir = Path(out_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def save_image(img_array, i, out_dir):
|
||||
img = PIL.Image.fromarray(img_array)
|
||||
img.save(str(out_dir / f"frame_{i:06d}.png"), quality=100)
|
||||
|
||||
num_images = len(imgs_array)
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
[executor.submit(save_image, imgs_array[i], i, out_dir) for i in range(num_images)]
|
||||
|
|
|
@ -0,0 +1,163 @@
|
|||
"""Process pickle files formatted like in: https://github.com/fyhMer/fowm"""
|
||||
|
||||
import pickle
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import tqdm
|
||||
from datasets import Dataset, Features, Image, Sequence, Value
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
||||
from lerobot.common.datasets.utils import (
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
|
||||
# TODO(rcadene): enable for PR video dataset
|
||||
# from lerobot.common.datasets.video_utils import encode_video_frames
|
||||
|
||||
|
||||
def check_format(raw_dir):
|
||||
keys = {"actions", "rewards", "dones"}
|
||||
nested_keys = {"observations": {"rgb", "state"}, "next_observations": {"rgb", "state"}}
|
||||
|
||||
xarm_files = list(raw_dir.glob("*.pkl"))
|
||||
assert len(xarm_files) > 0
|
||||
|
||||
with open(xarm_files[0], "rb") as f:
|
||||
dataset_dict = pickle.load(f)
|
||||
|
||||
assert isinstance(dataset_dict, dict)
|
||||
assert all(k in dataset_dict for k in keys)
|
||||
|
||||
# Check for consistent lengths in nested keys
|
||||
expected_len = len(dataset_dict["actions"])
|
||||
assert all(len(dataset_dict[key]) == expected_len for key in keys if key in dataset_dict)
|
||||
|
||||
for key, subkeys in nested_keys.items():
|
||||
nested_dict = dataset_dict.get(key, {})
|
||||
assert all(len(nested_dict[subkey]) == expected_len for subkey in subkeys if subkey in nested_dict)
|
||||
|
||||
|
||||
def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||
pkl_path = raw_dir / "buffer.pkl"
|
||||
|
||||
with open(pkl_path, "rb") as f:
|
||||
pkl_data = pickle.load(f)
|
||||
|
||||
ep_dicts = []
|
||||
episode_data_index = {"from": [], "to": []}
|
||||
|
||||
id_from = 0
|
||||
id_to = 0
|
||||
ep_idx = 0
|
||||
total_frames = pkl_data["actions"].shape[0]
|
||||
for i in tqdm.tqdm(range(total_frames)):
|
||||
id_to += 1
|
||||
|
||||
if not pkl_data["dones"][i]:
|
||||
continue
|
||||
|
||||
num_frames = id_to - id_from
|
||||
|
||||
image = torch.tensor(pkl_data["observations"]["rgb"][id_from:id_to])
|
||||
image = einops.rearrange(image, "b c h w -> b h w c")
|
||||
state = torch.tensor(pkl_data["observations"]["state"][id_from:id_to])
|
||||
action = torch.tensor(pkl_data["actions"][id_from:id_to])
|
||||
# TODO(rcadene): we have a missing last frame which is the observation when the env is done
|
||||
# it is critical to have this frame for tdmpc to predict a "done observation/state"
|
||||
# next_image = torch.tensor(pkl_data["next_observations"]["rgb"][id_from:id_to])
|
||||
# next_state = torch.tensor(pkl_data["next_observations"]["state"][id_from:id_to])
|
||||
next_reward = torch.tensor(pkl_data["rewards"][id_from:id_to])
|
||||
next_done = torch.tensor(pkl_data["dones"][id_from:id_to])
|
||||
|
||||
ep_dict = {}
|
||||
|
||||
imgs_array = [x.numpy() for x in image]
|
||||
if video:
|
||||
# save png images in temporary directory
|
||||
tmp_imgs_dir = out_dir / "tmp_images"
|
||||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||
|
||||
# encode images to a mp4 video
|
||||
video_path = out_dir / "videos" / f"observation.image_episode_{ep_idx:06d}.mp4"
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps) # noqa: F821
|
||||
|
||||
# clean temporary images directory
|
||||
shutil.rmtree(tmp_imgs_dir)
|
||||
|
||||
# store the episode index
|
||||
ep_dict["observation.image"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int)
|
||||
else:
|
||||
ep_dict["observation.image"] = [PILImage.fromarray(x) for x in imgs_array]
|
||||
|
||||
ep_dict["observation.state"] = state
|
||||
ep_dict["action"] = action
|
||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int)
|
||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
||||
# ep_dict["next.observation.image"] = next_image
|
||||
# ep_dict["next.observation.state"] = next_state
|
||||
ep_dict["next.reward"] = next_reward
|
||||
ep_dict["next.done"] = next_done
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
episode_data_index["from"].append(id_from)
|
||||
episode_data_index["to"].append(id_from + num_frames)
|
||||
|
||||
id_from = id_to
|
||||
ep_idx += 1
|
||||
|
||||
# process first episode only
|
||||
if debug:
|
||||
break
|
||||
|
||||
data_dict = concatenate_episodes(ep_dicts)
|
||||
return data_dict, episode_data_index
|
||||
|
||||
|
||||
def to_hf_dataset(data_dict, video):
|
||||
features = {}
|
||||
|
||||
if video:
|
||||
features["observation.image"] = Value(dtype="int64", id="video")
|
||||
else:
|
||||
features["observation.image"] = Image()
|
||||
|
||||
features["observation.state"] = Sequence(
|
||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
features["action"] = Sequence(
|
||||
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
features["episode_index"] = Value(dtype="int64", id=None)
|
||||
features["frame_index"] = Value(dtype="int64", id=None)
|
||||
features["timestamp"] = Value(dtype="float32", id=None)
|
||||
features["next.reward"] = Value(dtype="float32", id=None)
|
||||
features["next.done"] = Value(dtype="bool", id=None)
|
||||
features["index"] = Value(dtype="int64", id=None)
|
||||
# TODO(rcadene): add success
|
||||
# features["next.success"] = Value(dtype='bool', id=None)
|
||||
|
||||
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return hf_dataset
|
||||
|
||||
|
||||
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False):
|
||||
# sanity check
|
||||
check_format(raw_dir)
|
||||
|
||||
if fps is None:
|
||||
fps = 15
|
||||
|
||||
data_dict, episode_data_index = load_from_raw(raw_dir, out_dir, fps, video, debug)
|
||||
hf_dataset = to_hf_dataset(data_dict, video)
|
||||
|
||||
info = {
|
||||
"fps": fps,
|
||||
"video": video,
|
||||
}
|
||||
return hf_dataset, episode_data_index, info
|
|
@ -1,145 +0,0 @@
|
|||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import tqdm
|
||||
from datasets import Dataset, Features, Image, Sequence, Value
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes
|
||||
from lerobot.common.datasets.utils import (
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
|
||||
|
||||
class XarmProcessor:
|
||||
"""Process pickle files formatted like in: https://github.com/fyhMer/fowm"""
|
||||
|
||||
def __init__(self, folder_path: str, fps: int | None = None):
|
||||
self.folder_path = Path(folder_path)
|
||||
self.keys = {"actions", "rewards", "dones"}
|
||||
self.nested_keys = {"observations": {"rgb", "state"}, "next_observations": {"rgb", "state"}}
|
||||
if fps is None:
|
||||
fps = 15
|
||||
self._fps = fps
|
||||
|
||||
@property
|
||||
def fps(self) -> int:
|
||||
return self._fps
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
# get all .pkl files
|
||||
xarm_files = list(self.folder_path.glob("*.pkl"))
|
||||
if len(xarm_files) != 1:
|
||||
return False
|
||||
|
||||
try:
|
||||
with open(xarm_files[0], "rb") as f:
|
||||
dataset_dict = pickle.load(f)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
if not isinstance(dataset_dict, dict):
|
||||
return False
|
||||
|
||||
if not all(k in dataset_dict for k in self.keys):
|
||||
return False
|
||||
|
||||
# Check for consistent lengths in nested keys
|
||||
try:
|
||||
expected_len = len(dataset_dict["actions"])
|
||||
if any(len(dataset_dict[key]) != expected_len for key in self.keys if key in dataset_dict):
|
||||
return False
|
||||
|
||||
for key, subkeys in self.nested_keys.items():
|
||||
nested_dict = dataset_dict.get(key, {})
|
||||
if any(
|
||||
len(nested_dict[subkey]) != expected_len for subkey in subkeys if subkey in nested_dict
|
||||
):
|
||||
return False
|
||||
except KeyError: # If any expected key or subkey is missing
|
||||
return False
|
||||
|
||||
return True # All checks passed
|
||||
|
||||
def preprocess(self):
|
||||
if not self.is_valid():
|
||||
raise ValueError("The Xarm file is invalid or does not contain the required datasets.")
|
||||
|
||||
xarm_files = list(self.folder_path.glob("*.pkl"))
|
||||
|
||||
with open(xarm_files[0], "rb") as f:
|
||||
dataset_dict = pickle.load(f)
|
||||
ep_dicts = []
|
||||
episode_data_index = {"from": [], "to": []}
|
||||
|
||||
id_from = 0
|
||||
id_to = 0
|
||||
episode_id = 0
|
||||
total_frames = dataset_dict["actions"].shape[0]
|
||||
for i in tqdm.tqdm(range(total_frames)):
|
||||
id_to += 1
|
||||
|
||||
if not dataset_dict["dones"][i]:
|
||||
continue
|
||||
|
||||
num_frames = id_to - id_from
|
||||
|
||||
image = torch.tensor(dataset_dict["observations"]["rgb"][id_from:id_to])
|
||||
image = einops.rearrange(image, "b c h w -> b h w c")
|
||||
state = torch.tensor(dataset_dict["observations"]["state"][id_from:id_to])
|
||||
action = torch.tensor(dataset_dict["actions"][id_from:id_to])
|
||||
# TODO(rcadene): we have a missing last frame which is the observation when the env is done
|
||||
# it is critical to have this frame for tdmpc to predict a "done observation/state"
|
||||
# next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][id_from:id_to])
|
||||
# next_state = torch.tensor(dataset_dict["next_observations"]["state"][id_from:id_to])
|
||||
next_reward = torch.tensor(dataset_dict["rewards"][id_from:id_to])
|
||||
next_done = torch.tensor(dataset_dict["dones"][id_from:id_to])
|
||||
|
||||
ep_dict = {
|
||||
"observation.image": [PILImage.fromarray(x.numpy()) for x in image],
|
||||
"observation.state": state,
|
||||
"action": action,
|
||||
"episode_index": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
||||
"frame_index": torch.arange(0, num_frames, 1),
|
||||
"timestamp": torch.arange(0, num_frames, 1) / self.fps,
|
||||
# "next.observation.image": next_image,
|
||||
# "next.observation.state": next_state,
|
||||
"next.reward": next_reward,
|
||||
"next.done": next_done,
|
||||
}
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
episode_data_index["from"].append(id_from)
|
||||
episode_data_index["to"].append(id_from + num_frames)
|
||||
|
||||
id_from = id_to
|
||||
episode_id += 1
|
||||
|
||||
data_dict = concatenate_episodes(ep_dicts)
|
||||
return data_dict, episode_data_index
|
||||
|
||||
def to_hf_dataset(self, data_dict):
|
||||
features = {
|
||||
"observation.image": Image(),
|
||||
"observation.state": Sequence(
|
||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
),
|
||||
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
|
||||
"episode_index": Value(dtype="int64", id=None),
|
||||
"frame_index": Value(dtype="int64", id=None),
|
||||
"timestamp": Value(dtype="float32", id=None),
|
||||
"next.reward": Value(dtype="float32", id=None),
|
||||
"next.done": Value(dtype="bool", id=None),
|
||||
#'next.success': Value(dtype='bool', id=None),
|
||||
"index": Value(dtype="int64", id=None),
|
||||
}
|
||||
features = Features(features)
|
||||
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
|
||||
return hf_dataset
|
||||
|
||||
def cleanup(self):
|
||||
pass
|
|
@ -1,295 +1,215 @@
|
|||
"""
|
||||
Use this script to convert your dataset into LeRobot dataset format and upload it to the Hugging Face hub,
|
||||
or store it locally. LeRobot dataset format is lightweight, fast to load from, and does not require any
|
||||
installation of neural net specific packages like pytorch, tensorflow, jax.
|
||||
|
||||
Example:
|
||||
```
|
||||
python lerobot/scripts/push_dataset_to_hub.py \
|
||||
--data-dir data \
|
||||
--dataset-id pusht \
|
||||
--raw-format pusht_zarr \
|
||||
--community-id lerobot \
|
||||
--revision v1.2 \
|
||||
--dry-run 1 \
|
||||
--save-to-disk 1 \
|
||||
--save-tests-to-disk 0 \
|
||||
--debug 1
|
||||
|
||||
python lerobot/scripts/push_dataset_to_hub.py \
|
||||
--data-dir data \
|
||||
--dataset-id xarm_lift_medium \
|
||||
--raw-format xarm_pkl \
|
||||
--community-id lerobot \
|
||||
--revision v1.2 \
|
||||
--dry-run 1 \
|
||||
--save-to-disk 1 \
|
||||
--save-tests-to-disk 0 \
|
||||
--debug 1
|
||||
|
||||
python lerobot/scripts/push_dataset_to_hub.py \
|
||||
--data-dir data \
|
||||
--dataset-id aloha_sim_insertion_scripted \
|
||||
--raw-format aloha_hdf5 \
|
||||
--community-id lerobot \
|
||||
--revision v1.2 \
|
||||
--dry-run 1 \
|
||||
--save-to-disk 1 \
|
||||
--save-tests-to-disk 0 \
|
||||
--debug 1
|
||||
|
||||
python lerobot/scripts/push_dataset_to_hub.py \
|
||||
--data-dir data \
|
||||
--dataset-id umi_cup_in_the_wild \
|
||||
--raw-format umi_zarr \
|
||||
--community-id lerobot \
|
||||
--revision v1.2 \
|
||||
--dry-run 1 \
|
||||
--save-to-disk 1 \
|
||||
--save-tests-to-disk 0 \
|
||||
--debug 1
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any, Protocol
|
||||
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from huggingface_hub import HfApi
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
|
||||
from lerobot.common.datasets.push_dataset_to_hub.aloha_processor import (
|
||||
AlohaProcessor,
|
||||
)
|
||||
from lerobot.common.datasets.push_dataset_to_hub.pusht_processor import PushTProcessor
|
||||
from lerobot.common.datasets.push_dataset_to_hub.umi_processor import UmiProcessor
|
||||
from lerobot.common.datasets.push_dataset_to_hub.xarm_processor import XarmProcessor
|
||||
from lerobot.common.datasets.utils import compute_stats, flatten_dict
|
||||
|
||||
|
||||
def push_lerobot_dataset_to_hub(
|
||||
hf_dataset: Dataset,
|
||||
episode_data_index: dict[str, list[int]],
|
||||
info: dict[str, Any],
|
||||
stats: dict[str, dict[str, torch.Tensor]],
|
||||
root: Path,
|
||||
revision: str,
|
||||
dataset_id: str,
|
||||
community_id: str = "lerobot",
|
||||
dry_run: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Pushes a dataset to the Hugging Face Hub.
|
||||
def get_from_raw_to_lerobot_format_fn(raw_format):
|
||||
if raw_format == "pusht_zarr":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.pusht_zarr_format import from_raw_to_lerobot_format
|
||||
elif raw_format == "umi_zarr":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format import from_raw_to_lerobot_format
|
||||
elif raw_format == "aloha_hdf5":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import from_raw_to_lerobot_format
|
||||
elif raw_format == "xarm_pkl":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.xarm_pkl_format import from_raw_to_lerobot_format
|
||||
else:
|
||||
raise ValueError(raw_format)
|
||||
|
||||
Args:
|
||||
hf_dataset (Dataset): The dataset to be pushed.
|
||||
episode_data_index (dict[str, list[int]]): The index of episode data.
|
||||
info (dict[str, Any]): Information about the dataset, eg. fps.
|
||||
stats (dict[str, dict[str, torch.Tensor]]): Statistics of the dataset.
|
||||
root (Path): The root directory of the dataset.
|
||||
revision (str): The revision of the dataset.
|
||||
dataset_id (str): The ID of the dataset.
|
||||
community_id (str, optional): The ID of the community or the user where the
|
||||
dataset will be stored. Defaults to "lerobot".
|
||||
dry_run (bool, optional): If True, performs a dry run without actually pushing the dataset. Defaults to False.
|
||||
"""
|
||||
if not dry_run:
|
||||
# push to main to indicate latest version
|
||||
hf_dataset.push_to_hub(f"{community_id}/{dataset_id}", token=True)
|
||||
return from_raw_to_lerobot_format
|
||||
|
||||
# push to version branch
|
||||
hf_dataset.push_to_hub(f"{community_id}/{dataset_id}", token=True, revision=revision)
|
||||
|
||||
# create and store meta_data
|
||||
meta_data_dir = root / community_id / dataset_id / "meta_data"
|
||||
def save_meta_data(info, stats, episode_data_index, meta_data_dir):
|
||||
meta_data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# info
|
||||
# save info
|
||||
info_path = meta_data_dir / "info.json"
|
||||
|
||||
with open(str(info_path), "w") as f:
|
||||
json.dump(info, f, indent=4)
|
||||
# stats
|
||||
|
||||
# save stats
|
||||
stats_path = meta_data_dir / "stats.safetensors"
|
||||
save_file(flatten_dict(stats), stats_path)
|
||||
|
||||
# episode_data_index
|
||||
# save episode_data_index
|
||||
episode_data_index = {key: torch.tensor(episode_data_index[key]) for key in episode_data_index}
|
||||
ep_data_idx_path = meta_data_dir / "episode_data_index.safetensors"
|
||||
save_file(episode_data_index, ep_data_idx_path)
|
||||
|
||||
if not dry_run:
|
||||
|
||||
def push_meta_data_to_hub(meta_data_dir, repo_id, revision):
|
||||
api = HfApi()
|
||||
|
||||
def upload(filename, revision):
|
||||
api.upload_file(
|
||||
path_or_fileobj=info_path,
|
||||
path_in_repo=str(info_path).replace(f"{root}/{community_id}/{dataset_id}", ""),
|
||||
repo_id=f"{community_id}/{dataset_id}",
|
||||
repo_type="dataset",
|
||||
)
|
||||
api.upload_file(
|
||||
path_or_fileobj=info_path,
|
||||
path_in_repo=str(info_path).replace(f"{root}/{community_id}/{dataset_id}", ""),
|
||||
repo_id=f"{community_id}/{dataset_id}",
|
||||
repo_type="dataset",
|
||||
path_or_fileobj=meta_data_dir / filename,
|
||||
path_in_repo=f"meta_data/{filename}",
|
||||
repo_id=repo_id,
|
||||
revision=revision,
|
||||
repo_type="dataset",
|
||||
)
|
||||
|
||||
# stats
|
||||
api.upload_file(
|
||||
path_or_fileobj=stats_path,
|
||||
path_in_repo=str(stats_path).replace(f"{root}/{community_id}/{dataset_id}", ""),
|
||||
repo_id=f"{community_id}/{dataset_id}",
|
||||
repo_type="dataset",
|
||||
)
|
||||
api.upload_file(
|
||||
path_or_fileobj=stats_path,
|
||||
path_in_repo=str(stats_path).replace(f"{root}/{community_id}/{dataset_id}", ""),
|
||||
repo_id=f"{community_id}/{dataset_id}",
|
||||
repo_type="dataset",
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
api.upload_file(
|
||||
path_or_fileobj=ep_data_idx_path,
|
||||
path_in_repo=str(ep_data_idx_path).replace(f"{root}/{community_id}/{dataset_id}", ""),
|
||||
repo_id=f"{community_id}/{dataset_id}",
|
||||
repo_type="dataset",
|
||||
)
|
||||
api.upload_file(
|
||||
path_or_fileobj=ep_data_idx_path,
|
||||
path_in_repo=str(ep_data_idx_path).replace(f"{root}/{community_id}/{dataset_id}", ""),
|
||||
repo_id=f"{community_id}/{dataset_id}",
|
||||
repo_type="dataset",
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
# copy in tests folder, the first episode and the meta_data directory
|
||||
num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0]
|
||||
hf_dataset.select(range(num_items_first_ep)).with_format("torch").save_to_disk(
|
||||
f"tests/data/{community_id}/{dataset_id}/train"
|
||||
)
|
||||
if Path(f"tests/data/{community_id}/{dataset_id}/meta_data").exists():
|
||||
shutil.rmtree(f"tests/data/{community_id}/{dataset_id}/meta_data")
|
||||
shutil.copytree(meta_data_dir, f"tests/data/{community_id}/{dataset_id}/meta_data")
|
||||
upload("info.json", "main")
|
||||
upload("info.json", revision)
|
||||
upload("stats.safetensors", "main")
|
||||
upload("stats.safetensors", revision)
|
||||
upload("episode_data_index.safetensors", "main")
|
||||
upload("episode_data_index.safetensors", revision)
|
||||
|
||||
|
||||
def push_dataset_to_hub(
|
||||
data_dir: Path,
|
||||
dataset_id: str,
|
||||
root: Path,
|
||||
raw_format: str | None,
|
||||
community_id: str,
|
||||
revision: str,
|
||||
dry_run: bool,
|
||||
save_to_disk: bool,
|
||||
tests_data_dir: Path,
|
||||
save_tests_to_disk: bool,
|
||||
fps: int | None,
|
||||
dataset_folder: Path | None = None,
|
||||
dry_run: bool = False,
|
||||
revision: str = "v1.1",
|
||||
community_id: str = "lerobot",
|
||||
no_preprocess: bool = False,
|
||||
path_save_to_disk: str | None = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Download a raw dataset if needed or access a local raw dataset, detect the raw format (e.g. aloha, pusht, umi) and process it accordingly in a common data format which is then pushed to the Hugging Face Hub.
|
||||
video: bool,
|
||||
debug: bool,
|
||||
):
|
||||
raw_dir = data_dir / f"{dataset_id}_raw"
|
||||
|
||||
Args:
|
||||
dataset_id (str): The ID of the dataset.
|
||||
root (Path): The root directory where the dataset will be downloaded.
|
||||
fps (int | None): The desired frames per second for the dataset.
|
||||
dataset_folder (Path | None, optional): The path to the dataset folder. If not provided, the dataset will be downloaded using the dataset ID. Defaults to None.
|
||||
dry_run (bool, optional): If True, performs a dry run without actually pushing the dataset. Defaults to False.
|
||||
revision (str, optional): Version of the `push_dataset_to_hub.py` codebase used to preprocess the dataset. Defaults to "v1.1".
|
||||
community_id (str, optional): The ID of the community. Defaults to "lerobot".
|
||||
no_preprocess (bool, optional): If True, does not preprocesses the dataset. Defaults to False.
|
||||
path_save_to_disk (str | None, optional): The path to save the dataset to disk. Works when `dry_run` is True, which allows to only save on disk without uploading. By default, the dataset is not saved on disk.
|
||||
**kwargs: Additional keyword arguments for the preprocessor init method.
|
||||
out_dir = data_dir / community_id / dataset_id
|
||||
meta_data_dir = out_dir / "meta_data"
|
||||
videos_dir = out_dir / "videos"
|
||||
|
||||
tests_out_dir = tests_data_dir / community_id / dataset_id
|
||||
tests_meta_data_dir = tests_out_dir / "meta_data"
|
||||
|
||||
"""
|
||||
if dataset_folder is None:
|
||||
dataset_folder = download_raw(root=root, dataset_id=dataset_id)
|
||||
if out_dir.exists():
|
||||
shutil.rmtree(out_dir)
|
||||
|
||||
if not no_preprocess:
|
||||
processor = guess_dataset_type(dataset_folder=dataset_folder, fps=fps, **kwargs)
|
||||
data_dict, episode_data_index = processor.preprocess()
|
||||
hf_dataset = processor.to_hf_dataset(data_dict)
|
||||
if tests_out_dir.exists():
|
||||
shutil.rmtree(tests_out_dir)
|
||||
|
||||
info = {
|
||||
"fps": processor.fps,
|
||||
}
|
||||
stats: dict[str, dict[str, torch.Tensor]] = compute_stats(hf_dataset)
|
||||
if not raw_dir.exists():
|
||||
download_raw(raw_dir, dataset_id)
|
||||
|
||||
push_lerobot_dataset_to_hub(
|
||||
hf_dataset=hf_dataset,
|
||||
episode_data_index=episode_data_index,
|
||||
info=info,
|
||||
stats=stats,
|
||||
root=root,
|
||||
revision=revision,
|
||||
dataset_id=dataset_id,
|
||||
community_id=community_id,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
if path_save_to_disk:
|
||||
hf_dataset.with_format("torch").save_to_disk(dataset_path=str(path_save_to_disk))
|
||||
if raw_format is None:
|
||||
# TODO(rcadene, adilzouitine): implement auto_find_raw_format
|
||||
raise NotImplementedError()
|
||||
# raw_format = auto_find_raw_format(raw_dir)
|
||||
|
||||
processor.cleanup()
|
||||
from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format)
|
||||
|
||||
# convert dataset from original raw format to LeRobot format
|
||||
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(raw_dir, out_dir, fps, video, debug)
|
||||
|
||||
class DatasetProcessor(Protocol):
|
||||
"""A class for processing datasets.
|
||||
stats = compute_stats(hf_dataset)
|
||||
|
||||
This class provides methods for validating, preprocessing, and converting datasets.
|
||||
if save_to_disk:
|
||||
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
|
||||
hf_dataset.save_to_disk(str(out_dir / "train"))
|
||||
|
||||
Args:
|
||||
folder_path (str): The path to the folder containing the dataset.
|
||||
fps (int | None): The frames per second of the dataset. If None, the default value is used.
|
||||
*args: Additional positional arguments.
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
if not dry_run or save_to_disk:
|
||||
# mandatory for upload
|
||||
save_meta_data(info, stats, episode_data_index, meta_data_dir)
|
||||
|
||||
def __init__(self, folder_path: str, fps: int | None, *args, **kwargs) -> None: ...
|
||||
if not dry_run:
|
||||
repo_id = f"{community_id}/{dataset_id}"
|
||||
hf_dataset.push_to_hub(repo_id, token=True, revision="main")
|
||||
hf_dataset.push_to_hub(repo_id, token=True, revision=revision)
|
||||
push_meta_data_to_hub(repo_id, meta_data_dir)
|
||||
if video:
|
||||
push_meta_data_to_hub(repo_id, videos_dir)
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
"""Check if the dataset is valid.
|
||||
if save_tests_to_disk:
|
||||
# get the first episode
|
||||
num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0]
|
||||
test_hf_dataset = hf_dataset.select(range(num_items_first_ep))
|
||||
|
||||
Returns:
|
||||
bool: True if the dataset is valid, False otherwise.
|
||||
"""
|
||||
...
|
||||
test_hf_dataset = test_hf_dataset.with_format(None)
|
||||
test_hf_dataset.save_to_disk(str(tests_out_dir / "train"))
|
||||
|
||||
def preprocess(self) -> tuple[dict, dict]:
|
||||
"""Preprocess the dataset.
|
||||
|
||||
Returns:
|
||||
tuple[dict, dict]: A tuple containing two dictionaries representing the preprocessed data.
|
||||
"""
|
||||
...
|
||||
|
||||
def to_hf_dataset(self, data_dict: dict) -> Dataset:
|
||||
"""Convert the preprocessed data to a Hugging Face dataset.
|
||||
|
||||
Args:
|
||||
data_dict (dict): The preprocessed data.
|
||||
|
||||
Returns:
|
||||
Dataset: The converted Hugging Face dataset.
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
def fps(self) -> int:
|
||||
"""Get the frames per second of the dataset.
|
||||
|
||||
Returns:
|
||||
int: The frames per second.
|
||||
"""
|
||||
...
|
||||
|
||||
def cleanup(self):
|
||||
"""Clean up any resources used by the dataset processor."""
|
||||
...
|
||||
|
||||
|
||||
def guess_dataset_type(dataset_folder: Path, **processor_kwargs) -> DatasetProcessor:
|
||||
if (processor := AlohaProcessor(folder_path=dataset_folder, **processor_kwargs)).is_valid():
|
||||
return processor
|
||||
if (processor := XarmProcessor(folder_path=dataset_folder, **processor_kwargs)).is_valid():
|
||||
return processor
|
||||
if (processor := PushTProcessor(folder_path=dataset_folder, **processor_kwargs)).is_valid():
|
||||
return processor
|
||||
if (processor := UmiProcessor(folder_path=dataset_folder, **processor_kwargs)).is_valid():
|
||||
return processor
|
||||
# TODO: Propose a registration mechanism for new dataset types
|
||||
raise ValueError(f"Could not guess dataset type for folder {dataset_folder}")
|
||||
# copy meta data to tests directory
|
||||
if Path(tests_meta_data_dir).exists():
|
||||
shutil.rmtree(tests_meta_data_dir)
|
||||
shutil.copytree(meta_data_dir, tests_meta_data_dir)
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Main function to process command line arguments and push dataset to Hugging Face Hub.
|
||||
|
||||
Parses command line arguments to get dataset details and conditions under which the dataset
|
||||
is processed and pushed. It manages dataset preparation and uploading based on the user-defined parameters.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Push a dataset to the Hugging Face Hub with optional parameters for customization.",
|
||||
epilog="""
|
||||
Example usage:
|
||||
python -m lerobot.scripts.push_dataset_to_hub --dataset-folder /path/to/dataset --dataset-id example_dataset --root /path/to/root --dry-run --revision v2.0 --community-id example_community --fps 30 --path-save-to-disk /path/to/save --no-preprocess
|
||||
|
||||
This processes and optionally pushes 'example_dataset' located in '/path/to/dataset' to Hugging Face Hub,
|
||||
with various parameters to control the processing and uploading behavior.
|
||||
""",
|
||||
)
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--dataset-folder",
|
||||
"--data-dir",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="The filesystem path to the dataset folder. If not provided, the dataset must be identified and managed by other means.",
|
||||
required=True,
|
||||
help="Root directory containing datasets (e.g. `data` or `tmp/data` or `/tmp/lerobot/data`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Unique identifier for the dataset to be processed and uploaded.",
|
||||
help="Name of the dataset (e.g. `pusht`, `aloha_sim_insertion_human`), which matches the folder where the data is stored (e.g. `data/pusht`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root", type=Path, required=True, help="Root directory where the dataset operations are managed."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Simulate the push process without uploading any data, for testing purposes.",
|
||||
"--raw-format",
|
||||
type=str,
|
||||
help="Dataset type (e.g. `pusht_zarr`, `umi_zarr`, `aloha_hdf5`, `xarm_pkl`). If not provided, will be detected automatically.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--community-id",
|
||||
|
@ -297,41 +217,57 @@ def main():
|
|||
default="lerobot",
|
||||
help="Community or user ID under which the dataset will be hosted on the Hub.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fps",
|
||||
type=int,
|
||||
help="Target frame rate for video or image sequence datasets. Optional and applicable only if the dataset includes temporal media.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--revision",
|
||||
type=str,
|
||||
default="v1.0",
|
||||
help="Dataset version identifier to manage different iterations of the dataset.",
|
||||
default="v1.2",
|
||||
help="Codebase version used to generate the dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-preprocess",
|
||||
action="store_true",
|
||||
help="Does not preprocess the dataset, set this flag if you only want dowload the dataset raw.",
|
||||
"--dry-run",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Run everything without uploading to hub, for testing purposes or storing a dataset locally.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--path-save-to-disk",
|
||||
"--save-to-disk",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Save the dataset in the directory specified by `--data-dir`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tests-data-dir",
|
||||
type=Path,
|
||||
help="Optional path where the processed dataset can be saved locally.",
|
||||
default="tests/data",
|
||||
help="Directory containing tests artifacts datasets.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-tests-to-disk",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Save the dataset with 1 episode used for unit tests in the directory specified by `--tests-data-dir`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fps",
|
||||
type=int,
|
||||
help="Frame rate used to collect videos. If not provided, use the default one specified in the code.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--video",
|
||||
type=int,
|
||||
# TODO(rcadene): enable when video PR merges
|
||||
default=0,
|
||||
help="Convert each episode of the raw dataset to an mp4 video. This option allows 60 times lower disk space consumption and 25 faster loading time during training.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Debug mode process the first episode only.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
push_dataset_to_hub(
|
||||
dataset_folder=args.dataset_folder,
|
||||
dataset_id=args.dataset_id,
|
||||
root=args.root,
|
||||
fps=args.fps,
|
||||
dry_run=args.dry_run,
|
||||
community_id=args.community_id,
|
||||
revision=args.revision,
|
||||
no_preprocess=args.no_preprocess,
|
||||
path_save_to_disk=args.path_save_to_disk,
|
||||
)
|
||||
push_dataset_to_hub(**vars(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -11,6 +11,7 @@ Example usage:
|
|||
`python tests/script/save_dataset_to_safetensors.py`
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
|
@ -21,54 +22,56 @@ from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
|||
|
||||
|
||||
def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"):
|
||||
data_dir = Path(output_dir) / repo_id
|
||||
repo_dir = Path(output_dir) / repo_id
|
||||
|
||||
if data_dir.exists():
|
||||
shutil.rmtree(data_dir)
|
||||
if repo_dir.exists():
|
||||
shutil.rmtree(repo_dir)
|
||||
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
dataset = LeRobotDataset(repo_id=repo_id, root=data_dir)
|
||||
repo_dir.mkdir(parents=True, exist_ok=True)
|
||||
dataset = LeRobotDataset(
|
||||
repo_id=repo_id, root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||
)
|
||||
|
||||
# save 2 first frames of first episode
|
||||
i = dataset.episode_data_index["from"][0].item()
|
||||
save_file(dataset[i], data_dir / f"frame_{i}.safetensors")
|
||||
save_file(dataset[i + 1], data_dir / f"frame_{i+1}.safetensors")
|
||||
save_file(dataset[i], repo_dir / f"frame_{i}.safetensors")
|
||||
save_file(dataset[i + 1], repo_dir / f"frame_{i+1}.safetensors")
|
||||
|
||||
# save 2 frames at the middle of first episode
|
||||
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
|
||||
save_file(dataset[i], data_dir / f"frame_{i}.safetensors")
|
||||
save_file(dataset[i + 1], data_dir / f"frame_{i+1}.safetensors")
|
||||
save_file(dataset[i], repo_dir / f"frame_{i}.safetensors")
|
||||
save_file(dataset[i + 1], repo_dir / f"frame_{i+1}.safetensors")
|
||||
|
||||
# save 2 last frames of first episode
|
||||
i = dataset.episode_data_index["to"][0].item()
|
||||
save_file(dataset[i - 2], data_dir / f"frame_{i-2}.safetensors")
|
||||
save_file(dataset[i - 1], data_dir / f"frame_{i-1}.safetensors")
|
||||
save_file(dataset[i - 2], repo_dir / f"frame_{i-2}.safetensors")
|
||||
save_file(dataset[i - 1], repo_dir / f"frame_{i-1}.safetensors")
|
||||
|
||||
# TODO(rcadene): Enable testing on second and last episode
|
||||
# We currently cant because our test dataset only contains the first episode
|
||||
|
||||
# # save 2 first frames of second episode
|
||||
# i = dataset.episode_data_index["from"][1].item()
|
||||
# save_file(dataset[i], data_dir / f"frame_{i}.safetensors")
|
||||
# save_file(dataset[i+1], data_dir / f"frame_{i+1}.safetensors")
|
||||
# save_file(dataset[i], repo_dir / f"frame_{i}.safetensors")
|
||||
# save_file(dataset[i + 1], repo_dir / f"frame_{i+1}.safetensors")
|
||||
|
||||
# # save 2 last frames of second episode
|
||||
# i = dataset.episode_data_index["to"][1].item()
|
||||
# save_file(dataset[i-2], data_dir / f"frame_{i-2}.safetensors")
|
||||
# save_file(dataset[i-1], data_dir / f"frame_{i-1}.safetensors")
|
||||
# save_file(dataset[i - 2], repo_dir / f"frame_{i-2}.safetensors")
|
||||
# save_file(dataset[i - 1], repo_dir / f"frame_{i-1}.safetensors")
|
||||
|
||||
# # save 2 last frames of last episode
|
||||
# i = dataset.episode_data_index["to"][-1].item()
|
||||
# save_file(dataset[i-2], data_dir / f"frame_{i-2}.safetensors")
|
||||
# save_file(dataset[i-1], data_dir / f"frame_{i-1}.safetensors")
|
||||
# save_file(dataset[i - 2], repo_dir / f"frame_{i-2}.safetensors")
|
||||
# save_file(dataset[i - 1], repo_dir / f"frame_{i-1}.safetensors")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
available_datasets = [
|
||||
"lerobot/pusht",
|
||||
"lerobot/xarm_push_medium",
|
||||
"lerobot/xarm_lift_medium",
|
||||
"lerobot/aloha_sim_insertion_human",
|
||||
"lerobot/umi_cup_in_the_wild",
|
||||
# "lerobot/umi_cup_in_the_wild",
|
||||
]
|
||||
for dataset in available_datasets:
|
||||
save_dataset_to_safetensors("tests/data/save_dataset_to_safetensors", repo_id=dataset)
|
||||
|
|
|
@ -12,7 +12,9 @@ from safetensors.torch import load_file
|
|||
|
||||
import lerobot
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.lerobot_dataset import (
|
||||
LeRobotDataset,
|
||||
)
|
||||
from lerobot.common.datasets.utils import (
|
||||
compute_stats,
|
||||
flatten_dict,
|
||||
|
@ -22,8 +24,7 @@ from lerobot.common.datasets.utils import (
|
|||
unflatten_dict,
|
||||
)
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
|
||||
from .utils import DEFAULT_CONFIG_PATH, DEVICE
|
||||
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE
|
||||
|
||||
|
||||
@pytest.mark.parametrize("env_name, repo_id, policy_name", lerobot.env_dataset_policy_triplets)
|
||||
|
@ -238,35 +239,35 @@ def test_flatten_unflatten_dict():
|
|||
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}"
|
||||
|
||||
|
||||
def test_backward_compatibility():
|
||||
"""This tests artifacts have been generated by `tests/scripts/save_dataset_to_safetensors.py`."""
|
||||
|
||||
all_repo_id = [
|
||||
@pytest.mark.parametrize(
|
||||
"repo_id",
|
||||
[
|
||||
"lerobot/pusht",
|
||||
# TODO (azouitine): Add artifacts for the following datasets
|
||||
# "lerobot/aloha_sim_insertion_human",
|
||||
# "lerobot/xarm_push_medium",
|
||||
# "lerobot/umi_cup_in_the_wild",
|
||||
]
|
||||
for repo_id in all_repo_id:
|
||||
"lerobot/aloha_sim_insertion_human",
|
||||
"lerobot/xarm_lift_medium",
|
||||
],
|
||||
)
|
||||
def test_backward_compatibility(repo_id):
|
||||
"""The artifacts for this test have been generated by `tests/scripts/save_dataset_to_safetensors.py`."""
|
||||
|
||||
dataset = LeRobotDataset(
|
||||
repo_id,
|
||||
root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None,
|
||||
)
|
||||
|
||||
data_dir = Path("tests/data/save_dataset_to_safetensors") / repo_id
|
||||
test_dir = Path("tests/data/save_dataset_to_safetensors") / repo_id
|
||||
|
||||
def load_and_compare(i):
|
||||
new_frame = dataset[i] # noqa: B023
|
||||
old_frame = load_file(data_dir / f"frame_{i}.safetensors") # noqa: B023
|
||||
old_frame = load_file(test_dir / f"frame_{i}.safetensors") # noqa: B023
|
||||
|
||||
new_keys = set(new_frame.keys())
|
||||
old_keys = set(old_frame.keys())
|
||||
assert new_keys == old_keys, f"{new_keys=} and {old_keys=} are not the same"
|
||||
|
||||
for key in new_frame:
|
||||
assert (
|
||||
new_frame[key] == old_frame[key]
|
||||
assert torch.isclose(
|
||||
new_frame[key], old_frame[key], rtol=1e-05, atol=1e-08
|
||||
).all(), f"{key=} for index={i} does not contain the same value"
|
||||
|
||||
# test2 first frames of first episode
|
||||
|
@ -275,9 +276,7 @@ def test_backward_compatibility():
|
|||
load_and_compare(i + 1)
|
||||
|
||||
# test 2 frames at the middle of first episode
|
||||
i = int(
|
||||
(dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2
|
||||
)
|
||||
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
|
||||
load_and_compare(i)
|
||||
load_and_compare(i + 1)
|
||||
|
||||
|
|
Loading…
Reference in New Issue