2024-02-25 18:50:23 +08:00
|
|
|
import pytest
|
2024-03-14 23:22:55 +08:00
|
|
|
import torch
|
2024-02-25 18:50:23 +08:00
|
|
|
|
2024-04-08 00:01:22 +08:00
|
|
|
from lerobot.common.datasets.utils import cycle
|
|
|
|
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
|
2024-02-25 18:50:23 +08:00
|
|
|
from lerobot.common.policies.factory import make_policy
|
2024-03-20 00:02:09 +08:00
|
|
|
from lerobot.common.envs.factory import make_env
|
2024-03-31 23:05:25 +08:00
|
|
|
from lerobot.common.datasets.factory import make_dataset
|
2024-03-28 02:33:48 +08:00
|
|
|
from lerobot.common.utils import init_hydra_config
|
|
|
|
from .utils import DEVICE, DEFAULT_CONFIG_PATH
|
2024-02-25 18:50:23 +08:00
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
2024-03-20 02:50:04 +08:00
|
|
|
"env_name,policy_name,extra_overrides",
|
2024-02-25 18:50:23 +08:00
|
|
|
[
|
2024-03-20 02:50:04 +08:00
|
|
|
("simxarm", "tdmpc", ["policy.mpc=true"]),
|
2024-04-08 22:02:03 +08:00
|
|
|
("pusht", "tdmpc", ["policy.mpc=false"]),
|
2024-03-20 02:50:04 +08:00
|
|
|
("pusht", "diffusion", []),
|
2024-04-08 00:01:22 +08:00
|
|
|
# ("aloha", "act", ["env.task=sim_insertion", "dataset_id=aloha_sim_insertion_human"]),
|
|
|
|
#("aloha", "act", ["env.task=sim_insertion", "dataset_id=aloha_sim_insertion_scripted"]),
|
|
|
|
#("aloha", "act", ["env.task=sim_transfer_cube", "dataset_id=aloha_sim_transfer_cube_human"]),
|
|
|
|
#("aloha", "act", ["env.task=sim_transfer_cube", "dataset_id=aloha_sim_transfer_cube_scripted"]),
|
2024-03-25 23:35:46 +08:00
|
|
|
# TODO(aliberts): simxarm not working with diffusion
|
|
|
|
# ("simxarm", "diffusion", []),
|
2024-02-25 18:50:23 +08:00
|
|
|
],
|
|
|
|
)
|
2024-04-08 00:01:22 +08:00
|
|
|
def test_policy(env_name, policy_name, extra_overrides):
|
2024-03-20 02:50:04 +08:00
|
|
|
"""
|
|
|
|
Tests:
|
|
|
|
- Making the policy object.
|
|
|
|
- Updating the policy.
|
|
|
|
- Using the policy to select actions at inference time.
|
2024-04-08 00:01:22 +08:00
|
|
|
- Test the action can be applied to the policy
|
2024-03-20 02:50:04 +08:00
|
|
|
"""
|
2024-03-28 02:33:48 +08:00
|
|
|
cfg = init_hydra_config(
|
|
|
|
DEFAULT_CONFIG_PATH,
|
2024-02-26 09:10:09 +08:00
|
|
|
overrides=[
|
|
|
|
f"env={env_name}",
|
|
|
|
f"policy={policy_name}",
|
2024-03-12 22:14:39 +08:00
|
|
|
f"device={DEVICE}",
|
2024-02-26 09:10:09 +08:00
|
|
|
]
|
2024-03-20 02:50:04 +08:00
|
|
|
+ extra_overrides
|
2024-02-26 09:10:09 +08:00
|
|
|
)
|
2024-03-20 00:02:09 +08:00
|
|
|
# Check that we can make the policy object.
|
2024-02-25 18:50:23 +08:00
|
|
|
policy = make_policy(cfg)
|
2024-03-20 02:50:04 +08:00
|
|
|
# Check that we run select_actions and get the appropriate output.
|
2024-03-31 23:05:25 +08:00
|
|
|
dataset = make_dataset(cfg)
|
2024-04-08 00:01:22 +08:00
|
|
|
env = make_env(cfg, num_parallel_envs=2)
|
|
|
|
|
|
|
|
dataloader = torch.utils.data.DataLoader(
|
|
|
|
dataset,
|
|
|
|
num_workers=4,
|
|
|
|
batch_size=cfg.policy.batch_size,
|
|
|
|
shuffle=True,
|
|
|
|
pin_memory=DEVICE != "cpu",
|
|
|
|
drop_last=True,
|
2024-03-20 02:50:04 +08:00
|
|
|
)
|
2024-04-08 00:01:22 +08:00
|
|
|
dl_iter = cycle(dataloader)
|
2024-03-14 23:22:55 +08:00
|
|
|
|
2024-04-08 00:01:22 +08:00
|
|
|
batch = next(dl_iter)
|
2024-03-14 23:22:55 +08:00
|
|
|
|
2024-04-08 00:01:22 +08:00
|
|
|
for key in batch:
|
|
|
|
batch[key] = batch[key].to(DEVICE, non_blocking=True)
|
2024-03-14 23:22:55 +08:00
|
|
|
|
2024-04-08 00:01:22 +08:00
|
|
|
# Test updating the policy
|
|
|
|
policy(batch, step=0)
|
2024-03-14 23:22:55 +08:00
|
|
|
|
2024-04-08 00:01:22 +08:00
|
|
|
# reset the policy and environment
|
|
|
|
policy.reset()
|
|
|
|
observation, _ = env.reset(seed=cfg.seed)
|
2024-03-14 23:22:55 +08:00
|
|
|
|
2024-04-08 00:01:22 +08:00
|
|
|
# apply transform to normalize the observations
|
|
|
|
observation = preprocess_observation(observation, dataset.transform)
|
2024-03-22 21:25:23 +08:00
|
|
|
|
2024-04-08 00:01:22 +08:00
|
|
|
# send observation to device/gpu
|
|
|
|
observation = {key: observation[key].to(DEVICE, non_blocking=True) for key in observation}
|
2024-03-14 23:22:55 +08:00
|
|
|
|
2024-04-08 00:01:22 +08:00
|
|
|
# get the next action for the environment
|
|
|
|
with torch.inference_mode():
|
|
|
|
action = policy.select_action(observation, step=0)
|
2024-03-19 03:18:21 +08:00
|
|
|
|
2024-04-08 00:01:22 +08:00
|
|
|
# apply inverse transform to unnormalize the action
|
|
|
|
action = postprocess_action(action, dataset.transform)
|
2024-03-14 23:22:55 +08:00
|
|
|
|
2024-04-08 00:01:22 +08:00
|
|
|
# Test step through policy
|
|
|
|
env.step(action)
|
2024-03-14 23:22:55 +08:00
|
|
|
|