Rename dora_aloha_real, WIP test_policies
This commit is contained in:
parent
b7b5c3b4ff
commit
671ad93b6c
|
@ -52,6 +52,7 @@ available_tasks_per_env = {
|
||||||
],
|
],
|
||||||
"pusht": ["PushT-v0"],
|
"pusht": ["PushT-v0"],
|
||||||
"xarm": ["XarmLift-v0"],
|
"xarm": ["XarmLift-v0"],
|
||||||
|
"dora": ["DoraAloha-v0", "DoraKoch-v0", "DoraReachy2-v0"],
|
||||||
}
|
}
|
||||||
available_envs = list(available_tasks_per_env.keys())
|
available_envs = list(available_tasks_per_env.keys())
|
||||||
|
|
||||||
|
@ -77,6 +78,23 @@ available_datasets_per_env = {
|
||||||
"lerobot/xarm_push_medium_image",
|
"lerobot/xarm_push_medium_image",
|
||||||
"lerobot/xarm_push_medium_replay_image",
|
"lerobot/xarm_push_medium_replay_image",
|
||||||
],
|
],
|
||||||
|
"dora": [
|
||||||
|
"lerobot/aloha_static_battery",
|
||||||
|
"lerobot/aloha_static_candy",
|
||||||
|
"lerobot/aloha_static_coffee",
|
||||||
|
"lerobot/aloha_static_coffee_new",
|
||||||
|
"lerobot/aloha_static_cups_open",
|
||||||
|
"lerobot/aloha_static_fork_pick_up",
|
||||||
|
"lerobot/aloha_static_pingpong_test",
|
||||||
|
"lerobot/aloha_static_pro_pencil",
|
||||||
|
"lerobot/aloha_static_screw_driver",
|
||||||
|
"lerobot/aloha_static_tape",
|
||||||
|
"lerobot/aloha_static_thread_velcro",
|
||||||
|
"lerobot/aloha_static_towel",
|
||||||
|
"lerobot/aloha_static_vinh_cup",
|
||||||
|
"lerobot/aloha_static_vinh_cup_left",
|
||||||
|
"lerobot/aloha_static_ziploc_slide",
|
||||||
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
available_real_world_datasets = [
|
available_real_world_datasets = [
|
||||||
|
@ -116,7 +134,7 @@ available_policies = [
|
||||||
|
|
||||||
available_policies_per_env = {
|
available_policies_per_env = {
|
||||||
"aloha": ["act"],
|
"aloha": ["act"],
|
||||||
"aloha_real": ["act"],
|
"dora": ["act"],
|
||||||
"pusht": ["diffusion"],
|
"pusht": ["diffusion"],
|
||||||
"xarm": ["tdmpc"],
|
"xarm": ["tdmpc"],
|
||||||
}
|
}
|
||||||
|
|
|
@ -75,15 +75,16 @@ def get_policy_stats(env_name, policy_name, extra_overrides):
|
||||||
# HACK: We reload a batch with no delta_timestamps as `select_action` won't expect a timestamps dimension
|
# HACK: We reload a batch with no delta_timestamps as `select_action` won't expect a timestamps dimension
|
||||||
dataset.delta_timestamps = None
|
dataset.delta_timestamps = None
|
||||||
batch = next(iter(dataloader))
|
batch = next(iter(dataloader))
|
||||||
obs = {
|
obs = {}
|
||||||
k: batch[k]
|
for k in batch:
|
||||||
for k in batch
|
if "observation" in k:
|
||||||
if k in ["observation.image", "observation.images.top", "observation.state"]
|
obs[k] = batch[k]
|
||||||
}
|
|
||||||
|
if "n_action_steps" in cfg.policy:
|
||||||
|
actions_queue = cfg.policy.n_action_steps
|
||||||
|
else:
|
||||||
|
actions_queue = cfg.policy.n_action_repeats
|
||||||
|
|
||||||
actions_queue = (
|
|
||||||
cfg.policy.n_action_steps if "n_action_steps" in cfg.policy else cfg.policy.n_action_repeats
|
|
||||||
)
|
|
||||||
actions = {str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)}
|
actions = {str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)}
|
||||||
return output_dict, grad_stats, param_stats, actions
|
return output_dict, grad_stats, param_stats, actions
|
||||||
|
|
||||||
|
@ -114,6 +115,8 @@ if __name__ == "__main__":
|
||||||
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
|
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
|
||||||
),
|
),
|
||||||
("aloha", "act", ["policy.n_action_steps=10"]),
|
("aloha", "act", ["policy.n_action_steps=10"]),
|
||||||
|
("dora_aloha_real", "act_real", []),
|
||||||
|
("dora_aloha_real", "act_real_no_state", []),
|
||||||
]
|
]
|
||||||
for env, policy, extra_overrides in env_policies:
|
for env, policy, extra_overrides in env_policies:
|
||||||
save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy, extra_overrides)
|
save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy, extra_overrides)
|
|
@ -30,7 +30,7 @@ from lerobot.common.policies.factory import get_policy_and_config_classes, make_
|
||||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||||
from lerobot.common.policies.policy_protocol import Policy
|
from lerobot.common.policies.policy_protocol import Policy
|
||||||
from lerobot.common.utils.utils import init_hydra_config
|
from lerobot.common.utils.utils import init_hydra_config
|
||||||
from tests.scripts.save_policy_to_safetensor import get_policy_stats
|
from tests.scripts.save_policy_to_safetensors import get_policy_stats
|
||||||
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_cpu, require_env, require_x86_64_kernel
|
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_cpu, require_env, require_x86_64_kernel
|
||||||
|
|
||||||
|
|
||||||
|
@ -72,6 +72,8 @@ def test_get_policy_and_config_classes(policy_name: str):
|
||||||
),
|
),
|
||||||
# Note: these parameters also need custom logic in the test function for overriding the Hydra config.
|
# Note: these parameters also need custom logic in the test function for overriding the Hydra config.
|
||||||
("pusht", "act", ["env.task=PushT-v0", "dataset_repo_id=lerobot/pusht"]),
|
("pusht", "act", ["env.task=PushT-v0", "dataset_repo_id=lerobot/pusht"]),
|
||||||
|
("dora_aloha_real", "act_real", []),
|
||||||
|
("dora_aloha_real", "act_real_no_state", []),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@require_env
|
@require_env
|
||||||
|
@ -291,6 +293,8 @@ def test_normalize(insert_temporal_dim):
|
||||||
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
|
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
|
||||||
),
|
),
|
||||||
("aloha", "act", ["policy.n_action_steps=10"]),
|
("aloha", "act", ["policy.n_action_steps=10"]),
|
||||||
|
("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]),
|
||||||
|
("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
# As artifacts have been generated on an x86_64 kernel, this test won't
|
# As artifacts have been generated on an x86_64 kernel, this test won't
|
||||||
|
|
Loading…
Reference in New Issue