Add require_env wrapper
This commit is contained in:
parent
d167e5e3c5
commit
37efcea3eb
|
@ -7,9 +7,11 @@ from lerobot.common.utils.import_utils import is_package_available
|
|||
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
|
||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
|
||||
from tests.utils import require_env
|
||||
|
||||
|
||||
@pytest.mark.parametrize("env_name, task_name", lerobot.env_task_pairs)
|
||||
@require_env
|
||||
def test_available_env_task(env_name: str, task_name: list):
|
||||
"""
|
||||
This test verifies that all environments listed in `lerobot/__init__.py` can
|
||||
|
@ -17,9 +19,6 @@ def test_available_env_task(env_name: str, task_name: list):
|
|||
`available_tasks_per_env` are valid.
|
||||
"""
|
||||
package_name = f"gym_{env_name}"
|
||||
if not is_package_available(package_name):
|
||||
pytest.skip(f"gym-{env_name} not installed")
|
||||
|
||||
importlib.import_module(package_name)
|
||||
gym_handle = f"{package_name}/{task_name}"
|
||||
assert gym_handle in gym.envs.registry.keys(), gym_handle
|
||||
|
|
|
@ -12,21 +12,19 @@ from lerobot.common.utils.utils import init_hydra_config
|
|||
import lerobot
|
||||
from lerobot.common.envs.utils import preprocess_observation
|
||||
|
||||
from .utils import DEVICE, DEFAULT_CONFIG_PATH
|
||||
from .utils import DEVICE, DEFAULT_CONFIG_PATH, require_env
|
||||
|
||||
OBS_TYPES = ["state", "pixels", "pixels_agent_pos"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("obs_type", OBS_TYPES)
|
||||
@pytest.mark.parametrize("env_name, env_task", lerobot.env_task_pairs)
|
||||
@require_env
|
||||
def test_env(env_name, env_task, obs_type):
|
||||
if env_name == "aloha" and obs_type == "state":
|
||||
pytest.skip("`state` observations not available for aloha")
|
||||
|
||||
package_name = f"gym_{env_name}"
|
||||
if not is_package_available(package_name):
|
||||
pytest.skip(f"gym-{env_name} not installed")
|
||||
|
||||
importlib.import_module(package_name)
|
||||
env = gym.make(f"{package_name}/{env_task}", obs_type=obs_type)
|
||||
check_env(env.unwrapped, skip_render_check=True)
|
||||
|
@ -34,6 +32,7 @@ def test_env(env_name, env_task, obs_type):
|
|||
|
||||
|
||||
@pytest.mark.parametrize("env_name", lerobot.available_envs)
|
||||
@require_env
|
||||
def test_factory(env_name):
|
||||
cfg = init_hydra_config(
|
||||
DEFAULT_CONFIG_PATH,
|
||||
|
|
|
@ -8,7 +8,7 @@ from lerobot.common.policies.policy_protocol import Policy
|
|||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
from .utils import DEVICE, DEFAULT_CONFIG_PATH
|
||||
from .utils import DEVICE, DEFAULT_CONFIG_PATH, require_env
|
||||
|
||||
|
||||
# TODO(aliberts): refactor using lerobot/__init__.py variables
|
||||
|
@ -24,6 +24,7 @@ from .utils import DEVICE, DEFAULT_CONFIG_PATH
|
|||
("aloha", "act", ["env.task=AlohaTransferCube-v0", "dataset_id=aloha_sim_transfer_cube_scripted"]),
|
||||
],
|
||||
)
|
||||
@require_env
|
||||
def test_policy(env_name, policy_name, extra_overrides):
|
||||
"""
|
||||
Tests:
|
||||
|
|
|
@ -1,6 +1,36 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.common.utils.import_utils import is_package_available
|
||||
|
||||
# Pass this as the first argument to init_hydra_config.
|
||||
DEFAULT_CONFIG_PATH = "lerobot/configs/default.yaml"
|
||||
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
def require_env(func):
|
||||
"""
|
||||
Decorator that skips the test if the required environment package is not installed.
|
||||
As it need 'env_name' in args, it also checks whether it is provided as an argument.
|
||||
"""
|
||||
from functools import wraps
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
# Determine if 'env_name' is provided and extract its value
|
||||
arg_names = func.__code__.co_varnames[:func.__code__.co_argcount]
|
||||
if 'env_name' in arg_names:
|
||||
# Get the index of 'env_name' and retrieve the value from args
|
||||
index = arg_names.index('env_name')
|
||||
env_name = args[index] if len(args) > index else kwargs.get('env_name')
|
||||
else:
|
||||
raise ValueError("Function does not have 'env_name' as an argument.")
|
||||
|
||||
# Perform the package check
|
||||
package_name = f"gym_{env_name}"
|
||||
if not is_package_available(package_name):
|
||||
pytest.skip(f"gym-{env_name} not installed")
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
Loading…
Reference in New Issue