87 lines
2.9 KiB
Python
87 lines
2.9 KiB
Python
import pytest
|
|
import torch
|
|
|
|
from lerobot.common.datasets.utils import cycle
|
|
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
|
|
from lerobot.common.policies.factory import make_policy
|
|
from lerobot.common.envs.factory import make_env
|
|
from lerobot.common.datasets.factory import make_dataset
|
|
from lerobot.common.utils import init_hydra_config
|
|
from .utils import DEVICE, DEFAULT_CONFIG_PATH
|
|
|
|
@pytest.mark.parametrize(
|
|
"env_name,policy_name,extra_overrides",
|
|
[
|
|
("xarm", "tdmpc", ["policy.mpc=true"]),
|
|
("pusht", "tdmpc", ["policy.mpc=false"]),
|
|
("pusht", "diffusion", []),
|
|
("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset_id=aloha_sim_insertion_human"]),
|
|
("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset_id=aloha_sim_insertion_scripted"]),
|
|
("aloha", "act", ["env.task=AlohaTransferCube-v0", "dataset_id=aloha_sim_transfer_cube_human"]),
|
|
("aloha", "act", ["env.task=AlohaTransferCube-v0", "dataset_id=aloha_sim_transfer_cube_scripted"]),
|
|
# TODO(aliberts): xarm not working with diffusion
|
|
# ("xarm", "diffusion", []),
|
|
],
|
|
)
|
|
def test_policy(env_name, policy_name, extra_overrides):
|
|
"""
|
|
Tests:
|
|
- Making the policy object.
|
|
- Updating the policy.
|
|
- Using the policy to select actions at inference time.
|
|
- Test the action can be applied to the policy
|
|
"""
|
|
cfg = init_hydra_config(
|
|
DEFAULT_CONFIG_PATH,
|
|
overrides=[
|
|
f"env={env_name}",
|
|
f"policy={policy_name}",
|
|
f"device={DEVICE}",
|
|
]
|
|
+ extra_overrides
|
|
)
|
|
# Check that we can make the policy object.
|
|
policy = make_policy(cfg)
|
|
# Check that we run select_actions and get the appropriate output.
|
|
dataset = make_dataset(cfg)
|
|
env = make_env(cfg, num_parallel_envs=2)
|
|
|
|
dataloader = torch.utils.data.DataLoader(
|
|
dataset,
|
|
num_workers=4,
|
|
batch_size=2,
|
|
shuffle=True,
|
|
pin_memory=DEVICE != "cpu",
|
|
drop_last=True,
|
|
)
|
|
dl_iter = cycle(dataloader)
|
|
|
|
batch = next(dl_iter)
|
|
|
|
for key in batch:
|
|
batch[key] = batch[key].to(DEVICE, non_blocking=True)
|
|
|
|
# Test updating the policy
|
|
policy(batch, step=0)
|
|
|
|
# reset the policy and environment
|
|
policy.reset()
|
|
observation, _ = env.reset(seed=cfg.seed)
|
|
|
|
# apply transform to normalize the observations
|
|
observation = preprocess_observation(observation, dataset.transform)
|
|
|
|
# send observation to device/gpu
|
|
observation = {key: observation[key].to(DEVICE, non_blocking=True) for key in observation}
|
|
|
|
# get the next action for the environment
|
|
with torch.inference_mode():
|
|
action = policy.select_action(observation, step=0)
|
|
|
|
# apply inverse transform to unnormalize the action
|
|
action = postprocess_action(action, dataset.transform)
|
|
|
|
# Test step through policy
|
|
env.step(action)
|
|
|