#!/usr/bin/env python # Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from copy import deepcopy from pathlib import Path import einops import pytest import torch from safetensors.torch import load_file from lerobot import available_policies from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.utils import cycle, dataset_to_policy_features from lerobot.common.envs.factory import make_env, make_env_config from lerobot.common.envs.utils import preprocess_observation from lerobot.common.optim.factory import make_optimizer_and_scheduler from lerobot.common.policies.act.modeling_act import ACTTemporalEnsembler from lerobot.common.policies.factory import ( get_policy_class, make_policy, make_policy_config, ) from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.utils.random_utils import seeded_context from lerobot.configs.default import DatasetConfig from lerobot.configs.train import TrainPipelineConfig from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from tests.scripts.save_policy_to_safetensors import get_policy_stats from tests.utils import DEVICE, require_cpu, require_env, require_x86_64_kernel @pytest.fixture def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_path): # Create only one camera input which is squared to fit all current policy constraints # e.g. vqbet and tdmpc works with one camera only, and tdmpc requires it to be squared camera_features = { "observation.images.laptop": { "shape": (84, 84, 3), "names": ["height", "width", "channels"], "info": None, }, } motor_features = { "action": { "dtype": "float32", "shape": (6,), "names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"], }, "observation.state": { "dtype": "float32", "shape": (6,), "names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"], }, } info = info_factory( total_episodes=1, total_frames=1, camera_features=camera_features, motor_features=motor_features ) ds_meta = lerobot_dataset_metadata_factory(root=tmp_path / "init", info=info) return ds_meta @pytest.mark.parametrize("policy_name", available_policies) def test_get_policy_and_config_classes(policy_name: str): """Check that the correct policy and config classes are returned.""" policy_cls = get_policy_class(policy_name) policy_cfg = make_policy_config(policy_name) assert policy_cls.name == policy_name assert issubclass( policy_cfg.__class__, inspect.signature(policy_cls.__init__).parameters["config"].annotation ) @pytest.mark.parametrize( "ds_repo_id,env_name,env_kwargs,policy_name,policy_kwargs", [ ("lerobot/xarm_lift_medium", "xarm", {}, "tdmpc", {"use_mpc": True}), ("lerobot/pusht", "pusht", {}, "diffusion", {}), ("lerobot/pusht", "pusht", {}, "vqbet", {}), ("lerobot/pusht", "pusht", {}, "act", {}), ("lerobot/aloha_sim_insertion_human", "aloha", {"task": "AlohaInsertion-v0"}, "act", {}), ( "lerobot/aloha_sim_insertion_scripted", "aloha", {"task": "AlohaInsertion-v0"}, "act", {}, ), ( "lerobot/aloha_sim_insertion_human", "aloha", {"task": "AlohaInsertion-v0"}, "diffusion", {}, ), ( "lerobot/aloha_sim_transfer_cube_human", "aloha", {"task": "AlohaTransferCube-v0"}, "act", {}, ), ( "lerobot/aloha_sim_transfer_cube_scripted", "aloha", {"task": "AlohaTransferCube-v0"}, "act", {}, ), ], ) @require_env def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs): """ Tests: - Making the policy object. - Checking that the policy follows the correct protocol and subclasses nn.Module and PyTorchModelHubMixin. - Updating the policy. - Using the policy to select actions at inference time. - Test the action can be applied to the policy Note: We test various combinations of policy and dataset. The combinations are by no means exhaustive, and for now we add tests as we see fit. """ train_cfg = TrainPipelineConfig( # TODO(rcadene, aliberts): remove dataset download dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]), policy=make_policy_config(policy_name, **policy_kwargs), env=make_env_config(env_name, **env_kwargs), device=DEVICE, ) # Check that we can make the policy object. dataset = make_dataset(train_cfg) policy = make_policy(train_cfg.policy, ds_meta=dataset.meta, device=DEVICE) assert isinstance(policy, PreTrainedPolicy) # Check that we run select_actions and get the appropriate output. env = make_env(train_cfg.env, n_envs=2) dataloader = torch.utils.data.DataLoader( dataset, num_workers=0, batch_size=2, shuffle=True, pin_memory=DEVICE != "cpu", drop_last=True, ) dl_iter = cycle(dataloader) batch = next(dl_iter) for key in batch: if isinstance(batch[key], torch.Tensor): batch[key] = batch[key].to(DEVICE, non_blocking=True) # Test updating the policy (and test that it does not mutate the batch) batch_ = deepcopy(batch) policy.forward(batch) assert set(batch) == set(batch_), "Batch keys are not the same after a forward pass." assert all( torch.equal(batch[k], batch_[k]) if isinstance(batch[k], torch.Tensor) else batch[k] == batch_[k] for k in batch ), "Batch values are not the same after a forward pass." # reset the policy and environment policy.reset() observation, _ = env.reset(seed=train_cfg.seed) # apply transform to normalize the observations observation = preprocess_observation(observation) # send observation to device/gpu observation = {key: observation[key].to(DEVICE, non_blocking=True) for key in observation} # get the next action for the environment (also check that the observation batch is not modified) observation_ = deepcopy(observation) with torch.inference_mode(): action = policy.select_action(observation).cpu().numpy() assert set(observation) == set(observation_), ( "Observation batch keys are not the same after a forward pass." ) assert all(torch.equal(observation[k], observation_[k]) for k in observation), ( "Observation batch values are not the same after a forward pass." ) # Test step through policy env.step(action) # TODO(rcadene, aliberts): This test is quite end-to-end. Move this test in test_optimizer? def test_act_backbone_lr(): """ Test that the ACT policy can be instantiated with a different learning rate for the backbone. """ cfg = TrainPipelineConfig( # TODO(rcadene, aliberts): remove dataset download dataset=DatasetConfig(repo_id="lerobot/aloha_sim_insertion_scripted", episodes=[0]), policy=make_policy_config("act", optimizer_lr=0.01, optimizer_lr_backbone=0.001), device=DEVICE, ) cfg.validate() # Needed for auto-setting some parameters assert cfg.policy.optimizer_lr == 0.01 assert cfg.policy.optimizer_lr_backbone == 0.001 dataset = make_dataset(cfg) policy = make_policy(cfg.policy, device=DEVICE, ds_meta=dataset.meta) optimizer, _ = make_optimizer_and_scheduler(cfg, policy) assert len(optimizer.param_groups) == 2 assert optimizer.param_groups[0]["lr"] == cfg.policy.optimizer_lr assert optimizer.param_groups[1]["lr"] == cfg.policy.optimizer_lr_backbone assert len(optimizer.param_groups[0]["params"]) == 133 assert len(optimizer.param_groups[1]["params"]) == 20 @pytest.mark.parametrize("policy_name", available_policies) def test_policy_defaults(dummy_dataset_metadata, policy_name: str): """Check that the policy can be instantiated with defaults.""" policy_cls = get_policy_class(policy_name) policy_cfg = make_policy_config(policy_name) features = dataset_to_policy_features(dummy_dataset_metadata.features) policy_cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION} policy_cfg.input_features = { key: ft for key, ft in features.items() if key not in policy_cfg.output_features } policy_cls(policy_cfg) @pytest.mark.parametrize("policy_name", available_policies) def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name: str): policy_cls = get_policy_class(policy_name) policy_cfg = make_policy_config(policy_name) features = dataset_to_policy_features(dummy_dataset_metadata.features) policy_cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION} policy_cfg.input_features = { key: ft for key, ft in features.items() if key not in policy_cfg.output_features } policy = policy_cls(policy_cfg) save_dir = tmp_path / f"test_save_and_load_pretrained_{policy_cls.__name__}" policy.save_pretrained(save_dir) policy_ = policy_cls.from_pretrained(save_dir, config=policy_cfg) assert all(torch.equal(p, p_) for p, p_ in zip(policy.parameters(), policy_.parameters(), strict=True)) @pytest.mark.parametrize("insert_temporal_dim", [False, True]) def test_normalize(insert_temporal_dim): """ Test that normalize/unnormalize can run without exceptions when properly set up, and that they raise an exception when the forward pass is called without the stats having been provided. TODO(rcadene, alexander-soare): This should also test that the normalization / unnormalization works as expected. """ input_features = { "observation.image": PolicyFeature( type=FeatureType.VISUAL, shape=(3, 96, 96), ), "observation.state": PolicyFeature( type=FeatureType.STATE, shape=(10,), ), } output_features = { "action": PolicyFeature( type=FeatureType.ACTION, shape=(5,), ), } norm_map = { "VISUAL": NormalizationMode.MEAN_STD, "STATE": NormalizationMode.MIN_MAX, "ACTION": NormalizationMode.MIN_MAX, } dataset_stats = { "observation.image": { "mean": torch.randn(3, 1, 1), "std": torch.randn(3, 1, 1), "min": torch.randn(3, 1, 1), "max": torch.randn(3, 1, 1), }, "observation.state": { "mean": torch.randn(10), "std": torch.randn(10), "min": torch.randn(10), "max": torch.randn(10), }, "action": { "mean": torch.randn(5), "std": torch.randn(5), "min": torch.randn(5), "max": torch.randn(5), }, } bsize = 2 input_batch = { "observation.image": torch.randn(bsize, 3, 96, 96), "observation.state": torch.randn(bsize, 10), } output_batch = { "action": torch.randn(bsize, 5), } if insert_temporal_dim: tdim = 4 for key in input_batch: # [2,3,96,96] -> [2,tdim,3,96,96] input_batch[key] = torch.stack([input_batch[key]] * tdim, dim=1) for key in output_batch: output_batch[key] = torch.stack([output_batch[key]] * tdim, dim=1) # test without stats normalize = Normalize(input_features, norm_map, stats=None) with pytest.raises(AssertionError): normalize(input_batch) # test with stats normalize = Normalize(input_features, norm_map, stats=dataset_stats) normalize(input_batch) # test loading pretrained models new_normalize = Normalize(input_features, norm_map, stats=None) new_normalize.load_state_dict(normalize.state_dict()) new_normalize(input_batch) # test without stats unnormalize = Unnormalize(output_features, norm_map, stats=None) with pytest.raises(AssertionError): unnormalize(output_batch) # test with stats unnormalize = Unnormalize(output_features, norm_map, stats=dataset_stats) unnormalize(output_batch) # test loading pretrained models new_unnormalize = Unnormalize(output_features, norm_map, stats=None) new_unnormalize.load_state_dict(unnormalize.state_dict()) unnormalize(output_batch) @pytest.mark.parametrize( "ds_repo_id, env_name, policy_name, policy_kwargs, train_kwargs, file_name_extra", [ # TODO(alexander-soare): `policy.use_mpc=false` was previously the default in the config yaml but it # was changed to true. For some reason, tests would pass locally, but not in CI. So here we override # to test with `policy.use_mpc=false`. ("lerobot/xarm_lift_medium", "xarm", "tdmpc", {"use_mpc": False}, {"batch_size": 25}, "use_policy"), # ("lerobot/xarm_lift_medium", "xarm", "tdmpc", {"use_mpc": True}, {}, "use_mpc"), # TODO(rcadene): the diffusion model was normalizing the image in mean=0.5 std=0.5 which is a hack supposed to # to normalize the image at all. In our current codebase we dont normalize at all. But there is still a minor difference # that fails the test. However, by testing to normalize the image with 0.5 0.5 in the current codebase, the test pass. # Thus, we deactivate this test for now. # ( # "lerobot/pusht", # "pusht", # "diffusion", # { # "n_action_steps": 8, # "num_inference_steps": 10, # "down_dims": [128, 256, 512], # }, # {"batch_size": 64}, # "", # ), ("lerobot/aloha_sim_insertion_human", "aloha", "act", {"n_action_steps": 10}, {}, ""), ( "lerobot/aloha_sim_insertion_human", "aloha", "act", {"n_action_steps": 1000, "chunk_size": 1000}, {}, "_1000_steps", ), ], ) # As artifacts have been generated on an x86_64 kernel, this test won't # pass if it's run on another platform due to floating point errors @require_x86_64_kernel @require_cpu def test_backward_compatibility( ds_repo_id, env_name, policy_name, policy_kwargs, train_kwargs, file_name_extra ): """ NOTE: If this test does not pass, and you have intentionally changed something in the policy: 1. Inspect the differences in policy outputs and make sure you can account for them. Your PR should include a report on what changed and how that affected the outputs. 2. Go to the `if __name__ == "__main__"` block of `tests/scripts/save_policy_to_safetensors.py` and add the policies you want to update the test artifacts for. 3. Run `python tests/scripts/save_policy_to_safetensors.py`. The test artifact should be updated. 4. Check that this test now passes. 5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state. 6. Remember to stage and commit the resulting changes to `tests/data`. """ env_policy_dir = ( Path("tests/data/save_policy_to_safetensors") / f"{env_name}_{policy_name}{file_name_extra}" ) saved_output_dict = load_file(env_policy_dir / "output_dict.safetensors") saved_grad_stats = load_file(env_policy_dir / "grad_stats.safetensors") saved_param_stats = load_file(env_policy_dir / "param_stats.safetensors") saved_actions = load_file(env_policy_dir / "actions.safetensors") output_dict, grad_stats, param_stats, actions = get_policy_stats( ds_repo_id, env_name, policy_name, policy_kwargs, train_kwargs ) for key in saved_output_dict: assert torch.allclose(output_dict[key], saved_output_dict[key], rtol=0.1, atol=1e-7) for key in saved_grad_stats: assert torch.allclose(grad_stats[key], saved_grad_stats[key], rtol=0.1, atol=1e-7) for key in saved_param_stats: assert torch.allclose(param_stats[key], saved_param_stats[key], rtol=0.1, atol=1e-7) for key in saved_actions: assert torch.allclose(actions[key], saved_actions[key], rtol=0.1, atol=1e-7) def test_act_temporal_ensembler(): """Check that the online method in ACTTemporalEnsembler matches a simple offline calculation.""" temporal_ensemble_coeff = 0.01 chunk_size = 100 episode_length = 101 ensembler = ACTTemporalEnsembler(temporal_ensemble_coeff, chunk_size) # An batch of arbitrary sequences of 1D actions we wish to compute the average over. We'll keep the # "action space" in [-1, 1]. Apart from that, there is no real reason for the numbers chosen. with seeded_context(0): # Dimension is (batch, episode_length, chunk_size, action_dim(=1)) # Stepping through the episode_length dim is like running inference at each rollout step and getting # a different action chunk. batch_seq = torch.stack( [ torch.rand(episode_length, chunk_size) * 0.05 - 0.6, torch.rand(episode_length, chunk_size) * 0.02 - 0.01, torch.rand(episode_length, chunk_size) * 0.2 + 0.3, ], dim=0, ).unsqueeze(-1) # unsqueeze for action dim batch_size = batch_seq.shape[0] # Exponential weighting (normalized). Unsqueeze once to match the position of the `episode_length` # dimension of `batch_seq`. weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size)).unsqueeze(-1) # Simulate stepping through a rollout and computing a batch of actions with model on each step. for i in range(episode_length): # Mock a batch of actions. actions = torch.zeros(size=(batch_size, chunk_size, 1)) + batch_seq[:, i] online_avg = ensembler.update(actions) # Simple offline calculation: avg = Σ(aᵢ*wᵢ) / Σ(wᵢ). # Note: The complicated bit here is the slicing. Think about the (episode_length, chunk_size) grid. # What we want to do is take diagonal slices across it starting from the left. # eg: chunk_size=4, episode_length=6 # ┌───────┐ # │0 1 2 3│ # │1 2 3 4│ # │2 3 4 5│ # │3 4 5 6│ # │4 5 6 7│ # │5 6 7 8│ # └───────┘ chunk_indices = torch.arange(min(i, chunk_size - 1), -1, -1) episode_step_indices = torch.arange(i + 1)[-len(chunk_indices) :] seq_slice = batch_seq[:, episode_step_indices, chunk_indices] offline_avg = ( einops.reduce(seq_slice * weights[: i + 1], "b s 1 -> b 1", "sum") / weights[: i + 1].sum() ) # Sanity check. The average should be between the extrema. assert torch.all(einops.reduce(seq_slice, "b s 1 -> b 1", "min") <= offline_avg) assert torch.all(offline_avg <= einops.reduce(seq_slice, "b s 1 -> b 1", "max")) # Selected atol=1e-4 keeping in mind actions in [-1, 1] and excepting 0.01% error. assert torch.allclose(online_avg, offline_avg, atol=1e-4)