diff --git a/tests/test_policies.py b/tests/test_policies.py index 8cd4a804..e25f3f9f 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -16,7 +16,7 @@ from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.policy_protocol import Policy from lerobot.common.utils.utils import init_hydra_config from tests.scripts.save_policy_to_safetensor import get_policy_stats -from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_env +from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_env, require_x86_64_kernel @pytest.mark.parametrize("policy_name", available_policies) @@ -244,6 +244,7 @@ def test_normalize(insert_temporal_dim): ("aloha", "act", []), ], ) +@require_x86_64_kernel def test_backward_compatibility(env_name, policy_name, extra_overrides): env_policy_dir = Path("tests/data/save_policy_to_safetensors") / f"{env_name}_{policy_name}" saved_output_dict = load_file(env_policy_dir / "output_dict.safetensors") diff --git a/tests/utils.py b/tests/utils.py index 3edf055d..6a706694 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,3 +1,5 @@ +import platform + import pytest import torch @@ -9,6 +11,21 @@ DEFAULT_CONFIG_PATH = "lerobot/configs/default.yaml" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" +def require_x86_64_kernel(func): + """ + Decorator that skips the test if plateform device is not an x86_64 cpu. + """ + from functools import wraps + + @wraps(func) + def wrapper(*args, **kwargs): + if platform.machine() != "x86_64": + pytest.skip("requires x86_64 plateform") + return func(*args, **kwargs) + + return wrapper + + def require_cpu(func): """ Decorator that skips the test if device is not cpu.