lerobot/tests/test_available.py

65 lines
2.4 KiB
Python
Raw Normal View History

"""
This test verifies that all environments, datasets, policies listed in `lerobot/__init__.py` can be sucessfully
imported and that their class attributes (eg. `available_datasets`, `name`, `available_tasks`) corresponds.
Note:
When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
1. set the required class attributes:
- for classes inheriting from `AbstractDataset`: `available_datasets`
- for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
- for classes inheriting from `AbstractPolicy`: `name`
2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
3. update variables in `tests/test_available.py` by importing your new class
"""
import pytest
import lerobot
2024-04-06 07:27:12 +08:00
# from lerobot.common.envs.aloha.env import AlohaEnv
# from gym_pusht.envs import PushtEnv
# from gym_xarm.envs import SimxarmEnv
2024-04-06 07:27:12 +08:00
# from lerobot.common.datasets.simxarm import SimxarmDataset
# from lerobot.common.datasets.aloha import AlohaDataset
# from lerobot.common.datasets.pusht import PushtDataset
2024-04-06 07:27:12 +08:00
# from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy
# from lerobot.common.policies.diffusion.policy import DiffusionPolicy
# from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
2024-04-06 07:27:12 +08:00
# def test_available():
# pol_classes = [
# ActionChunkingTransformerPolicy,
# DiffusionPolicy,
# TDMPCPolicy,
# ]
2024-04-06 07:27:12 +08:00
# env_classes = [
# AlohaEnv,
# PushtEnv,
# SimxarmEnv,
# ]
2024-04-06 07:27:12 +08:00
# dat_classes = [
# AlohaDataset,
# PushtDataset,
# SimxarmDataset,
# ]
2024-04-06 07:27:12 +08:00
# policies = [pol_cls.name for pol_cls in pol_classes]
# assert set(policies) == set(lerobot.available_policies)
2024-04-06 07:27:12 +08:00
# envs = [env_cls.name for env_cls in env_classes]
# assert set(envs) == set(lerobot.available_envs)
2024-04-06 07:27:12 +08:00
# tasks_per_env = {env_cls.name: env_cls.available_tasks for env_cls in env_classes}
# for env in envs:
# assert set(tasks_per_env[env]) == set(lerobot.available_tasks_per_env[env])
2024-04-06 07:27:12 +08:00
# datasets_per_env = {env_cls.name: dat_cls.available_datasets for env_cls, dat_cls in zip(env_classes, dat_classes)}
# for env in envs:
# assert set(datasets_per_env[env]) == set(lerobot.available_datasets_per_env[env])