From 37efcea3ebe74ff23751219d51510cb8fa984a80 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Thu, 18 Apr 2024 12:44:10 +0200 Subject: [PATCH] Add require_env wrapper --- tests/test_available.py | 5 ++--- tests/test_envs.py | 7 +++---- tests/test_policies.py | 3 ++- tests/utils.py | 30 ++++++++++++++++++++++++++++++ 4 files changed, 37 insertions(+), 8 deletions(-) diff --git a/tests/test_available.py b/tests/test_available.py index a29b1c8f..675e68d2 100644 --- a/tests/test_available.py +++ b/tests/test_available.py @@ -7,9 +7,11 @@ from lerobot.common.utils.import_utils import is_package_available from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy from lerobot.common.policies.tdmpc.policy import TDMPCPolicy +from tests.utils import require_env @pytest.mark.parametrize("env_name, task_name", lerobot.env_task_pairs) +@require_env def test_available_env_task(env_name: str, task_name: list): """ This test verifies that all environments listed in `lerobot/__init__.py` can @@ -17,9 +19,6 @@ def test_available_env_task(env_name: str, task_name: list): `available_tasks_per_env` are valid. """ package_name = f"gym_{env_name}" - if not is_package_available(package_name): - pytest.skip(f"gym-{env_name} not installed") - importlib.import_module(package_name) gym_handle = f"{package_name}/{task_name}" assert gym_handle in gym.envs.registry.keys(), gym_handle diff --git a/tests/test_envs.py b/tests/test_envs.py index ae0e2b51..75d86274 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -12,21 +12,19 @@ from lerobot.common.utils.utils import init_hydra_config import lerobot from lerobot.common.envs.utils import preprocess_observation -from .utils import DEVICE, DEFAULT_CONFIG_PATH +from .utils import DEVICE, DEFAULT_CONFIG_PATH, require_env OBS_TYPES = ["state", "pixels", "pixels_agent_pos"] @pytest.mark.parametrize("obs_type", OBS_TYPES) @pytest.mark.parametrize("env_name, env_task", lerobot.env_task_pairs) +@require_env def test_env(env_name, env_task, obs_type): if env_name == "aloha" and obs_type == "state": pytest.skip("`state` observations not available for aloha") package_name = f"gym_{env_name}" - if not is_package_available(package_name): - pytest.skip(f"gym-{env_name} not installed") - importlib.import_module(package_name) env = gym.make(f"{package_name}/{env_task}", obs_type=obs_type) check_env(env.unwrapped, skip_render_check=True) @@ -34,6 +32,7 @@ def test_env(env_name, env_task, obs_type): @pytest.mark.parametrize("env_name", lerobot.available_envs) +@require_env def test_factory(env_name): cfg = init_hydra_config( DEFAULT_CONFIG_PATH, diff --git a/tests/test_policies.py b/tests/test_policies.py index cca755c9..24b30a45 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -8,7 +8,7 @@ from lerobot.common.policies.policy_protocol import Policy from lerobot.common.envs.factory import make_env from lerobot.common.datasets.factory import make_dataset from lerobot.common.utils.utils import init_hydra_config -from .utils import DEVICE, DEFAULT_CONFIG_PATH +from .utils import DEVICE, DEFAULT_CONFIG_PATH, require_env # TODO(aliberts): refactor using lerobot/__init__.py variables @@ -24,6 +24,7 @@ from .utils import DEVICE, DEFAULT_CONFIG_PATH ("aloha", "act", ["env.task=AlohaTransferCube-v0", "dataset_id=aloha_sim_transfer_cube_scripted"]), ], ) +@require_env def test_policy(env_name, policy_name, extra_overrides): """ Tests: diff --git a/tests/utils.py b/tests/utils.py index 788e5bef..6709cde1 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,6 +1,36 @@ +import pytest import torch +from lerobot.common.utils.import_utils import is_package_available + # Pass this as the first argument to init_hydra_config. DEFAULT_CONFIG_PATH = "lerobot/configs/default.yaml" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" + +def require_env(func): + """ + Decorator that skips the test if the required environment package is not installed. + As it need 'env_name' in args, it also checks whether it is provided as an argument. + """ + from functools import wraps + + @wraps(func) + def wrapper(*args, **kwargs): + # Determine if 'env_name' is provided and extract its value + arg_names = func.__code__.co_varnames[:func.__code__.co_argcount] + if 'env_name' in arg_names: + # Get the index of 'env_name' and retrieve the value from args + index = arg_names.index('env_name') + env_name = args[index] if len(args) > index else kwargs.get('env_name') + else: + raise ValueError("Function does not have 'env_name' as an argument.") + + # Perform the package check + package_name = f"gym_{env_name}" + if not is_package_available(package_name): + pytest.skip(f"gym-{env_name} not installed") + + return func(*args, **kwargs) + + return wrapper \ No newline at end of file