Revert to test available_datasets
This commit is contained in:
parent
d407ce21aa
commit
04340df89a
|
@ -4,6 +4,9 @@ import gymnasium as gym
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import lerobot
|
import lerobot
|
||||||
|
from lerobot.common.datasets.aloha import AlohaDataset
|
||||||
|
from lerobot.common.datasets.pusht import PushtDataset
|
||||||
|
from lerobot.common.datasets.xarm import XarmDataset
|
||||||
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
|
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
|
||||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||||
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
|
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
|
||||||
|
@ -24,6 +27,25 @@ def test_available_env_task(env_name: str, task_name: list):
|
||||||
assert gym_handle in gym.envs.registry, gym_handle
|
assert gym_handle in gym.envs.registry, gym_handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"env_name, dataset_class",
|
||||||
|
[
|
||||||
|
("aloha", AlohaDataset),
|
||||||
|
("pusht", PushtDataset),
|
||||||
|
("xarm", XarmDataset),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_available_datasets(env_name, dataset_class):
|
||||||
|
"""
|
||||||
|
This test verifies that the class attribute `available_datasets` for all
|
||||||
|
dataset classes is consistent with those listed in `lerobot/__init__.py`.
|
||||||
|
"""
|
||||||
|
available_env_datasets = lerobot.available_datasets[env_name]
|
||||||
|
assert set(available_env_datasets) == set(
|
||||||
|
dataset_class.available_datasets
|
||||||
|
), f"{env_name=} {available_env_datasets=}"
|
||||||
|
|
||||||
|
|
||||||
def test_available_policies():
|
def test_available_policies():
|
||||||
"""
|
"""
|
||||||
This test verifies that the class attribute `name` for all policies is
|
This test verifies that the class attribute `name` for all policies is
|
||||||
|
|
Loading…
Reference in New Issue