From 00cefb44bc9b01cd1275bbf4d1168fc68d3049b6 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Thu, 30 May 2024 17:17:37 +0100 Subject: [PATCH] ready for review --- tests/test_datasets.py | 37 +++++++++++++++++++++++++++++++------ 1 file changed, 31 insertions(+), 6 deletions(-) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index e01fc52c..bd63e568 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -16,6 +16,7 @@ import json import logging from copy import deepcopy +from itertools import chain from pathlib import Path import einops @@ -31,7 +32,7 @@ from lerobot.common.datasets.compute_stats import ( get_stats_einops_patterns, ) from lerobot.common.datasets.factory import make_dataset -from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset from lerobot.common.datasets.utils import ( flatten_dict, hf_transform_to_torch, @@ -42,11 +43,7 @@ from lerobot.common.utils.utils import init_hydra_config, seeded_context from tests.utils import DEFAULT_CONFIG_PATH, DEVICE -@pytest.mark.parametrize( - "env_name, repo_id, policy_name", - lerobot.env_dataset_policy_triplets - + [("aloha", ["lerobot/aloha_sim_insertion_human", "lerobot/aloha_sim_transfer_cube_human"], "act")], -) +@pytest.mark.parametrize("env_name, repo_id, policy_name", lerobot.env_dataset_policy_triplets) def test_factory(env_name, repo_id, policy_name): """ Tests that: @@ -113,6 +110,30 @@ def test_factory(env_name, repo_id, policy_name): assert key in item, f"{key}" +def test_multilerobotdataset_frames(): + """Check that all dataset frames are incorporated.""" + # Note: use the image variants of the dataset to make the test approx 3x faster. + repo_ids = ["lerobot/aloha_sim_insertion_human_image", "lerobot/aloha_sim_transfer_cube_human_image"] + sub_datasets = [LeRobotDataset(repo_id, root="tests/data") for repo_id in repo_ids] + dataset = MultiLeRobotDataset(repo_ids, root="tests/data") + assert len(dataset) == sum(len(d) for d in sub_datasets) + assert dataset.num_samples == sum(d.num_samples for d in sub_datasets) + assert dataset.num_episodes == sum(d.num_episodes for d in sub_datasets) + # Run through all items of the LeRobotDatasets in parallel with the items of the MultiLerobotDataset and + # check they match. + expected_dataset_indices = [] + for i, sub_dataset in enumerate(sub_datasets): + expected_dataset_indices.extend([i] * len(sub_dataset)) + for expected_dataset_index, sub_dataset_item, dataset_item in zip( + expected_dataset_indices, chain(*sub_datasets), dataset, strict=True + ): + dataset_index = dataset_item.pop("dataset_index") + assert dataset_index == expected_dataset_index + assert sub_dataset_item.keys() == dataset_item.keys() + for k in sub_dataset_item: + assert torch.equal(sub_dataset_item[k], dataset_item[k]) + + def test_compute_stats_on_xarm(): """Check that the statistics are computed correctly according to the stats_patterns property. @@ -351,3 +372,7 @@ def test_aggregate_stats(): for agg_fn in ["mean", "min", "max"]: assert torch.allclose(stats[data_key][agg_fn], einops.reduce(data, "n -> 1", agg_fn)) assert torch.allclose(stats[data_key]["std"], torch.std(data, correction=0)) + + +if __name__ == "__main__": + test_multilerobotdataset_frames()