This commit is contained in:
Simon Alibert 2024-05-01 20:16:04 +02:00
parent 276d210380
commit 3a918b980f
3 changed files with 5 additions and 4 deletions

View File

@ -1,6 +1,7 @@
# @package _global_
seed: 1
dataset_repo_id: lerobot/xarm_lift_medium_replay
training:
offline_steps: 25000

View File

@ -82,9 +82,9 @@ def save_policy_to_safetensors(output_dir, env_name, policy_name):
if __name__ == "__main__":
env_policies = [
# ("xarm", "tdmpc"),
("pusht", "diffusion"),
("aloha", "act"),
("xarm", "tdmpc"),
# ("pusht", "diffusion"),
# ("aloha", "act"),
]
for env, policy in env_policies:
save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy)

View File

@ -239,7 +239,7 @@ def test_normalize(insert_temporal_dim):
@pytest.mark.parametrize(
"env_name,policy_name",
[
# ("xarm", "tdmpc"),
("xarm", "tdmpc"),
("pusht", "diffusion"),
("aloha", "act"),
],