This commit is contained in:
Alexander Soare 2024-05-05 11:17:02 +01:00
parent 3ecf6b4f3f
commit 2741b5c59f
3 changed files with 9 additions and 11 deletions

View File

@ -1,5 +1,9 @@
# @package _global_ # @package _global_
# Defaults for training for the PushT dataset as per https://github.com/real-stanford/diffusion_policy.
# Note: We do not track EMA model weights as we discovered it does not improve the results. See
# https://github.com/huggingface/lerobot/pull/134 for more details.
seed: 100000 seed: 100000
dataset_repo_id: lerobot/pusht dataset_repo_id: lerobot/pusht

View File

@ -88,14 +88,8 @@ def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_override
if __name__ == "__main__": if __name__ == "__main__":
env_policies = [ # Instructions: include the policies that you want to save artifacts for here. Please make sure to revert
# ("xarm", "tdmpc", ["policy.n_action_repeats=2"]), # your changes when you are done.
# ( env_policies = []
# "pusht",
# "diffusion",
# ["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
# ),
# ("aloha", "act", ["policy.n_action_steps=10"]),
]
for env, policy, extra_overrides in env_policies: for env, policy, extra_overrides in env_policies:
save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy, extra_overrides) save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy, extra_overrides)

View File

@ -242,7 +242,7 @@ def test_normalize(insert_temporal_dim):
"diffusion", "diffusion",
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"], ["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
), ),
# ("aloha", "act", ["policy.n_action_steps=10"]), ("aloha", "act", ["policy.n_action_steps=10"]),
], ],
) )
# As artifacts have been generated on an x86_64 kernel, this test won't # As artifacts have been generated on an x86_64 kernel, this test won't
@ -254,7 +254,7 @@ def test_backward_compatibility(env_name, policy_name, extra_overrides):
1. Inspect the differences in policy outputs and make sure you can account for them. Your PR should 1. Inspect the differences in policy outputs and make sure you can account for them. Your PR should
include a report on what changed and how that affected the outputs. include a report on what changed and how that affected the outputs.
2. Go to the `if __name__ == "__main__"` block of `test/scripts/save_policy_to_safetensors.py` and 2. Go to the `if __name__ == "__main__"` block of `test/scripts/save_policy_to_safetensors.py` and
comment in the policies you want to update the test artifacts for. add the policies you want to update the test artifacts for.
3. Run `python test/scripts/save_policy_to_safetensors.py`. The test artifact should be updated. 3. Run `python test/scripts/save_policy_to_safetensors.py`. The test artifact should be updated.
4. Check that this test now passes. 4. Check that this test now passes.
5. Remember to restore `test/scripts/save_policy_to_safetensors.py` to its original state. 5. Remember to restore `test/scripts/save_policy_to_safetensors.py` to its original state.