From 04340df89ab799444970f02e52e8a8f5b4c7b209 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Thu, 18 Apr 2024 14:17:44 +0200 Subject: [PATCH] Revert to test available_datasets --- tests/test_available.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/test_available.py b/tests/test_available.py index ed3b22bf..4328ec69 100644 --- a/tests/test_available.py +++ b/tests/test_available.py @@ -4,6 +4,9 @@ import gymnasium as gym import pytest import lerobot +from lerobot.common.datasets.aloha import AlohaDataset +from lerobot.common.datasets.pusht import PushtDataset +from lerobot.common.datasets.xarm import XarmDataset from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy from lerobot.common.policies.tdmpc.policy import TDMPCPolicy @@ -24,6 +27,25 @@ def test_available_env_task(env_name: str, task_name: list): assert gym_handle in gym.envs.registry, gym_handle +@pytest.mark.parametrize( + "env_name, dataset_class", + [ + ("aloha", AlohaDataset), + ("pusht", PushtDataset), + ("xarm", XarmDataset), + ], +) +def test_available_datasets(env_name, dataset_class): + """ + This test verifies that the class attribute `available_datasets` for all + dataset classes is consistent with those listed in `lerobot/__init__.py`. + """ + available_env_datasets = lerobot.available_datasets[env_name] + assert set(available_env_datasets) == set( + dataset_class.available_datasets + ), f"{env_name=} {available_env_datasets=}" + + def test_available_policies(): """ This test verifies that the class attribute `name` for all policies is