Add require_x86_64_kernel
This commit is contained in:
parent
55ff23c252
commit
590b0eb48f
|
@ -16,7 +16,7 @@ from lerobot.common.policies.normalize import Normalize, Unnormalize
|
|||
from lerobot.common.policies.policy_protocol import Policy
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
from tests.scripts.save_policy_to_safetensor import get_policy_stats
|
||||
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_env
|
||||
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_env, require_x86_64_kernel
|
||||
|
||||
|
||||
@pytest.mark.parametrize("policy_name", available_policies)
|
||||
|
@ -244,6 +244,7 @@ def test_normalize(insert_temporal_dim):
|
|||
("aloha", "act", []),
|
||||
],
|
||||
)
|
||||
@require_x86_64_kernel
|
||||
def test_backward_compatibility(env_name, policy_name, extra_overrides):
|
||||
env_policy_dir = Path("tests/data/save_policy_to_safetensors") / f"{env_name}_{policy_name}"
|
||||
saved_output_dict = load_file(env_policy_dir / "output_dict.safetensors")
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
import platform
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
@ -9,6 +11,21 @@ DEFAULT_CONFIG_PATH = "lerobot/configs/default.yaml"
|
|||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
def require_x86_64_kernel(func):
|
||||
"""
|
||||
Decorator that skips the test if plateform device is not an x86_64 cpu.
|
||||
"""
|
||||
from functools import wraps
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
if platform.machine() != "x86_64":
|
||||
pytest.skip("requires x86_64 plateform")
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def require_cpu(func):
|
||||
"""
|
||||
Decorator that skips the test if device is not cpu.
|
||||
|
|
Loading…
Reference in New Issue