Rename dora_aloha_real, WIP test_policies
This commit is contained in:
parent
03d237fe0f
commit
92d1aecb40
|
@ -55,7 +55,7 @@ available_tasks_per_env = {
|
|||
],
|
||||
"pusht": ["PushT-v0"],
|
||||
"xarm": ["XarmLift-v0"],
|
||||
"dora_aloha_real": ["DoraAloha-v0", "DoraKoch-v0", "DoraReachy2-v0"],
|
||||
"dora": ["DoraAloha-v0", "DoraKoch-v0", "DoraReachy2-v0"],
|
||||
}
|
||||
available_envs = list(available_tasks_per_env.keys())
|
||||
|
||||
|
@ -81,7 +81,7 @@ available_datasets_per_env = {
|
|||
"lerobot/xarm_push_medium_image",
|
||||
"lerobot/xarm_push_medium_replay_image",
|
||||
],
|
||||
"dora_aloha_real": [
|
||||
"dora": [
|
||||
"lerobot/aloha_static_battery",
|
||||
"lerobot/aloha_static_candy",
|
||||
"lerobot/aloha_static_coffee",
|
||||
|
@ -139,7 +139,7 @@ available_policies = [
|
|||
# keys and values refer to yaml files
|
||||
available_policies_per_env = {
|
||||
"aloha": ["act"],
|
||||
"aloha_real": ["act"],
|
||||
"dora": ["act"],
|
||||
"pusht": ["diffusion"],
|
||||
"xarm": ["tdmpc"],
|
||||
"dora_aloha_real": ["act_real"],
|
||||
|
|
|
@ -1,13 +0,0 @@
|
|||
# @package _global_
|
||||
|
||||
fps: 30
|
||||
|
||||
env:
|
||||
name: dora
|
||||
task: DoraAloha-v0
|
||||
state_dim: 14
|
||||
action_dim: 14
|
||||
fps: ${fps}
|
||||
episode_length: 400
|
||||
gym:
|
||||
fps: ${fps}
|
|
@ -77,7 +77,7 @@ def get_policy_stats(env_name, policy_name, extra_overrides):
|
|||
batch = next(iter(dataloader))
|
||||
obs = {}
|
||||
for k in batch:
|
||||
if k.startswith("observation"):
|
||||
if "observation" in k:
|
||||
obs[k] = batch[k]
|
||||
|
||||
if "n_action_steps" in cfg.policy:
|
||||
|
@ -115,8 +115,8 @@ if __name__ == "__main__":
|
|||
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
|
||||
),
|
||||
("aloha", "act", ["policy.n_action_steps=10"]),
|
||||
("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]),
|
||||
("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]),
|
||||
("dora_aloha_real", "act_real", []),
|
||||
("dora_aloha_real", "act_real_no_state", []),
|
||||
]
|
||||
for env, policy, extra_overrides in env_policies:
|
||||
save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy, extra_overrides)
|
||||
|
|
Loading…
Reference in New Issue