Add require_env wrapper
This commit is contained in:
parent
d167e5e3c5
commit
37efcea3eb
|
@ -7,9 +7,11 @@ from lerobot.common.utils.import_utils import is_package_available
|
||||||
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
|
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
|
||||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||||
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
|
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
|
||||||
|
from tests.utils import require_env
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("env_name, task_name", lerobot.env_task_pairs)
|
@pytest.mark.parametrize("env_name, task_name", lerobot.env_task_pairs)
|
||||||
|
@require_env
|
||||||
def test_available_env_task(env_name: str, task_name: list):
|
def test_available_env_task(env_name: str, task_name: list):
|
||||||
"""
|
"""
|
||||||
This test verifies that all environments listed in `lerobot/__init__.py` can
|
This test verifies that all environments listed in `lerobot/__init__.py` can
|
||||||
|
@ -17,9 +19,6 @@ def test_available_env_task(env_name: str, task_name: list):
|
||||||
`available_tasks_per_env` are valid.
|
`available_tasks_per_env` are valid.
|
||||||
"""
|
"""
|
||||||
package_name = f"gym_{env_name}"
|
package_name = f"gym_{env_name}"
|
||||||
if not is_package_available(package_name):
|
|
||||||
pytest.skip(f"gym-{env_name} not installed")
|
|
||||||
|
|
||||||
importlib.import_module(package_name)
|
importlib.import_module(package_name)
|
||||||
gym_handle = f"{package_name}/{task_name}"
|
gym_handle = f"{package_name}/{task_name}"
|
||||||
assert gym_handle in gym.envs.registry.keys(), gym_handle
|
assert gym_handle in gym.envs.registry.keys(), gym_handle
|
||||||
|
|
|
@ -12,21 +12,19 @@ from lerobot.common.utils.utils import init_hydra_config
|
||||||
import lerobot
|
import lerobot
|
||||||
from lerobot.common.envs.utils import preprocess_observation
|
from lerobot.common.envs.utils import preprocess_observation
|
||||||
|
|
||||||
from .utils import DEVICE, DEFAULT_CONFIG_PATH
|
from .utils import DEVICE, DEFAULT_CONFIG_PATH, require_env
|
||||||
|
|
||||||
OBS_TYPES = ["state", "pixels", "pixels_agent_pos"]
|
OBS_TYPES = ["state", "pixels", "pixels_agent_pos"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("obs_type", OBS_TYPES)
|
@pytest.mark.parametrize("obs_type", OBS_TYPES)
|
||||||
@pytest.mark.parametrize("env_name, env_task", lerobot.env_task_pairs)
|
@pytest.mark.parametrize("env_name, env_task", lerobot.env_task_pairs)
|
||||||
|
@require_env
|
||||||
def test_env(env_name, env_task, obs_type):
|
def test_env(env_name, env_task, obs_type):
|
||||||
if env_name == "aloha" and obs_type == "state":
|
if env_name == "aloha" and obs_type == "state":
|
||||||
pytest.skip("`state` observations not available for aloha")
|
pytest.skip("`state` observations not available for aloha")
|
||||||
|
|
||||||
package_name = f"gym_{env_name}"
|
package_name = f"gym_{env_name}"
|
||||||
if not is_package_available(package_name):
|
|
||||||
pytest.skip(f"gym-{env_name} not installed")
|
|
||||||
|
|
||||||
importlib.import_module(package_name)
|
importlib.import_module(package_name)
|
||||||
env = gym.make(f"{package_name}/{env_task}", obs_type=obs_type)
|
env = gym.make(f"{package_name}/{env_task}", obs_type=obs_type)
|
||||||
check_env(env.unwrapped, skip_render_check=True)
|
check_env(env.unwrapped, skip_render_check=True)
|
||||||
|
@ -34,6 +32,7 @@ def test_env(env_name, env_task, obs_type):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("env_name", lerobot.available_envs)
|
@pytest.mark.parametrize("env_name", lerobot.available_envs)
|
||||||
|
@require_env
|
||||||
def test_factory(env_name):
|
def test_factory(env_name):
|
||||||
cfg = init_hydra_config(
|
cfg = init_hydra_config(
|
||||||
DEFAULT_CONFIG_PATH,
|
DEFAULT_CONFIG_PATH,
|
||||||
|
|
|
@ -8,7 +8,7 @@ from lerobot.common.policies.policy_protocol import Policy
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
from lerobot.common.utils.utils import init_hydra_config
|
from lerobot.common.utils.utils import init_hydra_config
|
||||||
from .utils import DEVICE, DEFAULT_CONFIG_PATH
|
from .utils import DEVICE, DEFAULT_CONFIG_PATH, require_env
|
||||||
|
|
||||||
|
|
||||||
# TODO(aliberts): refactor using lerobot/__init__.py variables
|
# TODO(aliberts): refactor using lerobot/__init__.py variables
|
||||||
|
@ -24,6 +24,7 @@ from .utils import DEVICE, DEFAULT_CONFIG_PATH
|
||||||
("aloha", "act", ["env.task=AlohaTransferCube-v0", "dataset_id=aloha_sim_transfer_cube_scripted"]),
|
("aloha", "act", ["env.task=AlohaTransferCube-v0", "dataset_id=aloha_sim_transfer_cube_scripted"]),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@require_env
|
||||||
def test_policy(env_name, policy_name, extra_overrides):
|
def test_policy(env_name, policy_name, extra_overrides):
|
||||||
"""
|
"""
|
||||||
Tests:
|
Tests:
|
||||||
|
|
|
@ -1,6 +1,36 @@
|
||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from lerobot.common.utils.import_utils import is_package_available
|
||||||
|
|
||||||
# Pass this as the first argument to init_hydra_config.
|
# Pass this as the first argument to init_hydra_config.
|
||||||
DEFAULT_CONFIG_PATH = "lerobot/configs/default.yaml"
|
DEFAULT_CONFIG_PATH = "lerobot/configs/default.yaml"
|
||||||
|
|
||||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
def require_env(func):
|
||||||
|
"""
|
||||||
|
Decorator that skips the test if the required environment package is not installed.
|
||||||
|
As it need 'env_name' in args, it also checks whether it is provided as an argument.
|
||||||
|
"""
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
# Determine if 'env_name' is provided and extract its value
|
||||||
|
arg_names = func.__code__.co_varnames[:func.__code__.co_argcount]
|
||||||
|
if 'env_name' in arg_names:
|
||||||
|
# Get the index of 'env_name' and retrieve the value from args
|
||||||
|
index = arg_names.index('env_name')
|
||||||
|
env_name = args[index] if len(args) > index else kwargs.get('env_name')
|
||||||
|
else:
|
||||||
|
raise ValueError("Function does not have 'env_name' as an argument.")
|
||||||
|
|
||||||
|
# Perform the package check
|
||||||
|
package_name = f"gym_{env_name}"
|
||||||
|
if not is_package_available(package_name):
|
||||||
|
pytest.skip(f"gym-{env_name} not installed")
|
||||||
|
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
Loading…
Reference in New Issue