diff --git a/tests/test_policies.py b/tests/test_policies.py index 75633fe6..bb0c7b80 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -31,7 +31,7 @@ from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.policy_protocol import Policy from lerobot.common.utils.utils import init_hydra_config from tests.scripts.save_policy_to_safetensor import get_policy_stats -from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_env, require_x86_64_kernel +from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_cpu, require_env, require_x86_64_kernel @pytest.mark.parametrize("policy_name", available_policies) @@ -296,16 +296,17 @@ def test_normalize(insert_temporal_dim): # As artifacts have been generated on an x86_64 kernel, this test won't # pass if it's run on another platform due to floating point errors @require_x86_64_kernel +@require_cpu def test_backward_compatibility(env_name, policy_name, extra_overrides): """ NOTE: If this test does not pass, and you have intentionally changed something in the policy: 1. Inspect the differences in policy outputs and make sure you can account for them. Your PR should include a report on what changed and how that affected the outputs. - 2. Go to the `if __name__ == "__main__"` block of `test/scripts/save_policy_to_safetensors.py` and + 2. Go to the `if __name__ == "__main__"` block of `tests/scripts/save_policy_to_safetensors.py` and add the policies you want to update the test artifacts for. - 3. Run `python test/scripts/save_policy_to_safetensors.py`. The test artifact should be updated. + 3. Run `python tests/scripts/save_policy_to_safetensors.py`. The test artifact should be updated. 4. Check that this test now passes. - 5. Remember to restore `test/scripts/save_policy_to_safetensors.py` to its original state. + 5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state. 6. Remember to stage and commit the resulting changes to `tests/data`. """ env_policy_dir = Path("tests/data/save_policy_to_safetensors") / f"{env_name}_{policy_name}"