From 68c1b13406068b9d88afbfcb2366f927141514f3 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Thu, 16 May 2024 13:51:53 +0100 Subject: [PATCH] Make policies compatible with other/multiple image keys (#149) --- .../common/policies/act/configuration_act.py | 7 --- lerobot/common/policies/act/modeling_act.py | 46 +++++++------------ .../diffusion/configuration_diffusion.py | 17 +++++-- .../policies/diffusion/modeling_diffusion.py | 34 +++++++++----- .../policies/tdmpc/configuration_tdmpc.py | 12 +++-- .../common/policies/tdmpc/modeling_tdmpc.py | 20 ++++---- lerobot/common/policies/utils.py | 4 ++ lerobot/scripts/train.py | 3 +- tests/test_policies.py | 33 +++++++++++++ 9 files changed, 107 insertions(+), 69 deletions(-) diff --git a/lerobot/common/policies/act/configuration_act.py b/lerobot/common/policies/act/configuration_act.py index 95f443da..be444b06 100644 --- a/lerobot/common/policies/act/configuration_act.py +++ b/lerobot/common/policies/act/configuration_act.py @@ -145,10 +145,3 @@ class ACTConfig: raise ValueError( f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`" ) - # Check that there is only one image. - # TODO(alexander-soare): generalize this to multiple images. - if ( - sum(k.startswith("observation.images.") for k in self.input_shapes) != 1 - or "observation.images.top" not in self.input_shapes - ): - raise ValueError('For now, only "observation.images.top" is accepted for an image input.') diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index e85a3736..4a8df1ce 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -62,6 +62,7 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin): if config is None: config = ACTConfig() self.config = config + self.normalize_inputs = Normalize( config.input_shapes, config.input_normalization_modes, dataset_stats ) @@ -71,8 +72,13 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin): self.unnormalize_outputs = Unnormalize( config.output_shapes, config.output_normalization_modes, dataset_stats ) + self.model = ACT(config) + self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] + + self.reset() + def reset(self): """This should be called whenever the environment is reset.""" if self.config.n_action_steps is not None: @@ -86,13 +92,10 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin): environment. It works by managing the actions in a queue and only calling `select_actions` when the queue is empty. """ - assert "observation.images.top" in batch - assert "observation.state" in batch - self.eval() batch = self.normalize_inputs(batch) - self._stack_images(batch) + batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) if len(self._action_queue) == 0: # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue @@ -108,8 +111,8 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin): def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: """Run the batch through the model and compute the loss for training or validation.""" batch = self.normalize_inputs(batch) + batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) batch = self.normalize_targets(batch) - self._stack_images(batch) actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch) l1_loss = ( @@ -132,21 +135,6 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin): return loss_dict - def _stack_images(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: - """Stacks all the images in a batch and puts them in a new key: "observation.images". - - This function expects `batch` to have (at least): - { - "observation.state": (B, state_dim) batch of robot states. - "observation.images.{name}": (B, C, H, W) tensor of images. - } - """ - # Stack images in the order dictated by input_shapes. - batch["observation.images"] = torch.stack( - [batch[k] for k in self.config.input_shapes if k.startswith("observation.images.")], - dim=-4, - ) - class ACT(nn.Module): """Action Chunking Transformer: The underlying neural network for ACTPolicy. @@ -176,10 +164,10 @@ class ACT(nn.Module): │ encoder │ │ │ │Transf.│ │ │ │ │ │ │encoder│ │ └───▲─────┘ │ │ │ │ │ - │ │ │ └───▲───┘ │ - │ │ │ │ │ - inputs └─────┼─────┘ │ - │ │ + │ │ │ └▲──▲─▲─┘ │ + │ │ │ │ │ │ │ + inputs └─────┼──┘ │ image emb. │ + │ state emb. │ └───────────────────────┘ """ @@ -321,18 +309,18 @@ class ACT(nn.Module): all_cam_features.append(cam_features) all_cam_pos_embeds.append(cam_pos_embed) # Concatenate camera observation feature maps and positional embeddings along the width dimension. - encoder_in = torch.cat(all_cam_features, axis=3) - cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=3) + encoder_in = torch.cat(all_cam_features, axis=-1) + cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=-1) # Get positional embeddings for robot state and latent. - robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) - latent_embed = self.encoder_latent_input_proj(latent_sample) + robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C) + latent_embed = self.encoder_latent_input_proj(latent_sample) # (B, C) # Stack encoder input and positional embeddings moving to (S, B, C). encoder_in = torch.cat( [ torch.stack([latent_embed, robot_state_embed], axis=0), - encoder_in.flatten(2).permute(2, 0, 1), + einops.rearrange(encoder_in, "b c h w -> (h w) b c"), ] ) pos_embed = torch.cat( diff --git a/lerobot/common/policies/diffusion/configuration_diffusion.py b/lerobot/common/policies/diffusion/configuration_diffusion.py index d0554942..632f6cd6 100644 --- a/lerobot/common/policies/diffusion/configuration_diffusion.py +++ b/lerobot/common/policies/diffusion/configuration_diffusion.py @@ -148,14 +148,21 @@ class DiffusionConfig: raise ValueError( f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}." ) + # There should only be one image key. + image_keys = {k for k in self.input_shapes if k.startswith("observation.image")} + if len(image_keys) != 1: + raise ValueError( + f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}." + ) + image_key = next(iter(image_keys)) if ( - self.crop_shape[0] > self.input_shapes["observation.image"][1] - or self.crop_shape[1] > self.input_shapes["observation.image"][2] + self.crop_shape[0] > self.input_shapes[image_key][1] + or self.crop_shape[1] > self.input_shapes[image_key][2] ): raise ValueError( - f'`crop_shape` should fit within `input_shapes["observation.image"]`. Got {self.crop_shape} ' - f'for `crop_shape` and {self.input_shapes["observation.image"]} for ' - '`input_shapes["observation.image"]`.' + f"`crop_shape` should fit within `input_shapes[{image_key}]`. Got {self.crop_shape} " + f"for `crop_shape` and {self.input_shapes[image_key]} for " + "`input_shapes[{image_key}]`." ) supported_prediction_types = ["epsilon", "sample"] if self.prediction_type not in supported_prediction_types: diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index c67040b6..1659b68e 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -19,6 +19,7 @@ TODO(alexander-soare): - Remove reliance on Robomimic for SpatialSoftmax. - Remove reliance on diffusers for DDPMScheduler and LR scheduler. + - Make compatible with multiple image keys. """ import math @@ -83,10 +84,18 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin): self.diffusion = DiffusionModel(config) + image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] + # Note: This check is covered in the post-init of the config but have a sanity check just in case. + if len(image_keys) != 1: + raise NotImplementedError( + f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}." + ) + self.input_image_key = image_keys[0] + + self.reset() + def reset(self): - """ - Clear observation and action queues. Should be called on `env.reset()` - """ + """Clear observation and action queues. Should be called on `env.reset()`""" self._queues = { "observation.image": deque(maxlen=self.config.n_obs_steps), "observation.state": deque(maxlen=self.config.n_obs_steps), @@ -115,16 +124,14 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin): "horizon" may not the best name to describe what the variable actually means, because this period is actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past. """ - assert "observation.image" in batch - assert "observation.state" in batch - batch = self.normalize_inputs(batch) + batch["observation.image"] = batch[self.input_image_key] self._queues = populate_queues(self._queues, batch) if len(self._queues["action"]) == 0: # stack n latest observations from the queue - batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch} + batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues} actions = self.diffusion.generate_actions(batch) # TODO(rcadene): make above methods return output dictionary? @@ -138,6 +145,7 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin): def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: """Run the batch through the model and compute the loss for training or validation.""" batch = self.normalize_inputs(batch) + batch["observation.image"] = batch[self.input_image_key] batch = self.normalize_targets(batch) loss = self.diffusion.compute_loss(batch) return {"loss": loss} @@ -215,13 +223,12 @@ class DiffusionModel(nn.Module): def generate_actions(self, batch: dict[str, Tensor]) -> Tensor: """ - This function expects `batch` to have (at least): + This function expects `batch` to have: { "observation.state": (B, n_obs_steps, state_dim) "observation.image": (B, n_obs_steps, C, H, W) } """ - assert set(batch).issuperset({"observation.state", "observation.image"}) batch_size, n_obs_steps = batch["observation.state"].shape[:2] assert n_obs_steps == self.config.n_obs_steps @@ -345,9 +352,12 @@ class DiffusionRgbEncoder(nn.Module): # Set up pooling and final layers. # Use a dry run to get the feature map shape. - # The dummy input should take the number of image channels from `config.input_shapes` and it should use the - # height and width from `config.crop_shape`. - dummy_input = torch.zeros(size=(1, config.input_shapes["observation.image"][0], *config.crop_shape)) + # The dummy input should take the number of image channels from `config.input_shapes` and it should + # use the height and width from `config.crop_shape`. + image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] + assert len(image_keys) == 1 + image_key = image_keys[0] + dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *config.crop_shape)) with torch.inference_mode(): dummy_feature_map = self.backbone(dummy_input) feature_map_shape = tuple(dummy_feature_map.shape[1:]) diff --git a/lerobot/common/policies/tdmpc/configuration_tdmpc.py b/lerobot/common/policies/tdmpc/configuration_tdmpc.py index ddf52248..cf76fb08 100644 --- a/lerobot/common/policies/tdmpc/configuration_tdmpc.py +++ b/lerobot/common/policies/tdmpc/configuration_tdmpc.py @@ -147,12 +147,18 @@ class TDMPCConfig: def __post_init__(self): """Input validation (not exhaustive).""" - if self.input_shapes["observation.image"][-2] != self.input_shapes["observation.image"][-1]: + # There should only be one image key. + image_keys = {k for k in self.input_shapes if k.startswith("observation.image")} + if len(image_keys) != 1: + raise ValueError( + f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}." + ) + image_key = next(iter(image_keys)) + if self.input_shapes[image_key][-2] != self.input_shapes[image_key][-1]: # TODO(alexander-soare): This limitation is solely because of code in the random shift # augmentation. It should be able to be removed. raise ValueError( - "Only square images are handled now. Got image shape " - f"{self.input_shapes['observation.image']}." + f"Only square images are handled now. Got image shape {self.input_shapes[image_key]}." ) if self.n_gaussian_samples <= 0: raise ValueError( diff --git a/lerobot/common/policies/tdmpc/modeling_tdmpc.py b/lerobot/common/policies/tdmpc/modeling_tdmpc.py index 70e78c98..7c873bf2 100644 --- a/lerobot/common/policies/tdmpc/modeling_tdmpc.py +++ b/lerobot/common/policies/tdmpc/modeling_tdmpc.py @@ -112,13 +112,12 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin): config.output_shapes, config.output_normalization_modes, dataset_stats ) - def save(self, fp): - """Save state dict of TOLD model to filepath.""" - torch.save(self.state_dict(), fp) + image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] + # Note: This check is covered in the post-init of the config but have a sanity check just in case. + assert len(image_keys) == 1 + self.input_image_key = image_keys[0] - def load(self, fp): - """Load a saved state dict from filepath into current agent.""" - self.load_state_dict(torch.load(fp)) + self.reset() def reset(self): """ @@ -137,10 +136,8 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin): @torch.no_grad() def select_action(self, batch: dict[str, Tensor]): """Select a single action given environment observations.""" - assert "observation.image" in batch - assert "observation.state" in batch - batch = self.normalize_inputs(batch) + batch["observation.image"] = batch[self.input_image_key] self._queues = populate_queues(self._queues, batch) @@ -319,13 +316,11 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin): device = get_device_from_parameters(self) batch = self.normalize_inputs(batch) + batch["observation.image"] = batch[self.input_image_key] batch = self.normalize_targets(batch) info = {} - # TODO(alexander-soare): Refactor TDMPC and make it comply with the policy interface documentation. - batch_size = batch["index"].shape[0] - # (b, t) -> (t, b) for key in batch: if batch[key].ndim > 1: @@ -353,6 +348,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin): # Run latent rollout using the latent dynamics model and policy model. # Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action # gives us a next `z`. + batch_size = batch["index"].shape[0] z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device) z_preds[0] = self.model.encode(current_observation) reward_preds = torch.empty_like(reward, device=device) diff --git a/lerobot/common/policies/utils.py b/lerobot/common/policies/utils.py index 8f7b6eec..5a62daa2 100644 --- a/lerobot/common/policies/utils.py +++ b/lerobot/common/policies/utils.py @@ -19,6 +19,10 @@ from torch import nn def populate_queues(queues, batch): for key in batch: + # Ignore keys not in the queues already (leaving the responsibility to the caller to make sure the + # queues have the keys they want). + if key not in queues: + continue if len(queues[key]) != queues[key].maxlen: # initialize by copying the first observation several times until the queue is full while len(queues[key]) != queues[key].maxlen: diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index ab07695b..7ca7a0b3 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -23,6 +23,7 @@ import hydra import torch from datasets import concatenate_datasets from datasets.utils import disable_progress_bars, enable_progress_bars +from omegaconf import DictConfig from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.utils import cycle @@ -307,7 +308,7 @@ def add_episodes_inplace( sampler.num_samples = len(concat_dataset) -def train(cfg: dict, out_dir=None, job_name=None): +def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None): if out_dir is None: raise NotImplementedError() if job_name is None: diff --git a/tests/test_policies.py b/tests/test_policies.py index c8457854..75633fe6 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -64,6 +64,14 @@ def test_get_policy_and_config_classes(policy_name: str): "act", ["env.task=AlohaTransferCube-v0", "dataset_repo_id=lerobot/aloha_sim_transfer_cube_scripted"], ), + # Note: these parameters also need custom logic in the test function for overriding the Hydra config. + ( + "aloha", + "diffusion", + ["env.task=AlohaInsertion-v0", "dataset_repo_id=lerobot/aloha_sim_insertion_human"], + ), + # Note: these parameters also need custom logic in the test function for overriding the Hydra config. + ("pusht", "act", ["env.task=PushT-v0", "dataset_repo_id=lerobot/pusht"]), ], ) @require_env @@ -87,6 +95,31 @@ def test_policy(env_name, policy_name, extra_overrides): + extra_overrides, ) + # Additional config override logic. + if env_name == "aloha" and policy_name == "diffusion": + for keys in [ + ("training", "delta_timestamps"), + ("policy", "input_shapes"), + ("policy", "input_normalization_modes"), + ]: + dct = dict(cfg[keys[0]][keys[1]]) + dct["observation.images.top"] = dct["observation.image"] + del dct["observation.image"] + cfg[keys[0]][keys[1]] = dct + cfg.override_dataset_stats = None + + # Additional config override logic. + if env_name == "pusht" and policy_name == "act": + for keys in [ + ("policy", "input_shapes"), + ("policy", "input_normalization_modes"), + ]: + dct = dict(cfg[keys[0]][keys[1]]) + dct["observation.image"] = dct["observation.images.top"] + del dct["observation.images.top"] + cfg[keys[0]][keys[1]] = dct + cfg.override_dataset_stats = None + # Check that we can make the policy object. dataset = make_dataset(cfg) policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)