diff --git a/tests/data/save_policy_to_safetensors/aloha_act/actions.safetensors b/tests/data/save_policy_to_safetensors/aloha_act/actions.safetensors index 105881ed..48288106 100644 Binary files a/tests/data/save_policy_to_safetensors/aloha_act/actions.safetensors and b/tests/data/save_policy_to_safetensors/aloha_act/actions.safetensors differ diff --git a/tests/scripts/save_policy_to_safetensor.py b/tests/scripts/save_policy_to_safetensor.py index 06763d72..a46b2c05 100644 --- a/tests/scripts/save_policy_to_safetensor.py +++ b/tests/scripts/save_policy_to_safetensor.py @@ -89,13 +89,9 @@ def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_override if __name__ == "__main__": env_policies = [ - ("xarm", "tdmpc", ["policy.n_action_repeats=2"]), - ( - "pusht", - "diffusion", - ["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"], - ), - ("aloha", "act", ["policy.n_action_steps=10"]), + ("xarm", "tdmpc", []), + ("pusht", "diffusion", ["policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"]), + ("aloha", "act", []), ] for env, policy, extra_overrides in env_policies: save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy, extra_overrides) diff --git a/tests/test_policies.py b/tests/test_policies.py index efc33281..c61a8f26 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -239,13 +239,9 @@ def test_normalize(insert_temporal_dim): @pytest.mark.parametrize( "env_name, policy_name, extra_overrides", [ - ("xarm", "tdmpc", ["policy.n_action_repeats=2"]), - ( - "pusht", - "diffusion", - ["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"], - ), - ("aloha", "act", ["policy.n_action_steps=10"]), + ("xarm", "tdmpc", []), + ("pusht", "diffusion", ["policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"]), + ("aloha", "act", []), ], ) # As artifacts have been generated on an x86_64 kernel, this test won't