Add require_env wrapper

This commit is contained in:
Simon Alibert 2024-04-18 12:44:10 +02:00
parent d167e5e3c5
commit 37efcea3eb
4 changed files with 37 additions and 8 deletions

View File

@ -7,9 +7,11 @@ from lerobot.common.utils.import_utils import is_package_available
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
from tests.utils import require_env
@pytest.mark.parametrize("env_name, task_name", lerobot.env_task_pairs)
@require_env
def test_available_env_task(env_name: str, task_name: list):
"""
This test verifies that all environments listed in `lerobot/__init__.py` can
@ -17,9 +19,6 @@ def test_available_env_task(env_name: str, task_name: list):
`available_tasks_per_env` are valid.
"""
package_name = f"gym_{env_name}"
if not is_package_available(package_name):
pytest.skip(f"gym-{env_name} not installed")
importlib.import_module(package_name)
gym_handle = f"{package_name}/{task_name}"
assert gym_handle in gym.envs.registry.keys(), gym_handle

View File

@ -12,21 +12,19 @@ from lerobot.common.utils.utils import init_hydra_config
import lerobot
from lerobot.common.envs.utils import preprocess_observation
from .utils import DEVICE, DEFAULT_CONFIG_PATH
from .utils import DEVICE, DEFAULT_CONFIG_PATH, require_env
OBS_TYPES = ["state", "pixels", "pixels_agent_pos"]
@pytest.mark.parametrize("obs_type", OBS_TYPES)
@pytest.mark.parametrize("env_name, env_task", lerobot.env_task_pairs)
@require_env
def test_env(env_name, env_task, obs_type):
if env_name == "aloha" and obs_type == "state":
pytest.skip("`state` observations not available for aloha")
package_name = f"gym_{env_name}"
if not is_package_available(package_name):
pytest.skip(f"gym-{env_name} not installed")
importlib.import_module(package_name)
env = gym.make(f"{package_name}/{env_task}", obs_type=obs_type)
check_env(env.unwrapped, skip_render_check=True)
@ -34,6 +32,7 @@ def test_env(env_name, env_task, obs_type):
@pytest.mark.parametrize("env_name", lerobot.available_envs)
@require_env
def test_factory(env_name):
cfg = init_hydra_config(
DEFAULT_CONFIG_PATH,

View File

@ -8,7 +8,7 @@ from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.envs.factory import make_env
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.utils.utils import init_hydra_config
from .utils import DEVICE, DEFAULT_CONFIG_PATH
from .utils import DEVICE, DEFAULT_CONFIG_PATH, require_env
# TODO(aliberts): refactor using lerobot/__init__.py variables
@ -24,6 +24,7 @@ from .utils import DEVICE, DEFAULT_CONFIG_PATH
("aloha", "act", ["env.task=AlohaTransferCube-v0", "dataset_id=aloha_sim_transfer_cube_scripted"]),
],
)
@require_env
def test_policy(env_name, policy_name, extra_overrides):
"""
Tests:

View File

@ -1,6 +1,36 @@
import pytest
import torch
from lerobot.common.utils.import_utils import is_package_available
# Pass this as the first argument to init_hydra_config.
DEFAULT_CONFIG_PATH = "lerobot/configs/default.yaml"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def require_env(func):
"""
Decorator that skips the test if the required environment package is not installed.
As it need 'env_name' in args, it also checks whether it is provided as an argument.
"""
from functools import wraps
@wraps(func)
def wrapper(*args, **kwargs):
# Determine if 'env_name' is provided and extract its value
arg_names = func.__code__.co_varnames[:func.__code__.co_argcount]
if 'env_name' in arg_names:
# Get the index of 'env_name' and retrieve the value from args
index = arg_names.index('env_name')
env_name = args[index] if len(args) > index else kwargs.get('env_name')
else:
raise ValueError("Function does not have 'env_name' as an argument.")
# Perform the package check
package_name = f"gym_{env_name}"
if not is_package_available(package_name):
pytest.skip(f"gym-{env_name} not installed")
return func(*args, **kwargs)
return wrapper