Add require_x86_64_kernel
This commit is contained in:
parent
55ff23c252
commit
590b0eb48f
|
@ -16,7 +16,7 @@ from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||||
from lerobot.common.policies.policy_protocol import Policy
|
from lerobot.common.policies.policy_protocol import Policy
|
||||||
from lerobot.common.utils.utils import init_hydra_config
|
from lerobot.common.utils.utils import init_hydra_config
|
||||||
from tests.scripts.save_policy_to_safetensor import get_policy_stats
|
from tests.scripts.save_policy_to_safetensor import get_policy_stats
|
||||||
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_env
|
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_env, require_x86_64_kernel
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("policy_name", available_policies)
|
@pytest.mark.parametrize("policy_name", available_policies)
|
||||||
|
@ -244,6 +244,7 @@ def test_normalize(insert_temporal_dim):
|
||||||
("aloha", "act", []),
|
("aloha", "act", []),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@require_x86_64_kernel
|
||||||
def test_backward_compatibility(env_name, policy_name, extra_overrides):
|
def test_backward_compatibility(env_name, policy_name, extra_overrides):
|
||||||
env_policy_dir = Path("tests/data/save_policy_to_safetensors") / f"{env_name}_{policy_name}"
|
env_policy_dir = Path("tests/data/save_policy_to_safetensors") / f"{env_name}_{policy_name}"
|
||||||
saved_output_dict = load_file(env_policy_dir / "output_dict.safetensors")
|
saved_output_dict = load_file(env_policy_dir / "output_dict.safetensors")
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
import platform
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
@ -9,6 +11,21 @@ DEFAULT_CONFIG_PATH = "lerobot/configs/default.yaml"
|
||||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
def require_x86_64_kernel(func):
|
||||||
|
"""
|
||||||
|
Decorator that skips the test if plateform device is not an x86_64 cpu.
|
||||||
|
"""
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
if platform.machine() != "x86_64":
|
||||||
|
pytest.skip("requires x86_64 plateform")
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def require_cpu(func):
|
def require_cpu(func):
|
||||||
"""
|
"""
|
||||||
Decorator that skips the test if device is not cpu.
|
Decorator that skips the test if device is not cpu.
|
||||||
|
|
Loading…
Reference in New Issue