Fix generation of dataset test artifact (#306)

This commit is contained in:
Alexander Soare 2024-07-05 11:02:26 +01:00 committed by GitHub
parent 74362ac453
commit 7bd5ab16d1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 24 additions and 7 deletions

View File

@ -222,6 +222,7 @@ def push_dataset_to_hub(
# get the first episode # get the first episode
num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0] num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0]
test_hf_dataset = hf_dataset.select(range(num_items_first_ep)) test_hf_dataset = hf_dataset.select(range(num_items_first_ep))
episode_data_index = {k: v[:1] for k, v in episode_data_index.items()}
test_hf_dataset = test_hf_dataset.with_format(None) test_hf_dataset = test_hf_dataset.with_format(None)
test_hf_dataset.save_to_disk(str(tests_data_dir / repo_id / "train")) test_hf_dataset.save_to_disk(str(tests_data_dir / repo_id / "train"))

View File

@ -251,17 +251,18 @@ def test_push_dataset_to_hub_out_dir_force_override_false(tmpdir):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"required_packages, raw_format, repo_id", "required_packages, raw_format, repo_id, make_test_data",
[ [
(["gym-pusht"], "pusht_zarr", "lerobot/pusht"), (["gym_pusht"], "pusht_zarr", "lerobot/pusht", False),
(None, "xarm_pkl", "lerobot/xarm_lift_medium"), (["gym_pusht"], "pusht_zarr", "lerobot/pusht", True),
(None, "aloha_hdf5", "lerobot/aloha_sim_insertion_scripted"), (None, "xarm_pkl", "lerobot/xarm_lift_medium", False),
(["imagecodecs"], "umi_zarr", "lerobot/umi_cup_in_the_wild"), (None, "aloha_hdf5", "lerobot/aloha_sim_insertion_scripted", False),
(None, "dora_parquet", "cadene/wrist_gripper"), (["imagecodecs"], "umi_zarr", "lerobot/umi_cup_in_the_wild", False),
(None, "dora_parquet", "cadene/wrist_gripper", False),
], ],
) )
@require_package_arg @require_package_arg
def test_push_dataset_to_hub_format(required_packages, tmpdir, raw_format, repo_id): def test_push_dataset_to_hub_format(required_packages, tmpdir, raw_format, repo_id, make_test_data):
num_episodes = 3 num_episodes = 3
tmpdir = Path(tmpdir) tmpdir = Path(tmpdir)
@ -278,6 +279,7 @@ def test_push_dataset_to_hub_format(required_packages, tmpdir, raw_format, repo_
local_dir=local_dir, local_dir=local_dir,
force_override=False, force_override=False,
cache_dir=tmpdir / "cache", cache_dir=tmpdir / "cache",
tests_data_dir=tmpdir / "tests/data" if make_test_data else None,
) )
# minimal generic tests on the local directory containing LeRobotDataset # minimal generic tests on the local directory containing LeRobotDataset
@ -299,6 +301,20 @@ def test_push_dataset_to_hub_format(required_packages, tmpdir, raw_format, repo_
for cam_key in lerobot_dataset.camera_keys: for cam_key in lerobot_dataset.camera_keys:
assert cam_key in item assert cam_key in item
if make_test_data:
# Check that only the first episode is selected.
test_dataset = LeRobotDataset(repo_id=repo_id, root=tmpdir / "tests/data")
num_frames = sum(
i == lerobot_dataset.hf_dataset["episode_index"][0]
for i in lerobot_dataset.hf_dataset["episode_index"]
).item()
assert (
test_dataset.hf_dataset["episode_index"]
== lerobot_dataset.hf_dataset["episode_index"][:num_frames]
)
for k in ["from", "to"]:
assert torch.equal(test_dataset.episode_data_index[k], lerobot_dataset.episode_data_index[k][:1])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"raw_format, repo_id", "raw_format, repo_id",