lerobot/tests/utils.py

85 lines
2.2 KiB
Python

import platform
import pytest
import torch
from lerobot.common.utils.import_utils import is_package_available
# Pass this as the first argument to init_hydra_config.
DEFAULT_CONFIG_PATH = "lerobot/configs/default.yaml"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def require_x86_64_kernel(func):
"""
Decorator that skips the test if plateform device is not an x86_64 cpu.
"""
from functools import wraps
@wraps(func)
def wrapper(*args, **kwargs):
if platform.machine() != "x86_64":
pytest.skip("requires x86_64 plateform")
return func(*args, **kwargs)
return wrapper
def require_cpu(func):
"""
Decorator that skips the test if device is not cpu.
"""
from functools import wraps
@wraps(func)
def wrapper(*args, **kwargs):
if DEVICE != "cpu":
pytest.skip("requires cpu")
return func(*args, **kwargs)
return wrapper
def require_cuda(func):
"""
Decorator that skips the test if cuda is not available.
"""
from functools import wraps
@wraps(func)
def wrapper(*args, **kwargs):
if not torch.cuda.is_available():
pytest.skip("requires cuda")
return func(*args, **kwargs)
return wrapper
def require_env(func):
"""
Decorator that skips the test if the required environment package is not installed.
As it need 'env_name' in args, it also checks whether it is provided as an argument.
"""
from functools import wraps
@wraps(func)
def wrapper(*args, **kwargs):
# Determine if 'env_name' is provided and extract its value
arg_names = func.__code__.co_varnames[: func.__code__.co_argcount]
if "env_name" in arg_names:
# Get the index of 'env_name' and retrieve the value from args
index = arg_names.index("env_name")
env_name = args[index] if len(args) > index else kwargs.get("env_name")
else:
raise ValueError("Function does not have 'env_name' as an argument.")
# Perform the package check
package_name = f"gym_{env_name}"
if not is_package_available(package_name):
pytest.skip(f"gym-{env_name} not installed")
return func(*args, **kwargs)
return wrapper