Rename dora_aloha_real, WIP test_policies

This commit is contained in:
Remi Cadene 2024-05-30 17:54:59 +00:00
parent b7b5c3b4ff
commit 671ad93b6c
4 changed files with 35 additions and 10 deletions

View File

@ -52,6 +52,7 @@ available_tasks_per_env = {
], ],
"pusht": ["PushT-v0"], "pusht": ["PushT-v0"],
"xarm": ["XarmLift-v0"], "xarm": ["XarmLift-v0"],
"dora": ["DoraAloha-v0", "DoraKoch-v0", "DoraReachy2-v0"],
} }
available_envs = list(available_tasks_per_env.keys()) available_envs = list(available_tasks_per_env.keys())
@ -77,6 +78,23 @@ available_datasets_per_env = {
"lerobot/xarm_push_medium_image", "lerobot/xarm_push_medium_image",
"lerobot/xarm_push_medium_replay_image", "lerobot/xarm_push_medium_replay_image",
], ],
"dora": [
"lerobot/aloha_static_battery",
"lerobot/aloha_static_candy",
"lerobot/aloha_static_coffee",
"lerobot/aloha_static_coffee_new",
"lerobot/aloha_static_cups_open",
"lerobot/aloha_static_fork_pick_up",
"lerobot/aloha_static_pingpong_test",
"lerobot/aloha_static_pro_pencil",
"lerobot/aloha_static_screw_driver",
"lerobot/aloha_static_tape",
"lerobot/aloha_static_thread_velcro",
"lerobot/aloha_static_towel",
"lerobot/aloha_static_vinh_cup",
"lerobot/aloha_static_vinh_cup_left",
"lerobot/aloha_static_ziploc_slide",
],
} }
available_real_world_datasets = [ available_real_world_datasets = [
@ -116,7 +134,7 @@ available_policies = [
available_policies_per_env = { available_policies_per_env = {
"aloha": ["act"], "aloha": ["act"],
"aloha_real": ["act"], "dora": ["act"],
"pusht": ["diffusion"], "pusht": ["diffusion"],
"xarm": ["tdmpc"], "xarm": ["tdmpc"],
} }

View File

@ -75,15 +75,16 @@ def get_policy_stats(env_name, policy_name, extra_overrides):
# HACK: We reload a batch with no delta_timestamps as `select_action` won't expect a timestamps dimension # HACK: We reload a batch with no delta_timestamps as `select_action` won't expect a timestamps dimension
dataset.delta_timestamps = None dataset.delta_timestamps = None
batch = next(iter(dataloader)) batch = next(iter(dataloader))
obs = { obs = {}
k: batch[k] for k in batch:
for k in batch if "observation" in k:
if k in ["observation.image", "observation.images.top", "observation.state"] obs[k] = batch[k]
}
if "n_action_steps" in cfg.policy:
actions_queue = cfg.policy.n_action_steps
else:
actions_queue = cfg.policy.n_action_repeats
actions_queue = (
cfg.policy.n_action_steps if "n_action_steps" in cfg.policy else cfg.policy.n_action_repeats
)
actions = {str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)} actions = {str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)}
return output_dict, grad_stats, param_stats, actions return output_dict, grad_stats, param_stats, actions
@ -114,6 +115,8 @@ if __name__ == "__main__":
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"], ["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
), ),
("aloha", "act", ["policy.n_action_steps=10"]), ("aloha", "act", ["policy.n_action_steps=10"]),
("dora_aloha_real", "act_real", []),
("dora_aloha_real", "act_real_no_state", []),
] ]
for env, policy, extra_overrides in env_policies: for env, policy, extra_overrides in env_policies:
save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy, extra_overrides) save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy, extra_overrides)

View File

@ -30,7 +30,7 @@ from lerobot.common.policies.factory import get_policy_and_config_classes, make_
from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.policy_protocol import Policy from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.utils.utils import init_hydra_config from lerobot.common.utils.utils import init_hydra_config
from tests.scripts.save_policy_to_safetensor import get_policy_stats from tests.scripts.save_policy_to_safetensors import get_policy_stats
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_cpu, require_env, require_x86_64_kernel from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_cpu, require_env, require_x86_64_kernel
@ -72,6 +72,8 @@ def test_get_policy_and_config_classes(policy_name: str):
), ),
# Note: these parameters also need custom logic in the test function for overriding the Hydra config. # Note: these parameters also need custom logic in the test function for overriding the Hydra config.
("pusht", "act", ["env.task=PushT-v0", "dataset_repo_id=lerobot/pusht"]), ("pusht", "act", ["env.task=PushT-v0", "dataset_repo_id=lerobot/pusht"]),
("dora_aloha_real", "act_real", []),
("dora_aloha_real", "act_real_no_state", []),
], ],
) )
@require_env @require_env
@ -291,6 +293,8 @@ def test_normalize(insert_temporal_dim):
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"], ["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
), ),
("aloha", "act", ["policy.n_action_steps=10"]), ("aloha", "act", ["policy.n_action_steps=10"]),
("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]),
("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]),
], ],
) )
# As artifacts have been generated on an x86_64 kernel, this test won't # As artifacts have been generated on an x86_64 kernel, this test won't