This commit is contained in:
Simon Alibert 2024-05-01 20:16:04 +02:00
parent 276d210380
commit 3a918b980f
3 changed files with 5 additions and 4 deletions

View File

@ -1,6 +1,7 @@
# @package _global_ # @package _global_
seed: 1 seed: 1
dataset_repo_id: lerobot/xarm_lift_medium_replay
training: training:
offline_steps: 25000 offline_steps: 25000

View File

@ -82,9 +82,9 @@ def save_policy_to_safetensors(output_dir, env_name, policy_name):
if __name__ == "__main__": if __name__ == "__main__":
env_policies = [ env_policies = [
# ("xarm", "tdmpc"), ("xarm", "tdmpc"),
("pusht", "diffusion"), # ("pusht", "diffusion"),
("aloha", "act"), # ("aloha", "act"),
] ]
for env, policy in env_policies: for env, policy in env_policies:
save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy) save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy)

View File

@ -239,7 +239,7 @@ def test_normalize(insert_temporal_dim):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"env_name,policy_name", "env_name,policy_name",
[ [
# ("xarm", "tdmpc"), ("xarm", "tdmpc"),
("pusht", "diffusion"), ("pusht", "diffusion"),
("aloha", "act"), ("aloha", "act"),
], ],