require cpu for test_policies

This commit is contained in:
Remi Cadene 2024-05-19 19:08:32 +00:00
parent 780bf5d130
commit ed50e519bc
1 changed files with 5 additions and 4 deletions

View File

@ -31,7 +31,7 @@ from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.policy_protocol import Policy from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.utils.utils import init_hydra_config from lerobot.common.utils.utils import init_hydra_config
from tests.scripts.save_policy_to_safetensor import get_policy_stats from tests.scripts.save_policy_to_safetensor import get_policy_stats
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_env, require_x86_64_kernel from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_cpu, require_env, require_x86_64_kernel
@pytest.mark.parametrize("policy_name", available_policies) @pytest.mark.parametrize("policy_name", available_policies)
@ -296,16 +296,17 @@ def test_normalize(insert_temporal_dim):
# As artifacts have been generated on an x86_64 kernel, this test won't # As artifacts have been generated on an x86_64 kernel, this test won't
# pass if it's run on another platform due to floating point errors # pass if it's run on another platform due to floating point errors
@require_x86_64_kernel @require_x86_64_kernel
@require_cpu
def test_backward_compatibility(env_name, policy_name, extra_overrides): def test_backward_compatibility(env_name, policy_name, extra_overrides):
""" """
NOTE: If this test does not pass, and you have intentionally changed something in the policy: NOTE: If this test does not pass, and you have intentionally changed something in the policy:
1. Inspect the differences in policy outputs and make sure you can account for them. Your PR should 1. Inspect the differences in policy outputs and make sure you can account for them. Your PR should
include a report on what changed and how that affected the outputs. include a report on what changed and how that affected the outputs.
2. Go to the `if __name__ == "__main__"` block of `test/scripts/save_policy_to_safetensors.py` and 2. Go to the `if __name__ == "__main__"` block of `tests/scripts/save_policy_to_safetensors.py` and
add the policies you want to update the test artifacts for. add the policies you want to update the test artifacts for.
3. Run `python test/scripts/save_policy_to_safetensors.py`. The test artifact should be updated. 3. Run `python tests/scripts/save_policy_to_safetensors.py`. The test artifact should be updated.
4. Check that this test now passes. 4. Check that this test now passes.
5. Remember to restore `test/scripts/save_policy_to_safetensors.py` to its original state. 5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state.
6. Remember to stage and commit the resulting changes to `tests/data`. 6. Remember to stage and commit the resulting changes to `tests/data`.
""" """
env_policy_dir = Path("tests/data/save_policy_to_safetensors") / f"{env_name}_{policy_name}" env_policy_dir = Path("tests/data/save_policy_to_safetensors") / f"{env_name}_{policy_name}"