Fix generation of dataset test artifact (#306)
This commit is contained in:
parent
74362ac453
commit
7bd5ab16d1
|
@ -222,6 +222,7 @@ def push_dataset_to_hub(
|
|||
# get the first episode
|
||||
num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0]
|
||||
test_hf_dataset = hf_dataset.select(range(num_items_first_ep))
|
||||
episode_data_index = {k: v[:1] for k, v in episode_data_index.items()}
|
||||
|
||||
test_hf_dataset = test_hf_dataset.with_format(None)
|
||||
test_hf_dataset.save_to_disk(str(tests_data_dir / repo_id / "train"))
|
||||
|
|
|
@ -251,17 +251,18 @@ def test_push_dataset_to_hub_out_dir_force_override_false(tmpdir):
|
|||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"required_packages, raw_format, repo_id",
|
||||
"required_packages, raw_format, repo_id, make_test_data",
|
||||
[
|
||||
(["gym-pusht"], "pusht_zarr", "lerobot/pusht"),
|
||||
(None, "xarm_pkl", "lerobot/xarm_lift_medium"),
|
||||
(None, "aloha_hdf5", "lerobot/aloha_sim_insertion_scripted"),
|
||||
(["imagecodecs"], "umi_zarr", "lerobot/umi_cup_in_the_wild"),
|
||||
(None, "dora_parquet", "cadene/wrist_gripper"),
|
||||
(["gym_pusht"], "pusht_zarr", "lerobot/pusht", False),
|
||||
(["gym_pusht"], "pusht_zarr", "lerobot/pusht", True),
|
||||
(None, "xarm_pkl", "lerobot/xarm_lift_medium", False),
|
||||
(None, "aloha_hdf5", "lerobot/aloha_sim_insertion_scripted", False),
|
||||
(["imagecodecs"], "umi_zarr", "lerobot/umi_cup_in_the_wild", False),
|
||||
(None, "dora_parquet", "cadene/wrist_gripper", False),
|
||||
],
|
||||
)
|
||||
@require_package_arg
|
||||
def test_push_dataset_to_hub_format(required_packages, tmpdir, raw_format, repo_id):
|
||||
def test_push_dataset_to_hub_format(required_packages, tmpdir, raw_format, repo_id, make_test_data):
|
||||
num_episodes = 3
|
||||
tmpdir = Path(tmpdir)
|
||||
|
||||
|
@ -278,6 +279,7 @@ def test_push_dataset_to_hub_format(required_packages, tmpdir, raw_format, repo_
|
|||
local_dir=local_dir,
|
||||
force_override=False,
|
||||
cache_dir=tmpdir / "cache",
|
||||
tests_data_dir=tmpdir / "tests/data" if make_test_data else None,
|
||||
)
|
||||
|
||||
# minimal generic tests on the local directory containing LeRobotDataset
|
||||
|
@ -299,6 +301,20 @@ def test_push_dataset_to_hub_format(required_packages, tmpdir, raw_format, repo_
|
|||
for cam_key in lerobot_dataset.camera_keys:
|
||||
assert cam_key in item
|
||||
|
||||
if make_test_data:
|
||||
# Check that only the first episode is selected.
|
||||
test_dataset = LeRobotDataset(repo_id=repo_id, root=tmpdir / "tests/data")
|
||||
num_frames = sum(
|
||||
i == lerobot_dataset.hf_dataset["episode_index"][0]
|
||||
for i in lerobot_dataset.hf_dataset["episode_index"]
|
||||
).item()
|
||||
assert (
|
||||
test_dataset.hf_dataset["episode_index"]
|
||||
== lerobot_dataset.hf_dataset["episode_index"][:num_frames]
|
||||
)
|
||||
for k in ["from", "to"]:
|
||||
assert torch.equal(test_dataset.episode_data_index[k], lerobot_dataset.episode_data_index[k][:1])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"raw_format, repo_id",
|
||||
|
|
Loading…
Reference in New Issue