fix unit tests
This commit is contained in:
parent
d503f6294f
commit
b6d80052c4
|
@ -45,6 +45,9 @@ import itertools
|
||||||
|
|
||||||
from lerobot.__version__ import __version__ # noqa: F401
|
from lerobot.__version__ import __version__ # noqa: F401
|
||||||
|
|
||||||
|
# TODO(rcadene): Improve policies and envs. As of now, an item in `available_policies`
|
||||||
|
# refers to a yaml file AND a modeling name. Same for `available_envs` which refers to
|
||||||
|
# a yaml file AND a environment name. The difference should be more obvious.
|
||||||
available_tasks_per_env = {
|
available_tasks_per_env = {
|
||||||
"aloha": [
|
"aloha": [
|
||||||
"AlohaInsertion-v0",
|
"AlohaInsertion-v0",
|
||||||
|
@ -52,7 +55,7 @@ available_tasks_per_env = {
|
||||||
],
|
],
|
||||||
"pusht": ["PushT-v0"],
|
"pusht": ["PushT-v0"],
|
||||||
"xarm": ["XarmLift-v0"],
|
"xarm": ["XarmLift-v0"],
|
||||||
"dora": ["DoraAloha-v0", "DoraKoch-v0", "DoraReachy2-v0"],
|
"dora_aloha_real": ["DoraAloha-v0", "DoraKoch-v0", "DoraReachy2-v0"],
|
||||||
}
|
}
|
||||||
available_envs = list(available_tasks_per_env.keys())
|
available_envs = list(available_tasks_per_env.keys())
|
||||||
|
|
||||||
|
@ -78,7 +81,7 @@ available_datasets_per_env = {
|
||||||
"lerobot/xarm_push_medium_image",
|
"lerobot/xarm_push_medium_image",
|
||||||
"lerobot/xarm_push_medium_replay_image",
|
"lerobot/xarm_push_medium_replay_image",
|
||||||
],
|
],
|
||||||
"dora": [
|
"dora_aloha_real": [
|
||||||
"lerobot/aloha_static_battery",
|
"lerobot/aloha_static_battery",
|
||||||
"lerobot/aloha_static_candy",
|
"lerobot/aloha_static_candy",
|
||||||
"lerobot/aloha_static_coffee",
|
"lerobot/aloha_static_coffee",
|
||||||
|
@ -126,17 +129,19 @@ available_datasets = list(
|
||||||
itertools.chain(*available_datasets_per_env.values(), available_real_world_datasets)
|
itertools.chain(*available_datasets_per_env.values(), available_real_world_datasets)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# refers to attribute "name" of policy instance
|
||||||
available_policies = [
|
available_policies = [
|
||||||
"act",
|
"act",
|
||||||
"diffusion",
|
"diffusion",
|
||||||
"tdmpc",
|
"tdmpc",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# keys and values refer to yaml files
|
||||||
available_policies_per_env = {
|
available_policies_per_env = {
|
||||||
"aloha": ["act"],
|
"aloha": ["act"],
|
||||||
"dora": ["act"],
|
|
||||||
"pusht": ["diffusion"],
|
"pusht": ["diffusion"],
|
||||||
"xarm": ["tdmpc"],
|
"xarm": ["tdmpc"],
|
||||||
|
"dora_aloha_real": ["act_real"],
|
||||||
}
|
}
|
||||||
|
|
||||||
env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks]
|
env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks]
|
||||||
|
|
|
@ -11,7 +11,7 @@
|
||||||
# ```bash
|
# ```bash
|
||||||
# python lerobot/scripts/train.py \
|
# python lerobot/scripts/train.py \
|
||||||
# policy=act_real \
|
# policy=act_real \
|
||||||
# env=aloha_real
|
# env=dora_aloha_real
|
||||||
# ```
|
# ```
|
||||||
|
|
||||||
seed: 1000
|
seed: 1000
|
||||||
|
|
|
@ -9,7 +9,7 @@
|
||||||
# ```bash
|
# ```bash
|
||||||
# python lerobot/scripts/train.py \
|
# python lerobot/scripts/train.py \
|
||||||
# policy=act_real_no_state \
|
# policy=act_real_no_state \
|
||||||
# env=aloha_real
|
# env=dora_aloha_real
|
||||||
# ```
|
# ```
|
||||||
|
|
||||||
seed: 1000
|
seed: 1000
|
||||||
|
|
Loading…
Reference in New Issue