From a74c1100f4348d03d8d197f1efeb6ca036773f55 Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Thu, 30 May 2024 17:54:59 +0000 Subject: [PATCH] Rename dora_aloha_real, WIP test_policies --- lerobot/__init__.py | 20 ++++++++++++++++++- .../{aloha_real.yaml => dora_aloha_real.yaml} | 0 ...ensor.py => save_policy_to_safetensors.py} | 19 ++++++++++-------- tests/test_policies.py | 6 +++++- 4 files changed, 35 insertions(+), 10 deletions(-) rename lerobot/configs/env/{aloha_real.yaml => dora_aloha_real.yaml} (100%) rename tests/scripts/{save_policy_to_safetensor.py => save_policy_to_safetensors.py} (92%) diff --git a/lerobot/__init__.py b/lerobot/__init__.py index 37db0c18..495a0235 100644 --- a/lerobot/__init__.py +++ b/lerobot/__init__.py @@ -52,6 +52,7 @@ available_tasks_per_env = { ], "pusht": ["PushT-v0"], "xarm": ["XarmLift-v0"], + "dora": ["DoraAloha-v0", "DoraKoch-v0", "DoraReachy2-v0"], } available_envs = list(available_tasks_per_env.keys()) @@ -77,6 +78,23 @@ available_datasets_per_env = { "lerobot/xarm_push_medium_image", "lerobot/xarm_push_medium_replay_image", ], + "dora": [ + "lerobot/aloha_static_battery", + "lerobot/aloha_static_candy", + "lerobot/aloha_static_coffee", + "lerobot/aloha_static_coffee_new", + "lerobot/aloha_static_cups_open", + "lerobot/aloha_static_fork_pick_up", + "lerobot/aloha_static_pingpong_test", + "lerobot/aloha_static_pro_pencil", + "lerobot/aloha_static_screw_driver", + "lerobot/aloha_static_tape", + "lerobot/aloha_static_thread_velcro", + "lerobot/aloha_static_towel", + "lerobot/aloha_static_vinh_cup", + "lerobot/aloha_static_vinh_cup_left", + "lerobot/aloha_static_ziploc_slide", + ], } available_real_world_datasets = [ @@ -116,7 +134,7 @@ available_policies = [ available_policies_per_env = { "aloha": ["act"], - "aloha_real": ["act"], + "dora": ["act"], "pusht": ["diffusion"], "xarm": ["tdmpc"], } diff --git a/lerobot/configs/env/aloha_real.yaml b/lerobot/configs/env/dora_aloha_real.yaml similarity index 100% rename from lerobot/configs/env/aloha_real.yaml rename to lerobot/configs/env/dora_aloha_real.yaml diff --git a/tests/scripts/save_policy_to_safetensor.py b/tests/scripts/save_policy_to_safetensors.py similarity index 92% rename from tests/scripts/save_policy_to_safetensor.py rename to tests/scripts/save_policy_to_safetensors.py index 89f33374..f867a5e8 100644 --- a/tests/scripts/save_policy_to_safetensor.py +++ b/tests/scripts/save_policy_to_safetensors.py @@ -75,15 +75,16 @@ def get_policy_stats(env_name, policy_name, extra_overrides): # HACK: We reload a batch with no delta_timestamps as `select_action` won't expect a timestamps dimension dataset.delta_timestamps = None batch = next(iter(dataloader)) - obs = { - k: batch[k] - for k in batch - if k in ["observation.image", "observation.images.top", "observation.state"] - } + obs = {} + for k in batch: + if "observation" in k: + obs[k] = batch[k] + + if "n_action_steps" in cfg.policy: + actions_queue = cfg.policy.n_action_steps + else: + actions_queue = cfg.policy.n_action_repeats - actions_queue = ( - cfg.policy.n_action_steps if "n_action_steps" in cfg.policy else cfg.policy.n_action_repeats - ) actions = {str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)} return output_dict, grad_stats, param_stats, actions @@ -114,6 +115,8 @@ if __name__ == "__main__": ["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"], ), ("aloha", "act", ["policy.n_action_steps=10"]), + ("dora_aloha_real", "act_real", []), + ("dora_aloha_real", "act_real_no_state", []), ] for env, policy, extra_overrides in env_policies: save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy, extra_overrides) diff --git a/tests/test_policies.py b/tests/test_policies.py index bb0c7b80..6378e254 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -30,7 +30,7 @@ from lerobot.common.policies.factory import get_policy_and_config_classes, make_ from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.policy_protocol import Policy from lerobot.common.utils.utils import init_hydra_config -from tests.scripts.save_policy_to_safetensor import get_policy_stats +from tests.scripts.save_policy_to_safetensors import get_policy_stats from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_cpu, require_env, require_x86_64_kernel @@ -72,6 +72,8 @@ def test_get_policy_and_config_classes(policy_name: str): ), # Note: these parameters also need custom logic in the test function for overriding the Hydra config. ("pusht", "act", ["env.task=PushT-v0", "dataset_repo_id=lerobot/pusht"]), + ("dora_aloha_real", "act_real", []), + ("dora_aloha_real", "act_real_no_state", []), ], ) @require_env @@ -291,6 +293,8 @@ def test_normalize(insert_temporal_dim): ["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"], ), ("aloha", "act", ["policy.n_action_steps=10"]), + ("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]), + ("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]), ], ) # As artifacts have been generated on an x86_64 kernel, this test won't