Add require_x86_64_kernel

This commit is contained in:
Simon Alibert 2024-05-02 15:45:58 +02:00
parent 55ff23c252
commit 590b0eb48f
2 changed files with 19 additions and 1 deletions

View File

@ -16,7 +16,7 @@ from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.utils.utils import init_hydra_config
from tests.scripts.save_policy_to_safetensor import get_policy_stats
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_env
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_env, require_x86_64_kernel
@pytest.mark.parametrize("policy_name", available_policies)
@ -244,6 +244,7 @@ def test_normalize(insert_temporal_dim):
("aloha", "act", []),
],
)
@require_x86_64_kernel
def test_backward_compatibility(env_name, policy_name, extra_overrides):
env_policy_dir = Path("tests/data/save_policy_to_safetensors") / f"{env_name}_{policy_name}"
saved_output_dict = load_file(env_policy_dir / "output_dict.safetensors")

View File

@ -1,3 +1,5 @@
import platform
import pytest
import torch
@ -9,6 +11,21 @@ DEFAULT_CONFIG_PATH = "lerobot/configs/default.yaml"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def require_x86_64_kernel(func):
"""
Decorator that skips the test if plateform device is not an x86_64 cpu.
"""
from functools import wraps
@wraps(func)
def wrapper(*args, **kwargs):
if platform.machine() != "x86_64":
pytest.skip("requires x86_64 plateform")
return func(*args, **kwargs)
return wrapper
def require_cpu(func):
"""
Decorator that skips the test if device is not cpu.