2024-05-15 18:13:09 +08:00
|
|
|
#!/usr/bin/env python
|
|
|
|
|
|
|
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
2024-05-01 23:17:18 +08:00
|
|
|
import inspect
|
2024-05-04 22:20:30 +08:00
|
|
|
from pathlib import Path
|
2024-05-01 23:17:18 +08:00
|
|
|
|
2024-02-25 18:50:23 +08:00
|
|
|
import pytest
|
2024-03-14 23:22:55 +08:00
|
|
|
import torch
|
2024-05-01 23:17:18 +08:00
|
|
|
from huggingface_hub import PyTorchModelHubMixin
|
2024-05-04 22:20:30 +08:00
|
|
|
from safetensors.torch import load_file
|
2024-02-25 18:50:23 +08:00
|
|
|
|
2024-05-01 23:17:18 +08:00
|
|
|
from lerobot import available_policies
|
2024-04-18 20:47:42 +08:00
|
|
|
from lerobot.common.datasets.factory import make_dataset
|
2024-04-08 00:01:22 +08:00
|
|
|
from lerobot.common.datasets.utils import cycle
|
2024-04-18 20:47:42 +08:00
|
|
|
from lerobot.common.envs.factory import make_env
|
2024-05-04 00:33:16 +08:00
|
|
|
from lerobot.common.envs.utils import preprocess_observation
|
2024-05-01 23:17:18 +08:00
|
|
|
from lerobot.common.policies.factory import get_policy_and_config_classes, make_policy
|
2024-04-25 17:47:38 +08:00
|
|
|
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
2024-04-16 23:35:04 +08:00
|
|
|
from lerobot.common.policies.policy_protocol import Policy
|
2024-04-18 20:47:42 +08:00
|
|
|
from lerobot.common.utils.utils import init_hydra_config
|
2024-05-31 21:31:02 +08:00
|
|
|
from tests.scripts.save_policy_to_safetensors import get_policy_stats
|
2024-05-20 19:48:09 +08:00
|
|
|
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_cpu, require_env, require_x86_64_kernel
|
2024-02-25 18:50:23 +08:00
|
|
|
|
2024-04-16 23:31:44 +08:00
|
|
|
|
2024-05-01 23:17:18 +08:00
|
|
|
@pytest.mark.parametrize("policy_name", available_policies)
|
|
|
|
def test_get_policy_and_config_classes(policy_name: str):
|
|
|
|
"""Check that the correct policy and config classes are returned."""
|
|
|
|
policy_cls, config_cls = get_policy_and_config_classes(policy_name)
|
|
|
|
assert policy_cls.name == policy_name
|
|
|
|
assert issubclass(config_cls, inspect.signature(policy_cls.__init__).parameters["config"].annotation)
|
|
|
|
|
|
|
|
|
2024-04-18 20:47:42 +08:00
|
|
|
# TODO(aliberts): refactor using lerobot/__init__.py variables
|
2024-02-25 18:50:23 +08:00
|
|
|
@pytest.mark.parametrize(
|
2024-03-20 02:50:04 +08:00
|
|
|
"env_name,policy_name,extra_overrides",
|
2024-02-25 18:50:23 +08:00
|
|
|
[
|
2024-05-01 23:40:04 +08:00
|
|
|
("xarm", "tdmpc", ["policy.use_mpc=true", "dataset_repo_id=lerobot/xarm_lift_medium"]),
|
2024-03-20 02:50:04 +08:00
|
|
|
("pusht", "diffusion", []),
|
2024-04-30 23:08:59 +08:00
|
|
|
("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset_repo_id=lerobot/aloha_sim_insertion_human"]),
|
2024-04-25 18:23:12 +08:00
|
|
|
(
|
|
|
|
"aloha",
|
|
|
|
"act",
|
2024-04-30 23:08:59 +08:00
|
|
|
["env.task=AlohaInsertion-v0", "dataset_repo_id=lerobot/aloha_sim_insertion_scripted"],
|
2024-04-25 18:23:12 +08:00
|
|
|
),
|
|
|
|
(
|
|
|
|
"aloha",
|
|
|
|
"act",
|
2024-04-30 23:08:59 +08:00
|
|
|
["env.task=AlohaTransferCube-v0", "dataset_repo_id=lerobot/aloha_sim_transfer_cube_human"],
|
2024-04-25 18:23:12 +08:00
|
|
|
),
|
|
|
|
(
|
|
|
|
"aloha",
|
|
|
|
"act",
|
2024-04-30 23:08:59 +08:00
|
|
|
["env.task=AlohaTransferCube-v0", "dataset_repo_id=lerobot/aloha_sim_transfer_cube_scripted"],
|
2024-04-25 18:23:12 +08:00
|
|
|
),
|
2024-05-16 20:51:53 +08:00
|
|
|
# Note: these parameters also need custom logic in the test function for overriding the Hydra config.
|
|
|
|
(
|
|
|
|
"aloha",
|
|
|
|
"diffusion",
|
|
|
|
["env.task=AlohaInsertion-v0", "dataset_repo_id=lerobot/aloha_sim_insertion_human"],
|
|
|
|
),
|
|
|
|
# Note: these parameters also need custom logic in the test function for overriding the Hydra config.
|
|
|
|
("pusht", "act", ["env.task=PushT-v0", "dataset_repo_id=lerobot/pusht"]),
|
2024-05-31 21:31:02 +08:00
|
|
|
("dora_aloha_real", "act_real", []),
|
|
|
|
("dora_aloha_real", "act_real_no_state", []),
|
2024-02-25 18:50:23 +08:00
|
|
|
],
|
|
|
|
)
|
2024-04-18 20:47:42 +08:00
|
|
|
@require_env
|
2024-04-08 00:01:22 +08:00
|
|
|
def test_policy(env_name, policy_name, extra_overrides):
|
2024-03-20 02:50:04 +08:00
|
|
|
"""
|
|
|
|
Tests:
|
|
|
|
- Making the policy object.
|
2024-05-01 23:17:18 +08:00
|
|
|
- Checking that the policy follows the correct protocol and subclasses nn.Module
|
|
|
|
and PyTorchModelHubMixin.
|
2024-03-20 02:50:04 +08:00
|
|
|
- Updating the policy.
|
|
|
|
- Using the policy to select actions at inference time.
|
2024-04-08 00:01:22 +08:00
|
|
|
- Test the action can be applied to the policy
|
2024-05-31 21:31:02 +08:00
|
|
|
|
|
|
|
Note: We test various combinations of policy and dataset. The combinations are by no means exhaustive,
|
|
|
|
and for now we add tests as we see fit.
|
2024-03-20 02:50:04 +08:00
|
|
|
"""
|
2024-03-28 02:33:48 +08:00
|
|
|
cfg = init_hydra_config(
|
|
|
|
DEFAULT_CONFIG_PATH,
|
2024-02-26 09:10:09 +08:00
|
|
|
overrides=[
|
|
|
|
f"env={env_name}",
|
|
|
|
f"policy={policy_name}",
|
2024-03-12 22:14:39 +08:00
|
|
|
f"device={DEVICE}",
|
2024-02-26 09:10:09 +08:00
|
|
|
]
|
2024-04-16 23:31:44 +08:00
|
|
|
+ extra_overrides,
|
2024-02-26 09:10:09 +08:00
|
|
|
)
|
2024-04-25 17:47:38 +08:00
|
|
|
|
2024-05-16 20:51:53 +08:00
|
|
|
# Additional config override logic.
|
|
|
|
if env_name == "aloha" and policy_name == "diffusion":
|
|
|
|
for keys in [
|
|
|
|
("training", "delta_timestamps"),
|
|
|
|
("policy", "input_shapes"),
|
|
|
|
("policy", "input_normalization_modes"),
|
|
|
|
]:
|
|
|
|
dct = dict(cfg[keys[0]][keys[1]])
|
|
|
|
dct["observation.images.top"] = dct["observation.image"]
|
|
|
|
del dct["observation.image"]
|
|
|
|
cfg[keys[0]][keys[1]] = dct
|
|
|
|
cfg.override_dataset_stats = None
|
|
|
|
|
|
|
|
# Additional config override logic.
|
|
|
|
if env_name == "pusht" and policy_name == "act":
|
|
|
|
for keys in [
|
|
|
|
("policy", "input_shapes"),
|
|
|
|
("policy", "input_normalization_modes"),
|
|
|
|
]:
|
|
|
|
dct = dict(cfg[keys[0]][keys[1]])
|
|
|
|
dct["observation.image"] = dct["observation.images.top"]
|
|
|
|
del dct["observation.images.top"]
|
|
|
|
cfg[keys[0]][keys[1]] = dct
|
|
|
|
cfg.override_dataset_stats = None
|
|
|
|
|
2024-03-20 00:02:09 +08:00
|
|
|
# Check that we can make the policy object.
|
2024-04-25 17:47:38 +08:00
|
|
|
dataset = make_dataset(cfg)
|
2024-05-01 23:17:18 +08:00
|
|
|
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
|
2024-04-16 23:31:44 +08:00
|
|
|
# Check that the policy follows the required protocol.
|
|
|
|
assert isinstance(
|
|
|
|
policy, Policy
|
|
|
|
), f"The policy does not follow the required protocol. Please see {Policy.__module__}.{Policy.__name__}."
|
2024-05-01 23:17:18 +08:00
|
|
|
assert isinstance(policy, torch.nn.Module)
|
|
|
|
assert isinstance(policy, PyTorchModelHubMixin)
|
2024-04-25 17:47:38 +08:00
|
|
|
|
2024-03-20 02:50:04 +08:00
|
|
|
# Check that we run select_actions and get the appropriate output.
|
2024-05-04 00:33:16 +08:00
|
|
|
env = make_env(cfg, n_envs=2)
|
2024-04-08 00:01:22 +08:00
|
|
|
|
|
|
|
dataloader = torch.utils.data.DataLoader(
|
|
|
|
dataset,
|
2024-05-31 21:31:02 +08:00
|
|
|
num_workers=0,
|
2024-04-09 10:50:32 +08:00
|
|
|
batch_size=2,
|
2024-04-08 00:01:22 +08:00
|
|
|
shuffle=True,
|
|
|
|
pin_memory=DEVICE != "cpu",
|
|
|
|
drop_last=True,
|
2024-03-20 02:50:04 +08:00
|
|
|
)
|
2024-04-08 00:01:22 +08:00
|
|
|
dl_iter = cycle(dataloader)
|
2024-03-14 23:22:55 +08:00
|
|
|
|
2024-04-08 00:01:22 +08:00
|
|
|
batch = next(dl_iter)
|
2024-03-14 23:22:55 +08:00
|
|
|
|
2024-04-08 00:01:22 +08:00
|
|
|
for key in batch:
|
|
|
|
batch[key] = batch[key].to(DEVICE, non_blocking=True)
|
2024-03-14 23:22:55 +08:00
|
|
|
|
2024-04-08 00:01:22 +08:00
|
|
|
# Test updating the policy
|
2024-05-01 23:40:04 +08:00
|
|
|
policy.forward(batch)
|
2024-03-14 23:22:55 +08:00
|
|
|
|
2024-04-08 00:01:22 +08:00
|
|
|
# reset the policy and environment
|
|
|
|
policy.reset()
|
|
|
|
observation, _ = env.reset(seed=cfg.seed)
|
2024-03-14 23:22:55 +08:00
|
|
|
|
2024-04-08 00:01:22 +08:00
|
|
|
# apply transform to normalize the observations
|
2024-04-25 17:47:38 +08:00
|
|
|
observation = preprocess_observation(observation)
|
2024-03-22 21:25:23 +08:00
|
|
|
|
2024-04-08 00:01:22 +08:00
|
|
|
# send observation to device/gpu
|
|
|
|
observation = {key: observation[key].to(DEVICE, non_blocking=True) for key in observation}
|
2024-03-14 23:22:55 +08:00
|
|
|
|
2024-04-08 00:01:22 +08:00
|
|
|
# get the next action for the environment
|
|
|
|
with torch.inference_mode():
|
2024-05-04 00:33:16 +08:00
|
|
|
action = policy.select_action(observation).cpu().numpy()
|
2024-03-14 23:22:55 +08:00
|
|
|
|
2024-04-08 00:01:22 +08:00
|
|
|
# Test step through policy
|
|
|
|
env.step(action)
|
2024-04-25 17:47:38 +08:00
|
|
|
|
2024-04-29 15:26:59 +08:00
|
|
|
|
2024-05-01 23:17:18 +08:00
|
|
|
@pytest.mark.parametrize("policy_name", available_policies)
|
|
|
|
def test_policy_defaults(policy_name: str):
|
|
|
|
"""Check that the policy can be instantiated with defaults."""
|
|
|
|
policy_cls, _ = get_policy_and_config_classes(policy_name)
|
|
|
|
policy_cls()
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("policy_name", available_policies)
|
|
|
|
def test_save_and_load_pretrained(policy_name: str):
|
|
|
|
policy_cls, _ = get_policy_and_config_classes(policy_name)
|
|
|
|
policy: Policy = policy_cls()
|
|
|
|
save_dir = "/tmp/test_save_and_load_pretrained_{policy_cls.__name__}"
|
|
|
|
policy.save_pretrained(save_dir)
|
|
|
|
policy_ = policy_cls.from_pretrained(save_dir)
|
|
|
|
assert all(torch.equal(p, p_) for p, p_ in zip(policy.parameters(), policy_.parameters(), strict=True))
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("insert_temporal_dim", [False, True])
|
2024-04-25 17:47:38 +08:00
|
|
|
def test_normalize(insert_temporal_dim):
|
|
|
|
"""
|
|
|
|
Test that normalize/unnormalize can run without exceptions when properly set up, and that they raise
|
|
|
|
an exception when the forward pass is called without the stats having been provided.
|
|
|
|
|
|
|
|
TODO(rcadene, alexander-soare): This should also test that the normalization / unnormalization works as
|
|
|
|
expected.
|
|
|
|
"""
|
|
|
|
|
|
|
|
input_shapes = {
|
|
|
|
"observation.image": [3, 96, 96],
|
|
|
|
"observation.state": [10],
|
|
|
|
}
|
|
|
|
output_shapes = {
|
|
|
|
"action": [5],
|
|
|
|
}
|
|
|
|
|
|
|
|
normalize_input_modes = {
|
|
|
|
"observation.image": "mean_std",
|
|
|
|
"observation.state": "min_max",
|
|
|
|
}
|
|
|
|
unnormalize_output_modes = {
|
|
|
|
"action": "min_max",
|
|
|
|
}
|
|
|
|
|
|
|
|
dataset_stats = {
|
|
|
|
"observation.image": {
|
|
|
|
"mean": torch.randn(3, 1, 1),
|
|
|
|
"std": torch.randn(3, 1, 1),
|
|
|
|
"min": torch.randn(3, 1, 1),
|
|
|
|
"max": torch.randn(3, 1, 1),
|
|
|
|
},
|
|
|
|
"observation.state": {
|
|
|
|
"mean": torch.randn(10),
|
|
|
|
"std": torch.randn(10),
|
|
|
|
"min": torch.randn(10),
|
|
|
|
"max": torch.randn(10),
|
|
|
|
},
|
|
|
|
"action": {
|
|
|
|
"mean": torch.randn(5),
|
|
|
|
"std": torch.randn(5),
|
|
|
|
"min": torch.randn(5),
|
|
|
|
"max": torch.randn(5),
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
bsize = 2
|
|
|
|
input_batch = {
|
|
|
|
"observation.image": torch.randn(bsize, 3, 96, 96),
|
|
|
|
"observation.state": torch.randn(bsize, 10),
|
|
|
|
}
|
|
|
|
output_batch = {
|
|
|
|
"action": torch.randn(bsize, 5),
|
|
|
|
}
|
|
|
|
|
|
|
|
if insert_temporal_dim:
|
|
|
|
tdim = 4
|
|
|
|
|
|
|
|
for key in input_batch:
|
|
|
|
# [2,3,96,96] -> [2,tdim,3,96,96]
|
|
|
|
input_batch[key] = torch.stack([input_batch[key]] * tdim, dim=1)
|
|
|
|
|
|
|
|
for key in output_batch:
|
|
|
|
output_batch[key] = torch.stack([output_batch[key]] * tdim, dim=1)
|
|
|
|
|
|
|
|
# test without stats
|
|
|
|
normalize = Normalize(input_shapes, normalize_input_modes, stats=None)
|
|
|
|
with pytest.raises(AssertionError):
|
|
|
|
normalize(input_batch)
|
|
|
|
|
|
|
|
# test with stats
|
|
|
|
normalize = Normalize(input_shapes, normalize_input_modes, stats=dataset_stats)
|
|
|
|
normalize(input_batch)
|
|
|
|
|
|
|
|
# test loading pretrained models
|
|
|
|
new_normalize = Normalize(input_shapes, normalize_input_modes, stats=None)
|
|
|
|
new_normalize.load_state_dict(normalize.state_dict())
|
|
|
|
new_normalize(input_batch)
|
|
|
|
|
|
|
|
# test without stats
|
|
|
|
unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=None)
|
|
|
|
with pytest.raises(AssertionError):
|
|
|
|
unnormalize(output_batch)
|
|
|
|
|
|
|
|
# test with stats
|
|
|
|
unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=dataset_stats)
|
|
|
|
unnormalize(output_batch)
|
|
|
|
|
|
|
|
# test loading pretrained models
|
|
|
|
new_unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=None)
|
|
|
|
new_unnormalize.load_state_dict(unnormalize.state_dict())
|
|
|
|
unnormalize(output_batch)
|
2024-05-04 22:20:30 +08:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"env_name, policy_name, extra_overrides",
|
|
|
|
[
|
2024-05-09 20:42:12 +08:00
|
|
|
("xarm", "tdmpc", []),
|
2024-05-04 22:20:30 +08:00
|
|
|
(
|
|
|
|
"pusht",
|
|
|
|
"diffusion",
|
|
|
|
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
|
|
|
|
),
|
|
|
|
("aloha", "act", ["policy.n_action_steps=10"]),
|
2024-05-31 21:31:02 +08:00
|
|
|
("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]),
|
|
|
|
("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]),
|
2024-05-04 22:20:30 +08:00
|
|
|
],
|
|
|
|
)
|
|
|
|
# As artifacts have been generated on an x86_64 kernel, this test won't
|
|
|
|
# pass if it's run on another platform due to floating point errors
|
|
|
|
@require_x86_64_kernel
|
2024-05-20 19:48:09 +08:00
|
|
|
@require_cpu
|
2024-05-04 22:20:30 +08:00
|
|
|
def test_backward_compatibility(env_name, policy_name, extra_overrides):
|
2024-05-05 18:26:12 +08:00
|
|
|
"""
|
|
|
|
NOTE: If this test does not pass, and you have intentionally changed something in the policy:
|
|
|
|
1. Inspect the differences in policy outputs and make sure you can account for them. Your PR should
|
|
|
|
include a report on what changed and how that affected the outputs.
|
2024-05-20 19:48:09 +08:00
|
|
|
2. Go to the `if __name__ == "__main__"` block of `tests/scripts/save_policy_to_safetensors.py` and
|
2024-05-05 18:26:12 +08:00
|
|
|
add the policies you want to update the test artifacts for.
|
2024-05-20 19:48:09 +08:00
|
|
|
3. Run `python tests/scripts/save_policy_to_safetensors.py`. The test artifact should be updated.
|
2024-05-05 18:26:12 +08:00
|
|
|
4. Check that this test now passes.
|
2024-05-20 19:48:09 +08:00
|
|
|
5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state.
|
2024-05-05 18:26:12 +08:00
|
|
|
6. Remember to stage and commit the resulting changes to `tests/data`.
|
|
|
|
"""
|
2024-05-04 22:20:30 +08:00
|
|
|
env_policy_dir = Path("tests/data/save_policy_to_safetensors") / f"{env_name}_{policy_name}"
|
|
|
|
saved_output_dict = load_file(env_policy_dir / "output_dict.safetensors")
|
|
|
|
saved_grad_stats = load_file(env_policy_dir / "grad_stats.safetensors")
|
|
|
|
saved_param_stats = load_file(env_policy_dir / "param_stats.safetensors")
|
|
|
|
saved_actions = load_file(env_policy_dir / "actions.safetensors")
|
|
|
|
|
|
|
|
output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, extra_overrides)
|
|
|
|
|
|
|
|
for key in saved_output_dict:
|
|
|
|
assert torch.isclose(output_dict[key], saved_output_dict[key], rtol=0.1, atol=1e-7).all()
|
|
|
|
for key in saved_grad_stats:
|
|
|
|
assert torch.isclose(grad_stats[key], saved_grad_stats[key], rtol=0.1, atol=1e-7).all()
|
|
|
|
for key in saved_param_stats:
|
|
|
|
assert torch.isclose(param_stats[key], saved_param_stats[key], rtol=50, atol=1e-7).all()
|
|
|
|
for key in saved_actions:
|
|
|
|
assert torch.isclose(actions[key], saved_actions[key], rtol=0.1, atol=1e-7).all()
|