369 lines
13 KiB
Python
369 lines
13 KiB
Python
"""
|
|
This file contains generic tests to ensure that nothing breaks if we modify the push_dataset_to_hub API.
|
|
Also, this file contains backward compatibility tests. Because they are slow and require to download the raw datasets,
|
|
we skip them for now in our CI.
|
|
|
|
Example to run backward compatiblity tests locally:
|
|
```
|
|
DATA_DIR=tests/data python -m pytest --run-skipped tests/test_push_dataset_to_hub.py::test_push_dataset_to_hub_pusht_backward_compatibility
|
|
```
|
|
"""
|
|
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
|
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
|
from lerobot.common.datasets.push_dataset_to_hub.utils import save_images_concurrently
|
|
from lerobot.common.datasets.video_utils import encode_video_frames
|
|
from lerobot.scripts.push_dataset_to_hub import push_dataset_to_hub
|
|
from tests.utils import require_package_arg
|
|
|
|
|
|
def _mock_download_raw_pusht(raw_dir, num_frames=4, num_episodes=3):
|
|
import zarr
|
|
|
|
raw_dir.mkdir(parents=True, exist_ok=True)
|
|
zarr_path = raw_dir / "pusht_cchi_v7_replay.zarr"
|
|
store = zarr.DirectoryStore(zarr_path)
|
|
zarr_data = zarr.group(store=store)
|
|
|
|
zarr_data.create_dataset(
|
|
"data/action", shape=(num_frames, 1), chunks=(num_frames, 1), dtype=np.float32, overwrite=True
|
|
)
|
|
zarr_data.create_dataset(
|
|
"data/img",
|
|
shape=(num_frames, 96, 96, 3),
|
|
chunks=(num_frames, 96, 96, 3),
|
|
dtype=np.uint8,
|
|
overwrite=True,
|
|
)
|
|
zarr_data.create_dataset(
|
|
"data/n_contacts", shape=(num_frames, 2), chunks=(num_frames, 2), dtype=np.float32, overwrite=True
|
|
)
|
|
zarr_data.create_dataset(
|
|
"data/state", shape=(num_frames, 5), chunks=(num_frames, 5), dtype=np.float32, overwrite=True
|
|
)
|
|
zarr_data.create_dataset(
|
|
"data/keypoint", shape=(num_frames, 9, 2), chunks=(num_frames, 9, 2), dtype=np.float32, overwrite=True
|
|
)
|
|
zarr_data.create_dataset(
|
|
"meta/episode_ends", shape=(num_episodes,), chunks=(num_episodes,), dtype=np.int32, overwrite=True
|
|
)
|
|
|
|
zarr_data["data/action"][:] = np.random.randn(num_frames, 1)
|
|
zarr_data["data/img"][:] = np.random.randint(0, 255, size=(num_frames, 96, 96, 3), dtype=np.uint8)
|
|
zarr_data["data/n_contacts"][:] = np.random.randn(num_frames, 2)
|
|
zarr_data["data/state"][:] = np.random.randn(num_frames, 5)
|
|
zarr_data["data/keypoint"][:] = np.random.randn(num_frames, 9, 2)
|
|
zarr_data["meta/episode_ends"][:] = np.array([1, 3, 4])
|
|
|
|
store.close()
|
|
|
|
|
|
def _mock_download_raw_umi(raw_dir, num_frames=4, num_episodes=3):
|
|
import zarr
|
|
|
|
raw_dir.mkdir(parents=True, exist_ok=True)
|
|
zarr_path = raw_dir / "cup_in_the_wild.zarr"
|
|
store = zarr.DirectoryStore(zarr_path)
|
|
zarr_data = zarr.group(store=store)
|
|
|
|
zarr_data.create_dataset(
|
|
"data/camera0_rgb",
|
|
shape=(num_frames, 96, 96, 3),
|
|
chunks=(num_frames, 96, 96, 3),
|
|
dtype=np.uint8,
|
|
overwrite=True,
|
|
)
|
|
zarr_data.create_dataset(
|
|
"data/robot0_demo_end_pose",
|
|
shape=(num_frames, 5),
|
|
chunks=(num_frames, 5),
|
|
dtype=np.float32,
|
|
overwrite=True,
|
|
)
|
|
zarr_data.create_dataset(
|
|
"data/robot0_demo_start_pose",
|
|
shape=(num_frames, 5),
|
|
chunks=(num_frames, 5),
|
|
dtype=np.float32,
|
|
overwrite=True,
|
|
)
|
|
zarr_data.create_dataset(
|
|
"data/robot0_eef_pos", shape=(num_frames, 5), chunks=(num_frames, 5), dtype=np.float32, overwrite=True
|
|
)
|
|
zarr_data.create_dataset(
|
|
"data/robot0_eef_rot_axis_angle",
|
|
shape=(num_frames, 5),
|
|
chunks=(num_frames, 5),
|
|
dtype=np.float32,
|
|
overwrite=True,
|
|
)
|
|
zarr_data.create_dataset(
|
|
"data/robot0_gripper_width",
|
|
shape=(num_frames, 5),
|
|
chunks=(num_frames, 5),
|
|
dtype=np.float32,
|
|
overwrite=True,
|
|
)
|
|
zarr_data.create_dataset(
|
|
"meta/episode_ends", shape=(num_episodes,), chunks=(num_episodes,), dtype=np.int32, overwrite=True
|
|
)
|
|
|
|
zarr_data["data/camera0_rgb"][:] = np.random.randint(0, 255, size=(num_frames, 96, 96, 3), dtype=np.uint8)
|
|
zarr_data["data/robot0_demo_end_pose"][:] = np.random.randn(num_frames, 5)
|
|
zarr_data["data/robot0_demo_start_pose"][:] = np.random.randn(num_frames, 5)
|
|
zarr_data["data/robot0_eef_pos"][:] = np.random.randn(num_frames, 5)
|
|
zarr_data["data/robot0_eef_rot_axis_angle"][:] = np.random.randn(num_frames, 5)
|
|
zarr_data["data/robot0_gripper_width"][:] = np.random.randn(num_frames, 5)
|
|
zarr_data["meta/episode_ends"][:] = np.array([1, 3, 4])
|
|
|
|
store.close()
|
|
|
|
|
|
def _mock_download_raw_xarm(raw_dir, num_frames=4):
|
|
import pickle
|
|
|
|
dataset_dict = {
|
|
"observations": {
|
|
"rgb": np.random.randint(0, 255, size=(num_frames, 3, 84, 84), dtype=np.uint8),
|
|
"state": np.random.randn(num_frames, 4),
|
|
},
|
|
"actions": np.random.randn(num_frames, 3),
|
|
"rewards": np.random.randn(num_frames),
|
|
"masks": np.random.randn(num_frames),
|
|
"dones": np.array([False, True, True, True]),
|
|
}
|
|
|
|
raw_dir.mkdir(parents=True, exist_ok=True)
|
|
pkl_path = raw_dir / "buffer.pkl"
|
|
with open(pkl_path, "wb") as f:
|
|
pickle.dump(dataset_dict, f)
|
|
|
|
|
|
def _mock_download_raw_aloha(raw_dir, num_frames=6, num_episodes=3):
|
|
import h5py
|
|
|
|
for ep_idx in range(num_episodes):
|
|
raw_dir.mkdir(parents=True, exist_ok=True)
|
|
path_h5 = raw_dir / f"episode_{ep_idx}.hdf5"
|
|
with h5py.File(str(path_h5), "w") as f:
|
|
f.create_dataset("action", data=np.random.randn(num_frames // num_episodes, 14))
|
|
f.create_dataset("observations/qpos", data=np.random.randn(num_frames // num_episodes, 14))
|
|
f.create_dataset("observations/qvel", data=np.random.randn(num_frames // num_episodes, 14))
|
|
f.create_dataset(
|
|
"observations/images/top",
|
|
data=np.random.randint(
|
|
0, 255, size=(num_frames // num_episodes, 480, 640, 3), dtype=np.uint8
|
|
),
|
|
)
|
|
|
|
|
|
def _mock_download_raw_dora(raw_dir, num_frames=6, num_episodes=3, fps=30):
|
|
from datetime import datetime, timedelta, timezone
|
|
|
|
import pandas
|
|
|
|
def write_parquet(key, timestamps, values):
|
|
data = {
|
|
"timestamp_utc": timestamps,
|
|
key: values,
|
|
}
|
|
df = pandas.DataFrame(data)
|
|
raw_dir.mkdir(parents=True, exist_ok=True)
|
|
df.to_parquet(raw_dir / f"{key}.parquet", engine="pyarrow")
|
|
|
|
episode_indices = [None, None, -1, None, None, -1, None, None, -1]
|
|
episode_indices_mapping = [0, 0, 0, 1, 1, 1, 2, 2, 2]
|
|
frame_indices = [0, 1, -1, 0, 1, -1, 0, 1, -1]
|
|
|
|
cam_key = "observation.images.cam_high"
|
|
timestamps = []
|
|
actions = []
|
|
states = []
|
|
frames = []
|
|
# `+ num_episodes`` for buffer frames associated to episode_index=-1
|
|
for i, frame_idx in enumerate(frame_indices):
|
|
t_utc = datetime.now(timezone.utc) + timedelta(seconds=i / fps)
|
|
action = np.random.randn(21).tolist()
|
|
state = np.random.randn(21).tolist()
|
|
ep_idx = episode_indices_mapping[i]
|
|
frame = [{"path": f"videos/{cam_key}_episode_{ep_idx:06d}.mp4", "timestamp": frame_idx / fps}]
|
|
timestamps.append(t_utc)
|
|
actions.append(action)
|
|
states.append(state)
|
|
frames.append(frame)
|
|
|
|
write_parquet(cam_key, timestamps, frames)
|
|
write_parquet("observation.state", timestamps, states)
|
|
write_parquet("action", timestamps, actions)
|
|
write_parquet("episode_index", timestamps, episode_indices)
|
|
|
|
# write fake mp4 file for each episode
|
|
for ep_idx in range(num_episodes):
|
|
imgs_array = np.random.randint(0, 255, size=(num_frames // num_episodes, 480, 640, 3), dtype=np.uint8)
|
|
|
|
tmp_imgs_dir = raw_dir / "tmp_images"
|
|
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
|
|
|
fname = f"{cam_key}_episode_{ep_idx:06d}.mp4"
|
|
video_path = raw_dir / "videos" / fname
|
|
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
|
|
|
|
|
def _mock_download_raw(raw_dir, repo_id):
|
|
if "wrist_gripper" in repo_id:
|
|
_mock_download_raw_dora(raw_dir)
|
|
elif "aloha" in repo_id:
|
|
_mock_download_raw_aloha(raw_dir)
|
|
elif "pusht" in repo_id:
|
|
_mock_download_raw_pusht(raw_dir)
|
|
elif "xarm" in repo_id:
|
|
_mock_download_raw_xarm(raw_dir)
|
|
elif "umi" in repo_id:
|
|
_mock_download_raw_umi(raw_dir)
|
|
else:
|
|
raise ValueError(repo_id)
|
|
|
|
|
|
def test_push_dataset_to_hub_invalid_repo_id(tmpdir):
|
|
with pytest.raises(ValueError):
|
|
push_dataset_to_hub(Path(tmpdir), "raw_format", "invalid_repo_id")
|
|
|
|
|
|
def test_push_dataset_to_hub_out_dir_force_override_false(tmpdir):
|
|
tmpdir = Path(tmpdir)
|
|
out_dir = tmpdir / "out"
|
|
raw_dir = tmpdir / "raw"
|
|
# mkdir to skip download
|
|
raw_dir.mkdir(parents=True, exist_ok=True)
|
|
with pytest.raises(ValueError):
|
|
push_dataset_to_hub(
|
|
raw_dir=raw_dir,
|
|
raw_format="some_format",
|
|
repo_id="user/dataset",
|
|
local_dir=out_dir,
|
|
force_override=False,
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"required_packages, raw_format, repo_id, make_test_data",
|
|
[
|
|
(["gym_pusht"], "pusht_zarr", "lerobot/pusht", False),
|
|
(["gym_pusht"], "pusht_zarr", "lerobot/pusht", True),
|
|
(None, "xarm_pkl", "lerobot/xarm_lift_medium", False),
|
|
(None, "aloha_hdf5", "lerobot/aloha_sim_insertion_scripted", False),
|
|
(["imagecodecs"], "umi_zarr", "lerobot/umi_cup_in_the_wild", False),
|
|
(None, "dora_parquet", "cadene/wrist_gripper", False),
|
|
],
|
|
)
|
|
@require_package_arg
|
|
def test_push_dataset_to_hub_format(required_packages, tmpdir, raw_format, repo_id, make_test_data):
|
|
num_episodes = 3
|
|
tmpdir = Path(tmpdir)
|
|
|
|
raw_dir = tmpdir / f"{repo_id}_raw"
|
|
_mock_download_raw(raw_dir, repo_id)
|
|
|
|
local_dir = tmpdir / repo_id
|
|
|
|
lerobot_dataset = push_dataset_to_hub(
|
|
raw_dir=raw_dir,
|
|
raw_format=raw_format,
|
|
repo_id=repo_id,
|
|
push_to_hub=False,
|
|
local_dir=local_dir,
|
|
force_override=False,
|
|
cache_dir=tmpdir / "cache",
|
|
tests_data_dir=tmpdir / "tests/data" if make_test_data else None,
|
|
)
|
|
|
|
# minimal generic tests on the local directory containing LeRobotDataset
|
|
assert (local_dir / "meta_data" / "info.json").exists()
|
|
assert (local_dir / "meta_data" / "stats.safetensors").exists()
|
|
assert (local_dir / "meta_data" / "episode_data_index.safetensors").exists()
|
|
for i in range(num_episodes):
|
|
for cam_key in lerobot_dataset.camera_keys:
|
|
assert (local_dir / "videos" / f"{cam_key}_episode_{i:06d}.mp4").exists()
|
|
assert (local_dir / "train" / "dataset_info.json").exists()
|
|
assert (local_dir / "train" / "state.json").exists()
|
|
assert len(list((local_dir / "train").glob("*.arrow"))) > 0
|
|
|
|
# minimal generic tests on the item
|
|
item = lerobot_dataset[0]
|
|
assert "index" in item
|
|
assert "episode_index" in item
|
|
assert "timestamp" in item
|
|
for cam_key in lerobot_dataset.camera_keys:
|
|
assert cam_key in item
|
|
|
|
if make_test_data:
|
|
# Check that only the first episode is selected.
|
|
test_dataset = LeRobotDataset(repo_id=repo_id, root=tmpdir / "tests/data")
|
|
num_frames = sum(
|
|
i == lerobot_dataset.hf_dataset["episode_index"][0]
|
|
for i in lerobot_dataset.hf_dataset["episode_index"]
|
|
).item()
|
|
assert (
|
|
test_dataset.hf_dataset["episode_index"]
|
|
== lerobot_dataset.hf_dataset["episode_index"][:num_frames]
|
|
)
|
|
for k in ["from", "to"]:
|
|
assert torch.equal(test_dataset.episode_data_index[k], lerobot_dataset.episode_data_index[k][:1])
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"raw_format, repo_id",
|
|
[
|
|
# TODO(rcadene): add raw dataset test artifacts
|
|
("pusht_zarr", "lerobot/pusht"),
|
|
("xarm_pkl", "lerobot/xarm_lift_medium"),
|
|
("aloha_hdf5", "lerobot/aloha_sim_insertion_scripted"),
|
|
("umi_zarr", "lerobot/umi_cup_in_the_wild"),
|
|
("dora_parquet", "cadene/wrist_gripper"),
|
|
],
|
|
)
|
|
@pytest.mark.skip(
|
|
"Not compatible with our CI since it downloads raw datasets. Run with `DATA_DIR=tests/data python -m pytest --run-skipped tests/test_push_dataset_to_hub.py::test_push_dataset_to_hub_pusht_backward_compatibility`"
|
|
)
|
|
def test_push_dataset_to_hub_pusht_backward_compatibility(tmpdir, raw_format, repo_id):
|
|
_, dataset_id = repo_id.split("/")
|
|
|
|
tmpdir = Path(tmpdir)
|
|
raw_dir = tmpdir / f"{dataset_id}_raw"
|
|
local_dir = tmpdir / repo_id
|
|
|
|
push_dataset_to_hub(
|
|
raw_dir=raw_dir,
|
|
raw_format=raw_format,
|
|
repo_id=repo_id,
|
|
push_to_hub=False,
|
|
local_dir=local_dir,
|
|
force_override=False,
|
|
cache_dir=tmpdir / "cache",
|
|
episodes=[0],
|
|
)
|
|
|
|
ds_actual = LeRobotDataset(repo_id, root=tmpdir)
|
|
ds_reference = LeRobotDataset(repo_id)
|
|
|
|
assert len(ds_reference.hf_dataset) == len(ds_actual.hf_dataset)
|
|
|
|
def check_same_items(item1, item2):
|
|
assert item1.keys() == item2.keys(), "Keys mismatch"
|
|
|
|
for key in item1:
|
|
if isinstance(item1[key], torch.Tensor) and isinstance(item2[key], torch.Tensor):
|
|
assert torch.equal(item1[key], item2[key]), f"Mismatch found in key: {key}"
|
|
else:
|
|
assert item1[key] == item2[key], f"Mismatch found in key: {key}"
|
|
|
|
for i in range(len(ds_reference.hf_dataset)):
|
|
item_reference = ds_reference.hf_dataset[i]
|
|
item_actual = ds_actual.hf_dataset[i]
|
|
check_same_items(item_reference, item_actual)
|