From 7bd5ab16d1edfb00b7ed71b4cf8ee6a815b64bc1 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Fri, 5 Jul 2024 11:02:26 +0100 Subject: [PATCH] Fix generation of dataset test artifact (#306) --- lerobot/scripts/push_dataset_to_hub.py | 1 + tests/test_push_dataset_to_hub.py | 30 ++++++++++++++++++++------ 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/lerobot/scripts/push_dataset_to_hub.py b/lerobot/scripts/push_dataset_to_hub.py index 92a0cc45..e471d5bd 100644 --- a/lerobot/scripts/push_dataset_to_hub.py +++ b/lerobot/scripts/push_dataset_to_hub.py @@ -222,6 +222,7 @@ def push_dataset_to_hub( # get the first episode num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0] test_hf_dataset = hf_dataset.select(range(num_items_first_ep)) + episode_data_index = {k: v[:1] for k, v in episode_data_index.items()} test_hf_dataset = test_hf_dataset.with_format(None) test_hf_dataset.save_to_disk(str(tests_data_dir / repo_id / "train")) diff --git a/tests/test_push_dataset_to_hub.py b/tests/test_push_dataset_to_hub.py index 7ddbe7aa..9f3bff93 100644 --- a/tests/test_push_dataset_to_hub.py +++ b/tests/test_push_dataset_to_hub.py @@ -251,17 +251,18 @@ def test_push_dataset_to_hub_out_dir_force_override_false(tmpdir): @pytest.mark.parametrize( - "required_packages, raw_format, repo_id", + "required_packages, raw_format, repo_id, make_test_data", [ - (["gym-pusht"], "pusht_zarr", "lerobot/pusht"), - (None, "xarm_pkl", "lerobot/xarm_lift_medium"), - (None, "aloha_hdf5", "lerobot/aloha_sim_insertion_scripted"), - (["imagecodecs"], "umi_zarr", "lerobot/umi_cup_in_the_wild"), - (None, "dora_parquet", "cadene/wrist_gripper"), + (["gym_pusht"], "pusht_zarr", "lerobot/pusht", False), + (["gym_pusht"], "pusht_zarr", "lerobot/pusht", True), + (None, "xarm_pkl", "lerobot/xarm_lift_medium", False), + (None, "aloha_hdf5", "lerobot/aloha_sim_insertion_scripted", False), + (["imagecodecs"], "umi_zarr", "lerobot/umi_cup_in_the_wild", False), + (None, "dora_parquet", "cadene/wrist_gripper", False), ], ) @require_package_arg -def test_push_dataset_to_hub_format(required_packages, tmpdir, raw_format, repo_id): +def test_push_dataset_to_hub_format(required_packages, tmpdir, raw_format, repo_id, make_test_data): num_episodes = 3 tmpdir = Path(tmpdir) @@ -278,6 +279,7 @@ def test_push_dataset_to_hub_format(required_packages, tmpdir, raw_format, repo_ local_dir=local_dir, force_override=False, cache_dir=tmpdir / "cache", + tests_data_dir=tmpdir / "tests/data" if make_test_data else None, ) # minimal generic tests on the local directory containing LeRobotDataset @@ -299,6 +301,20 @@ def test_push_dataset_to_hub_format(required_packages, tmpdir, raw_format, repo_ for cam_key in lerobot_dataset.camera_keys: assert cam_key in item + if make_test_data: + # Check that only the first episode is selected. + test_dataset = LeRobotDataset(repo_id=repo_id, root=tmpdir / "tests/data") + num_frames = sum( + i == lerobot_dataset.hf_dataset["episode_index"][0] + for i in lerobot_dataset.hf_dataset["episode_index"] + ).item() + assert ( + test_dataset.hf_dataset["episode_index"] + == lerobot_dataset.hf_dataset["episode_index"][:num_frames] + ) + for k in ["from", "to"]: + assert torch.equal(test_dataset.episode_data_index[k], lerobot_dataset.episode_data_index[k][:1]) + @pytest.mark.parametrize( "raw_format, repo_id",