Enable test_available.py

This commit is contained in:
Cadene 2024-04-10 13:06:48 +00:00
parent 7c8eb7ff19
commit 6082a7bc73
2 changed files with 45 additions and 59 deletions

View File

@ -12,14 +12,11 @@ Example:
print(lerobot.available_policies) print(lerobot.available_policies)
``` ```
Note: When implementing a new dataset (e.g. `AlohaDataset`), policy (e.g. `DiffusionPolicy`), or environment, follow these steps:
When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to: - Set the required class attributes: `available_datasets`.
1. set the required class attributes: - Set the required class attributes: `name`.
- for classes inheriting from `AbstractDataset`: `available_datasets` - Update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
- for classes inheriting from `AbstractEnv`: `name`, `available_tasks` - Update variables in `tests/test_available.py` by importing your new class
- for classes inheriting from `AbstractPolicy`: `name`
2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
3. update variables in `tests/test_available.py` by importing your new class
""" """
from lerobot.__version__ import __version__ # noqa: F401 from lerobot.__version__ import __version__ # noqa: F401
@ -32,11 +29,11 @@ available_envs = [
available_tasks_per_env = { available_tasks_per_env = {
"aloha": [ "aloha": [
"sim_insertion", "AlohaInsertion-v0",
"sim_transfer_cube", "AlohaTransferCube-v0",
], ],
"pusht": ["pusht"], "pusht": ["PushT-v0"],
"xarm": ["lift"], "xarm": ["XarmLift-v0"],
} }
available_datasets_per_env = { available_datasets_per_env = {

View File

@ -1,64 +1,53 @@
""" """
This test verifies that all environments, datasets, policies listed in `lerobot/__init__.py` can be sucessfully This test verifies that all environments, datasets, policies listed in `lerobot/__init__.py` can be sucessfully
imported and that their class attributes (eg. `available_datasets`, `name`, `available_tasks`) corresponds. imported and that their class attributes (eg. `available_datasets`, `name`, `available_tasks`) are valid.
Note: When implementing a new dataset (e.g. `AlohaDataset`), policy (e.g. `DiffusionPolicy`), or environment, follow these steps:
When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to: - Set the required class attributes: `available_datasets`.
1. set the required class attributes: - Set the required class attributes: `name`.
- for classes inheriting from `AbstractDataset`: `available_datasets` - Update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
- for classes inheriting from `AbstractEnv`: `name`, `available_tasks` - Update variables in `tests/test_available.py` by importing your new class
- for classes inheriting from `AbstractPolicy`: `name`
2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
3. update variables in `tests/test_available.py` by importing your new class
""" """
import importlib
import pytest import pytest
import lerobot import lerobot
import gymnasium as gym
# from lerobot.common.envs.aloha.env import AlohaEnv from lerobot.common.datasets.xarm import XarmDataset
# from gym_pusht.envs import PushtEnv from lerobot.common.datasets.aloha import AlohaDataset
# from gym_xarm.envs import SimxarmEnv from lerobot.common.datasets.pusht import PushtDataset
# from lerobot.common.datasets.xarm import SimxarmDataset from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy
# from lerobot.common.datasets.aloha import AlohaDataset from lerobot.common.policies.diffusion.policy import DiffusionPolicy
# from lerobot.common.datasets.pusht import PushtDataset from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
# from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy
# from lerobot.common.policies.diffusion.policy import DiffusionPolicy
# from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
# def test_available(): def test_available():
# pol_classes = [ policy_classes = [
# ActionChunkingTransformerPolicy, ActionChunkingTransformerPolicy,
# DiffusionPolicy, DiffusionPolicy,
# TDMPCPolicy, TDMPCPolicy,
# ] ]
# env_classes = [ dataset_class_per_env = {
# AlohaEnv, "aloha": AlohaDataset,
# PushtEnv, "pusht": PushtDataset,
# SimxarmEnv, "xarm": XarmDataset,
# ] }
# dat_classes = [
# AlohaDataset,
# PushtDataset,
# SimxarmDataset,
# ]
# policies = [pol_cls.name for pol_cls in pol_classes] policies = [pol_cls.name for pol_cls in policy_classes]
# assert set(policies) == set(lerobot.available_policies) assert set(policies) == set(lerobot.available_policies), policies
# envs = [env_cls.name for env_cls in env_classes] for env_name in lerobot.available_envs:
# assert set(envs) == set(lerobot.available_envs) for task_name in lerobot.available_tasks_per_env[env_name]:
package_name = f"gym_{env_name}"
importlib.import_module(package_name)
gym_handle = f"{package_name}/{task_name}"
assert gym_handle in gym.envs.registry.keys(), gym_handle
# tasks_per_env = {env_cls.name: env_cls.available_tasks for env_cls in env_classes} dataset_class = dataset_class_per_env[env_name]
# for env in envs: available_datasets = lerobot.available_datasets_per_env[env_name]
# assert set(tasks_per_env[env]) == set(lerobot.available_tasks_per_env[env]) assert set(available_datasets) == set(dataset_class.available_datasets), f"{env_name=} {available_datasets=}"
# datasets_per_env = {env_cls.name: dat_cls.available_datasets for env_cls, dat_cls in zip(env_classes, dat_classes)}
# for env in envs:
# assert set(datasets_per_env[env]) == set(lerobot.available_datasets_per_env[env])