require cpu for test_policies
This commit is contained in:
parent
780bf5d130
commit
ed50e519bc
|
@ -31,7 +31,7 @@ from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||||
from lerobot.common.policies.policy_protocol import Policy
|
from lerobot.common.policies.policy_protocol import Policy
|
||||||
from lerobot.common.utils.utils import init_hydra_config
|
from lerobot.common.utils.utils import init_hydra_config
|
||||||
from tests.scripts.save_policy_to_safetensor import get_policy_stats
|
from tests.scripts.save_policy_to_safetensor import get_policy_stats
|
||||||
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_env, require_x86_64_kernel
|
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_cpu, require_env, require_x86_64_kernel
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("policy_name", available_policies)
|
@pytest.mark.parametrize("policy_name", available_policies)
|
||||||
|
@ -296,16 +296,17 @@ def test_normalize(insert_temporal_dim):
|
||||||
# As artifacts have been generated on an x86_64 kernel, this test won't
|
# As artifacts have been generated on an x86_64 kernel, this test won't
|
||||||
# pass if it's run on another platform due to floating point errors
|
# pass if it's run on another platform due to floating point errors
|
||||||
@require_x86_64_kernel
|
@require_x86_64_kernel
|
||||||
|
@require_cpu
|
||||||
def test_backward_compatibility(env_name, policy_name, extra_overrides):
|
def test_backward_compatibility(env_name, policy_name, extra_overrides):
|
||||||
"""
|
"""
|
||||||
NOTE: If this test does not pass, and you have intentionally changed something in the policy:
|
NOTE: If this test does not pass, and you have intentionally changed something in the policy:
|
||||||
1. Inspect the differences in policy outputs and make sure you can account for them. Your PR should
|
1. Inspect the differences in policy outputs and make sure you can account for them. Your PR should
|
||||||
include a report on what changed and how that affected the outputs.
|
include a report on what changed and how that affected the outputs.
|
||||||
2. Go to the `if __name__ == "__main__"` block of `test/scripts/save_policy_to_safetensors.py` and
|
2. Go to the `if __name__ == "__main__"` block of `tests/scripts/save_policy_to_safetensors.py` and
|
||||||
add the policies you want to update the test artifacts for.
|
add the policies you want to update the test artifacts for.
|
||||||
3. Run `python test/scripts/save_policy_to_safetensors.py`. The test artifact should be updated.
|
3. Run `python tests/scripts/save_policy_to_safetensors.py`. The test artifact should be updated.
|
||||||
4. Check that this test now passes.
|
4. Check that this test now passes.
|
||||||
5. Remember to restore `test/scripts/save_policy_to_safetensors.py` to its original state.
|
5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state.
|
||||||
6. Remember to stage and commit the resulting changes to `tests/data`.
|
6. Remember to stage and commit the resulting changes to `tests/data`.
|
||||||
"""
|
"""
|
||||||
env_policy_dir = Path("tests/data/save_policy_to_safetensors") / f"{env_name}_{policy_name}"
|
env_policy_dir = Path("tests/data/save_policy_to_safetensors") / f"{env_name}_{policy_name}"
|
||||||
|
|
Loading…
Reference in New Issue