117 lines
3.3 KiB
Python
117 lines
3.3 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import platform
|
|
from functools import wraps
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from lerobot.common.utils.import_utils import is_package_available
|
|
|
|
# Pass this as the first argument to init_hydra_config.
|
|
DEFAULT_CONFIG_PATH = "lerobot/configs/default.yaml"
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
def require_x86_64_kernel(func):
|
|
"""
|
|
Decorator that skips the test if plateform device is not an x86_64 cpu.
|
|
"""
|
|
from functools import wraps
|
|
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
if platform.machine() != "x86_64":
|
|
pytest.skip("requires x86_64 plateform")
|
|
return func(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
|
|
def require_cpu(func):
|
|
"""
|
|
Decorator that skips the test if device is not cpu.
|
|
"""
|
|
from functools import wraps
|
|
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
if DEVICE != "cpu":
|
|
pytest.skip("requires cpu")
|
|
return func(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
|
|
def require_cuda(func):
|
|
"""
|
|
Decorator that skips the test if cuda is not available.
|
|
"""
|
|
from functools import wraps
|
|
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
if not torch.cuda.is_available():
|
|
pytest.skip("requires cuda")
|
|
return func(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
|
|
def require_env(func):
|
|
"""
|
|
Decorator that skips the test if the required environment package is not installed.
|
|
As it need 'env_name' in args, it also checks whether it is provided as an argument.
|
|
"""
|
|
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
# Determine if 'env_name' is provided and extract its value
|
|
arg_names = func.__code__.co_varnames[: func.__code__.co_argcount]
|
|
if "env_name" in arg_names:
|
|
# Get the index of 'env_name' and retrieve the value from args
|
|
index = arg_names.index("env_name")
|
|
env_name = args[index] if len(args) > index else kwargs.get("env_name")
|
|
else:
|
|
raise ValueError("Function does not have 'env_name' as an argument.")
|
|
|
|
# Perform the package check
|
|
package_name = f"gym_{env_name}"
|
|
if not is_package_available(package_name):
|
|
pytest.skip(f"gym-{env_name} not installed")
|
|
|
|
return func(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
|
|
def require_package(package_name):
|
|
"""
|
|
Decorator that skips the test if the specified package is not installed.
|
|
"""
|
|
|
|
def decorator(func):
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
if not is_package_available(package_name):
|
|
pytest.skip(f"{package_name} not installed")
|
|
return func(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
return decorator
|