lerobot/tests/test_policies.py

87 lines
2.9 KiB
Python
Raw Normal View History

import pytest
import torch
from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
from lerobot.common.policies.factory import make_policy
2024-03-20 00:02:09 +08:00
from lerobot.common.envs.factory import make_env
from lerobot.common.datasets.factory import make_dataset
2024-03-28 02:33:48 +08:00
from lerobot.common.utils import init_hydra_config
from .utils import DEVICE, DEFAULT_CONFIG_PATH
@pytest.mark.parametrize(
2024-03-20 02:50:04 +08:00
"env_name,policy_name,extra_overrides",
[
("xarm", "tdmpc", ["policy.mpc=true"]),
2024-04-08 22:02:03 +08:00
("pusht", "tdmpc", ["policy.mpc=false"]),
2024-03-20 02:50:04 +08:00
("pusht", "diffusion", []),
("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset_id=aloha_sim_insertion_human"]),
("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset_id=aloha_sim_insertion_scripted"]),
("aloha", "act", ["env.task=AlohaTransferCube-v0", "dataset_id=aloha_sim_transfer_cube_human"]),
("aloha", "act", ["env.task=AlohaTransferCube-v0", "dataset_id=aloha_sim_transfer_cube_scripted"]),
# TODO(aliberts): xarm not working with diffusion
# ("xarm", "diffusion", []),
],
)
def test_policy(env_name, policy_name, extra_overrides):
2024-03-20 02:50:04 +08:00
"""
Tests:
- Making the policy object.
- Updating the policy.
- Using the policy to select actions at inference time.
- Test the action can be applied to the policy
2024-03-20 02:50:04 +08:00
"""
2024-03-28 02:33:48 +08:00
cfg = init_hydra_config(
DEFAULT_CONFIG_PATH,
overrides=[
f"env={env_name}",
f"policy={policy_name}",
f"device={DEVICE}",
]
2024-03-20 02:50:04 +08:00
+ extra_overrides
)
2024-03-20 00:02:09 +08:00
# Check that we can make the policy object.
policy = make_policy(cfg)
2024-03-20 02:50:04 +08:00
# Check that we run select_actions and get the appropriate output.
dataset = make_dataset(cfg)
env = make_env(cfg, num_parallel_envs=2)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=4,
batch_size=2,
shuffle=True,
pin_memory=DEVICE != "cpu",
drop_last=True,
2024-03-20 02:50:04 +08:00
)
dl_iter = cycle(dataloader)
batch = next(dl_iter)
for key in batch:
batch[key] = batch[key].to(DEVICE, non_blocking=True)
# Test updating the policy
policy(batch, step=0)
# reset the policy and environment
policy.reset()
observation, _ = env.reset(seed=cfg.seed)
# apply transform to normalize the observations
observation = preprocess_observation(observation, dataset.transform)
# send observation to device/gpu
observation = {key: observation[key].to(DEVICE, non_blocking=True) for key in observation}
# get the next action for the environment
with torch.inference_mode():
action = policy.select_action(observation, step=0)
# apply inverse transform to unnormalize the action
action = postprocess_action(action, dataset.transform)
# Test step through policy
env.step(action)