Revert to test available_datasets
This commit is contained in:
parent
d407ce21aa
commit
04340df89a
|
@ -4,6 +4,9 @@ import gymnasium as gym
|
|||
import pytest
|
||||
|
||||
import lerobot
|
||||
from lerobot.common.datasets.aloha import AlohaDataset
|
||||
from lerobot.common.datasets.pusht import PushtDataset
|
||||
from lerobot.common.datasets.xarm import XarmDataset
|
||||
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
|
||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
|
||||
|
@ -24,6 +27,25 @@ def test_available_env_task(env_name: str, task_name: list):
|
|||
assert gym_handle in gym.envs.registry, gym_handle
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env_name, dataset_class",
|
||||
[
|
||||
("aloha", AlohaDataset),
|
||||
("pusht", PushtDataset),
|
||||
("xarm", XarmDataset),
|
||||
],
|
||||
)
|
||||
def test_available_datasets(env_name, dataset_class):
|
||||
"""
|
||||
This test verifies that the class attribute `available_datasets` for all
|
||||
dataset classes is consistent with those listed in `lerobot/__init__.py`.
|
||||
"""
|
||||
available_env_datasets = lerobot.available_datasets[env_name]
|
||||
assert set(available_env_datasets) == set(
|
||||
dataset_class.available_datasets
|
||||
), f"{env_name=} {available_env_datasets=}"
|
||||
|
||||
|
||||
def test_available_policies():
|
||||
"""
|
||||
This test verifies that the class attribute `name` for all policies is
|
||||
|
|
Loading…
Reference in New Issue