From 293bdc7f677c26bae4a0c81301a2fc04940a42fe Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Fri, 1 Nov 2024 19:55:28 +0100 Subject: [PATCH] Simplify, add test content, add todo --- tests/fixtures/dataset_factories.py | 5 +- tests/test_datasets.py | 72 ++++++++++++++--------------- 2 files changed, 38 insertions(+), 39 deletions(-) diff --git a/tests/fixtures/dataset_factories.py b/tests/fixtures/dataset_factories.py index d8d94b20..d98ae1e9 100644 --- a/tests/fixtures/dataset_factories.py +++ b/tests/fixtures/dataset_factories.py @@ -297,6 +297,7 @@ def lerobot_dataset_factory( ): def _create_lerobot_dataset( root: Path, + repo_id: str = DUMMY_REPO_ID, info_dict: dict = info, stats_dict: dict = stats, task_dicts: list[dict] = tasks, @@ -322,7 +323,7 @@ def lerobot_dataset_factory( mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version, enforce_v2=True: version mock_snapshot_download_patch.side_effect = mock_snapshot_download - return LeRobotDataset(repo_id=DUMMY_REPO_ID, root=root, **kwargs) + return LeRobotDataset(repo_id=repo_id, root=root, **kwargs) return _create_lerobot_dataset @@ -341,6 +342,7 @@ def lerobot_dataset_from_episodes_factory( total_frames: int = 150, total_tasks: int = 1, multi_task: bool = False, + repo_id: str = DUMMY_REPO_ID, **kwargs, ): info_dict = info_factory( @@ -356,6 +358,7 @@ def lerobot_dataset_from_episodes_factory( hf_dataset = hf_dataset_factory(episode_dicts=episode_dicts, task_dicts=task_dicts) return lerobot_dataset_factory( root=root, + repo_id=repo_id, info_dict=info_dict, task_dicts=task_dicts, episode_dicts=episode_dicts, diff --git a/tests/test_datasets.py b/tests/test_datasets.py index c90ec93f..c46bb51a 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -41,26 +41,23 @@ from lerobot.common.datasets.utils import ( unflatten_dict, ) from lerobot.common.utils.utils import init_hydra_config, seeded_context -from tests.fixtures.defaults import DEFAULT_FPS, DUMMY_REPO_ID, DUMMY_ROBOT_TYPE +from tests.fixtures.defaults import DUMMY_REPO_ID from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, make_robot -@pytest.fixture(scope="function") -def dataset_create(tmp_path): - robot = make_robot("koch", mock=True) - return LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=tmp_path) - - -@pytest.fixture(scope="function") -def dataset_init(lerobot_dataset_factory, tmp_path): - return lerobot_dataset_factory(root=tmp_path) - - -def test_same_attributes_defined(dataset_create, dataset_init): +def test_same_attributes_defined(lerobot_dataset_factory, tmp_path): """ Instantiate a LeRobotDataset both ways with '__init__()' and 'create()' and verify that instantiated objects have the same sets of attributes defined. """ + # Instantiate both ways + robot = make_robot("koch", mock=True) + root_create = tmp_path / "create" + dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create) + + root_init = tmp_path / "init" + dataset_init = lerobot_dataset_factory(root=root_init) + # Access the '_hub_version' cached_property in both instances to force its creation _ = dataset_init._hub_version _ = dataset_create._hub_version @@ -68,35 +65,34 @@ def test_same_attributes_defined(dataset_create, dataset_init): init_attr = set(vars(dataset_init).keys()) create_attr = set(vars(dataset_create).keys()) - assert init_attr == create_attr, "Attribute sets do not match between __init__ and .create()" + assert init_attr == create_attr def test_dataset_initialization(lerobot_dataset_from_episodes_factory, tmp_path): - total_episodes = 10 - total_frames = 400 - dataset = lerobot_dataset_from_episodes_factory( - root=tmp_path, total_episodes=total_episodes, total_frames=total_frames - ) + kwargs = { + "repo_id": DUMMY_REPO_ID, + "total_episodes": 10, + "total_frames": 400, + "episodes": [2, 5, 6], + } + dataset = lerobot_dataset_from_episodes_factory(root=tmp_path, **kwargs) - assert dataset.repo_id == DUMMY_REPO_ID - assert dataset.num_episodes == total_episodes - assert dataset.num_samples == total_frames - assert dataset.info["fps"] == DEFAULT_FPS - assert dataset.info["robot_type"] == DUMMY_ROBOT_TYPE + assert dataset.repo_id == kwargs["repo_id"] + assert dataset.total_episodes == kwargs["total_episodes"] + assert dataset.total_frames == kwargs["total_frames"] + assert dataset.episodes == kwargs["episodes"] + assert dataset.num_episodes == len(kwargs["episodes"]) + assert dataset.num_frames == len(dataset) -def test_dataset_length(dataset_init): - dataset = dataset_init - assert len(dataset) == 3 # Number of frames in the episode - - -def test_dataset_item(dataset_init): - dataset = dataset_init - item = dataset[0] - assert item["episode_index"] == 0 - assert item["frame_index"] == 0 - assert item["state"].tolist() == [1, 2, 3] - assert item["action"].tolist() == [0.1, 0.2] +# TODO(aliberts): +# - [ ] test various attributes & state from init and create +# - [ ] test init with episodes and check num_frames +# - [ ] test add_frame +# - [ ] test add_episode +# - [ ] test consolidate +# - [ ] test push_to_hub +# - [ ] test smaller methods @pytest.mark.skip("TODO after v2 migration / removing hydra") @@ -186,7 +182,7 @@ def test_multilerobotdataset_frames(): sub_datasets = [LeRobotDataset(repo_id) for repo_id in repo_ids] dataset = MultiLeRobotDataset(repo_ids) assert len(dataset) == sum(len(d) for d in sub_datasets) - assert dataset.num_samples == sum(d.num_samples for d in sub_datasets) + assert dataset.num_frames == sum(d.num_frames for d in sub_datasets) assert dataset.num_episodes == sum(d.num_episodes for d in sub_datasets) # Run through all items of the LeRobotDatasets in parallel with the items of the MultiLerobotDataset and @@ -266,6 +262,7 @@ def test_compute_stats_on_xarm(): # assert torch.allclose(loaded_stats[k]["max"], expected_stats[k]["max"]) +# TODO(aliberts): Move to more appropriate location def test_flatten_unflatten_dict(): d = { "obs": { @@ -301,7 +298,6 @@ def test_flatten_unflatten_dict(): # "lerobot/cmu_stretch", ], ) - # TODO(rcadene, aliberts): all these tests fail locally on Mac M1, but not on Linux def test_backward_compatibility(repo_id): """The artifacts for this test have been generated by `tests/scripts/save_dataset_to_safetensors.py`."""