diff --git a/tests/test_datasets.py b/tests/test_datasets.py index afcbb3ba..6d0055a2 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -118,8 +118,8 @@ def test_multilerobotdataset_frames(): """Check that all dataset frames are incorporated.""" # Note: use the image variants of the dataset to make the test approx 3x faster. repo_ids = ["lerobot/aloha_sim_insertion_human_image", "lerobot/aloha_sim_transfer_cube_human_image"] - sub_datasets = [LeRobotDataset(repo_id, root="tests/data") for repo_id in repo_ids] - dataset = MultiLeRobotDataset(repo_ids, root="tests/data") + sub_datasets = [LeRobotDataset(repo_id) for repo_id in repo_ids] + dataset = MultiLeRobotDataset(repo_ids) assert len(dataset) == sum(len(d) for d in sub_datasets) assert dataset.num_samples == sum(d.num_samples for d in sub_datasets) assert dataset.num_episodes == sum(d.num_episodes for d in sub_datasets)