Revert "Override action queue defaults"

This reverts commit 7592c21eb2751c4ecc0fc1badb437f34b0f2155e.
This commit is contained in:
Simon Alibert 2024-05-03 09:57:36 +02:00
parent 61f38da7a3
commit 5647d71c08
3 changed files with 6 additions and 14 deletions

View File

@ -89,13 +89,9 @@ def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_override
if __name__ == "__main__": if __name__ == "__main__":
env_policies = [ env_policies = [
("xarm", "tdmpc", ["policy.n_action_repeats=2"]), ("xarm", "tdmpc", []),
( ("pusht", "diffusion", ["policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"]),
"pusht", ("aloha", "act", []),
"diffusion",
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
),
("aloha", "act", ["policy.n_action_steps=10"]),
] ]
for env, policy, extra_overrides in env_policies: for env, policy, extra_overrides in env_policies:
save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy, extra_overrides) save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy, extra_overrides)

View File

@ -239,13 +239,9 @@ def test_normalize(insert_temporal_dim):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"env_name, policy_name, extra_overrides", "env_name, policy_name, extra_overrides",
[ [
("xarm", "tdmpc", ["policy.n_action_repeats=2"]), ("xarm", "tdmpc", []),
( ("pusht", "diffusion", ["policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"]),
"pusht", ("aloha", "act", []),
"diffusion",
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
),
("aloha", "act", ["policy.n_action_steps=10"]),
], ],
) )
# As artifacts have been generated on an x86_64 kernel, this test won't # As artifacts have been generated on an x86_64 kernel, this test won't