WIP
This commit is contained in:
parent
276d210380
commit
3a918b980f
|
@ -1,6 +1,7 @@
|
|||
# @package _global_
|
||||
|
||||
seed: 1
|
||||
dataset_repo_id: lerobot/xarm_lift_medium_replay
|
||||
|
||||
training:
|
||||
offline_steps: 25000
|
||||
|
|
|
@ -82,9 +82,9 @@ def save_policy_to_safetensors(output_dir, env_name, policy_name):
|
|||
|
||||
if __name__ == "__main__":
|
||||
env_policies = [
|
||||
# ("xarm", "tdmpc"),
|
||||
("pusht", "diffusion"),
|
||||
("aloha", "act"),
|
||||
("xarm", "tdmpc"),
|
||||
# ("pusht", "diffusion"),
|
||||
# ("aloha", "act"),
|
||||
]
|
||||
for env, policy in env_policies:
|
||||
save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy)
|
||||
|
|
|
@ -239,7 +239,7 @@ def test_normalize(insert_temporal_dim):
|
|||
@pytest.mark.parametrize(
|
||||
"env_name,policy_name",
|
||||
[
|
||||
# ("xarm", "tdmpc"),
|
||||
("xarm", "tdmpc"),
|
||||
("pusht", "diffusion"),
|
||||
("aloha", "act"),
|
||||
],
|
||||
|
|
Loading…
Reference in New Issue