This commit is contained in:
Alexander Soare 2024-05-05 11:17:02 +01:00
parent 3ecf6b4f3f
commit 2741b5c59f
3 changed files with 9 additions and 11 deletions

View File

@ -1,5 +1,9 @@
# @package _global_
# Defaults for training for the PushT dataset as per https://github.com/real-stanford/diffusion_policy.
# Note: We do not track EMA model weights as we discovered it does not improve the results. See
# https://github.com/huggingface/lerobot/pull/134 for more details.
seed: 100000
dataset_repo_id: lerobot/pusht

View File

@ -88,14 +88,8 @@ def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_override
if __name__ == "__main__":
env_policies = [
# ("xarm", "tdmpc", ["policy.n_action_repeats=2"]),
# (
# "pusht",
# "diffusion",
# ["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
# ),
# ("aloha", "act", ["policy.n_action_steps=10"]),
]
# Instructions: include the policies that you want to save artifacts for here. Please make sure to revert
# your changes when you are done.
env_policies = []
for env, policy, extra_overrides in env_policies:
save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy, extra_overrides)

View File

@ -242,7 +242,7 @@ def test_normalize(insert_temporal_dim):
"diffusion",
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
),
# ("aloha", "act", ["policy.n_action_steps=10"]),
("aloha", "act", ["policy.n_action_steps=10"]),
],
)
# As artifacts have been generated on an x86_64 kernel, this test won't
@ -254,7 +254,7 @@ def test_backward_compatibility(env_name, policy_name, extra_overrides):
1. Inspect the differences in policy outputs and make sure you can account for them. Your PR should
include a report on what changed and how that affected the outputs.
2. Go to the `if __name__ == "__main__"` block of `test/scripts/save_policy_to_safetensors.py` and
comment in the policies you want to update the test artifacts for.
add the policies you want to update the test artifacts for.
3. Run `python test/scripts/save_policy_to_safetensors.py`. The test artifact should be updated.
4. Check that this test now passes.
5. Remember to restore `test/scripts/save_policy_to_safetensors.py` to its original state.