WIP
This commit is contained in:
parent
276d210380
commit
3a918b980f
|
@ -1,6 +1,7 @@
|
||||||
# @package _global_
|
# @package _global_
|
||||||
|
|
||||||
seed: 1
|
seed: 1
|
||||||
|
dataset_repo_id: lerobot/xarm_lift_medium_replay
|
||||||
|
|
||||||
training:
|
training:
|
||||||
offline_steps: 25000
|
offline_steps: 25000
|
||||||
|
|
|
@ -82,9 +82,9 @@ def save_policy_to_safetensors(output_dir, env_name, policy_name):
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
env_policies = [
|
env_policies = [
|
||||||
# ("xarm", "tdmpc"),
|
("xarm", "tdmpc"),
|
||||||
("pusht", "diffusion"),
|
# ("pusht", "diffusion"),
|
||||||
("aloha", "act"),
|
# ("aloha", "act"),
|
||||||
]
|
]
|
||||||
for env, policy in env_policies:
|
for env, policy in env_policies:
|
||||||
save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy)
|
save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy)
|
||||||
|
|
|
@ -239,7 +239,7 @@ def test_normalize(insert_temporal_dim):
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"env_name,policy_name",
|
"env_name,policy_name",
|
||||||
[
|
[
|
||||||
# ("xarm", "tdmpc"),
|
("xarm", "tdmpc"),
|
||||||
("pusht", "diffusion"),
|
("pusht", "diffusion"),
|
||||||
("aloha", "act"),
|
("aloha", "act"),
|
||||||
],
|
],
|
||||||
|
|
Loading…
Reference in New Issue