From 6082a7bc7303d86bd6e5ed49def1e039fbfaf988 Mon Sep 17 00:00:00 2001 From: Cadene Date: Wed, 10 Apr 2024 13:06:48 +0000 Subject: [PATCH] Enable test_available.py --- lerobot/__init__.py | 21 +++++------ tests/test_available.py | 83 ++++++++++++++++++----------------------- 2 files changed, 45 insertions(+), 59 deletions(-) diff --git a/lerobot/__init__.py b/lerobot/__init__.py index 4673aab0..8ab95df8 100644 --- a/lerobot/__init__.py +++ b/lerobot/__init__.py @@ -12,14 +12,11 @@ Example: print(lerobot.available_policies) ``` -Note: - When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to: - 1. set the required class attributes: - - for classes inheriting from `AbstractDataset`: `available_datasets` - - for classes inheriting from `AbstractEnv`: `name`, `available_tasks` - - for classes inheriting from `AbstractPolicy`: `name` - 2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`) - 3. update variables in `tests/test_available.py` by importing your new class +When implementing a new dataset (e.g. `AlohaDataset`), policy (e.g. `DiffusionPolicy`), or environment, follow these steps: +- Set the required class attributes: `available_datasets`. +- Set the required class attributes: `name`. +- Update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`) +- Update variables in `tests/test_available.py` by importing your new class """ from lerobot.__version__ import __version__ # noqa: F401 @@ -32,11 +29,11 @@ available_envs = [ available_tasks_per_env = { "aloha": [ - "sim_insertion", - "sim_transfer_cube", + "AlohaInsertion-v0", + "AlohaTransferCube-v0", ], - "pusht": ["pusht"], - "xarm": ["lift"], + "pusht": ["PushT-v0"], + "xarm": ["XarmLift-v0"], } available_datasets_per_env = { diff --git a/tests/test_available.py b/tests/test_available.py index 8df2c945..be74a42a 100644 --- a/tests/test_available.py +++ b/tests/test_available.py @@ -1,64 +1,53 @@ """ This test verifies that all environments, datasets, policies listed in `lerobot/__init__.py` can be sucessfully -imported and that their class attributes (eg. `available_datasets`, `name`, `available_tasks`) corresponds. +imported and that their class attributes (eg. `available_datasets`, `name`, `available_tasks`) are valid. -Note: - When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to: - 1. set the required class attributes: - - for classes inheriting from `AbstractDataset`: `available_datasets` - - for classes inheriting from `AbstractEnv`: `name`, `available_tasks` - - for classes inheriting from `AbstractPolicy`: `name` - 2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`) - 3. update variables in `tests/test_available.py` by importing your new class +When implementing a new dataset (e.g. `AlohaDataset`), policy (e.g. `DiffusionPolicy`), or environment, follow these steps: +- Set the required class attributes: `available_datasets`. +- Set the required class attributes: `name`. +- Update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`) +- Update variables in `tests/test_available.py` by importing your new class """ +import importlib import pytest import lerobot +import gymnasium as gym -# from lerobot.common.envs.aloha.env import AlohaEnv -# from gym_pusht.envs import PushtEnv -# from gym_xarm.envs import SimxarmEnv +from lerobot.common.datasets.xarm import XarmDataset +from lerobot.common.datasets.aloha import AlohaDataset +from lerobot.common.datasets.pusht import PushtDataset -# from lerobot.common.datasets.xarm import SimxarmDataset -# from lerobot.common.datasets.aloha import AlohaDataset -# from lerobot.common.datasets.pusht import PushtDataset - -# from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy -# from lerobot.common.policies.diffusion.policy import DiffusionPolicy -# from lerobot.common.policies.tdmpc.policy import TDMPCPolicy +from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy +from lerobot.common.policies.diffusion.policy import DiffusionPolicy +from lerobot.common.policies.tdmpc.policy import TDMPCPolicy -# def test_available(): -# pol_classes = [ -# ActionChunkingTransformerPolicy, -# DiffusionPolicy, -# TDMPCPolicy, -# ] +def test_available(): + policy_classes = [ + ActionChunkingTransformerPolicy, + DiffusionPolicy, + TDMPCPolicy, + ] -# env_classes = [ -# AlohaEnv, -# PushtEnv, -# SimxarmEnv, -# ] - -# dat_classes = [ -# AlohaDataset, -# PushtDataset, -# SimxarmDataset, -# ] + dataset_class_per_env = { + "aloha": AlohaDataset, + "pusht": PushtDataset, + "xarm": XarmDataset, + } -# policies = [pol_cls.name for pol_cls in pol_classes] -# assert set(policies) == set(lerobot.available_policies) + policies = [pol_cls.name for pol_cls in policy_classes] + assert set(policies) == set(lerobot.available_policies), policies -# envs = [env_cls.name for env_cls in env_classes] -# assert set(envs) == set(lerobot.available_envs) + for env_name in lerobot.available_envs: + for task_name in lerobot.available_tasks_per_env[env_name]: + package_name = f"gym_{env_name}" + importlib.import_module(package_name) + gym_handle = f"{package_name}/{task_name}" + assert gym_handle in gym.envs.registry.keys(), gym_handle -# tasks_per_env = {env_cls.name: env_cls.available_tasks for env_cls in env_classes} -# for env in envs: -# assert set(tasks_per_env[env]) == set(lerobot.available_tasks_per_env[env]) - -# datasets_per_env = {env_cls.name: dat_cls.available_datasets for env_cls, dat_cls in zip(env_classes, dat_classes)} -# for env in envs: -# assert set(datasets_per_env[env]) == set(lerobot.available_datasets_per_env[env]) + dataset_class = dataset_class_per_env[env_name] + available_datasets = lerobot.available_datasets_per_env[env_name] + assert set(available_datasets) == set(dataset_class.available_datasets), f"{env_name=} {available_datasets=}"