Rename dora_aloha_real, WIP test_policies

This commit is contained in:
Remi Cadene 2024-05-30 17:54:59 +00:00 committed by Thomas Wolf
parent 03d237fe0f
commit 92d1aecb40
3 changed files with 6 additions and 19 deletions

View File

@ -55,7 +55,7 @@ available_tasks_per_env = {
],
"pusht": ["PushT-v0"],
"xarm": ["XarmLift-v0"],
"dora_aloha_real": ["DoraAloha-v0", "DoraKoch-v0", "DoraReachy2-v0"],
"dora": ["DoraAloha-v0", "DoraKoch-v0", "DoraReachy2-v0"],
}
available_envs = list(available_tasks_per_env.keys())
@ -81,7 +81,7 @@ available_datasets_per_env = {
"lerobot/xarm_push_medium_image",
"lerobot/xarm_push_medium_replay_image",
],
"dora_aloha_real": [
"dora": [
"lerobot/aloha_static_battery",
"lerobot/aloha_static_candy",
"lerobot/aloha_static_coffee",
@ -139,7 +139,7 @@ available_policies = [
# keys and values refer to yaml files
available_policies_per_env = {
"aloha": ["act"],
"aloha_real": ["act"],
"dora": ["act"],
"pusht": ["diffusion"],
"xarm": ["tdmpc"],
"dora_aloha_real": ["act_real"],

View File

@ -1,13 +0,0 @@
# @package _global_
fps: 30
env:
name: dora
task: DoraAloha-v0
state_dim: 14
action_dim: 14
fps: ${fps}
episode_length: 400
gym:
fps: ${fps}

View File

@ -77,7 +77,7 @@ def get_policy_stats(env_name, policy_name, extra_overrides):
batch = next(iter(dataloader))
obs = {}
for k in batch:
if k.startswith("observation"):
if "observation" in k:
obs[k] = batch[k]
if "n_action_steps" in cfg.policy:
@ -115,8 +115,8 @@ if __name__ == "__main__":
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
),
("aloha", "act", ["policy.n_action_steps=10"]),
("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]),
("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]),
("dora_aloha_real", "act_real", []),
("dora_aloha_real", "act_real_no_state", []),
]
for env, policy, extra_overrides in env_policies:
save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy, extra_overrides)