From 23be5e1e7bdee2ed36a96f3d9b3511943d883b50 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Tue, 16 Apr 2024 16:31:44 +0100 Subject: [PATCH 1/7] backup wip --- examples/3_train_policy.py | 2 +- lerobot/common/policies/act/modeling_act.py | 23 +++++++++++---------- lerobot/scripts/train.py | 4 ++-- tests/test_policies.py | 12 ++++++++--- 4 files changed, 24 insertions(+), 17 deletions(-) diff --git a/examples/3_train_policy.py b/examples/3_train_policy.py index d2e8b8c9..d3467562 100644 --- a/examples/3_train_policy.py +++ b/examples/3_train_policy.py @@ -55,7 +55,7 @@ while not done: for batch in dataloader: for k in batch: batch[k] = batch[k].to(device, non_blocking=True) - info = policy(batch) + info = policy.update(batch) if step % log_freq == 0: num_samples = (step + 1) * cfg.batch_size loss = info["loss"] diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 18ea3377..567721cd 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -161,6 +161,9 @@ class ActionChunkingTransformerPolicy(nn.Module): if self.cfg.n_action_steps is not None: self._action_queue = deque([], maxlen=self.cfg.n_action_steps) + def forward(self, batch: dict[str, Tensor], **_) -> Tensor: + return self.select_action(self, batch) + @torch.no_grad def select_action(self, batch: dict[str, Tensor], **_) -> Tensor: """Select a single action given environment observations. @@ -172,23 +175,17 @@ class ActionChunkingTransformerPolicy(nn.Module): if len(self._action_queue) == 0: # `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has # shape (n_action_steps, batch_size, *), hence the transpose. - self._action_queue.extend(self.select_actions(batch).transpose(0, 1)) + self._action_queue.extend(self._select_actions(batch).transpose(0, 1)) return self._action_queue.popleft() @torch.no_grad - def select_actions(self, batch: dict[str, Tensor]) -> Tensor: + def _select_actions(self, batch: dict[str, Tensor]) -> Tensor: """Use the action chunking transformer to generate a sequence of actions.""" self.eval() self._preprocess_batch(batch, add_obs_steps_dim=True) - action = self.forward(batch, return_loss=False) - return action[: self.cfg.n_action_steps] - def __call__(self, *args, **kwargs) -> dict: - # TODO(alexander-soare): Temporary bridge until we know what to do about the `update` method. - return self.update(*args, **kwargs) - def _preprocess_batch( self, batch: dict[str, Tensor], add_obs_steps_dim: bool = False ) -> dict[str, Tensor]: @@ -216,9 +213,7 @@ class ActionChunkingTransformerPolicy(nn.Module): # Note: no squeeze is required for "observation.images.top" because then we'd have to unsqueeze to get # the image index dimension. - def update(self, batch, **_) -> dict: - """Run the model in train mode, compute the loss, and do an optimization step.""" - start_time = time.time() + def compute_loss(self, batch, **_) -> float: self._preprocess_batch(batch) self.train() @@ -230,6 +225,12 @@ class ActionChunkingTransformerPolicy(nn.Module): assert batch_size % num_slices == 0 loss = self.forward(batch, return_loss=True)["loss"] + return loss + + def update(self, batch, **_) -> dict: + """Run the model in train mode, compute the loss, and do an optimization step.""" + start_time = time.time() + loss = self.compute_loss(batch) loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_( diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 5ff6538d..8e4c1961 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -251,7 +251,7 @@ def train(cfg: dict, out_dir=None, job_name=None): for key in batch: batch[key] = batch[key].to(cfg.device, non_blocking=True) - train_info = policy(batch, step=step) + train_info = policy.update(batch, step=step) # TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done? if step % cfg.log_freq == 0: @@ -313,7 +313,7 @@ def train(cfg: dict, out_dir=None, job_name=None): for key in batch: batch[key] = batch[key].to(cfg.device, non_blocking=True) - train_info = policy(batch, step) + train_info = policy.update(batch, step) if step % cfg.log_freq == 0: log_train_info(logger, train_info, step, cfg, online_dataset, is_offline) diff --git a/tests/test_policies.py b/tests/test_policies.py index 8ccc7c62..2547a3a2 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -4,11 +4,13 @@ import torch from lerobot.common.datasets.utils import cycle from lerobot.common.envs.utils import postprocess_action, preprocess_observation from lerobot.common.policies.factory import make_policy +from lerobot.common.policies.protocol import Policy from lerobot.common.envs.factory import make_env from lerobot.common.datasets.factory import make_dataset from lerobot.common.utils import init_hydra_config from .utils import DEVICE, DEFAULT_CONFIG_PATH + @pytest.mark.parametrize( "env_name,policy_name,extra_overrides", [ @@ -27,6 +29,7 @@ def test_policy(env_name, policy_name, extra_overrides): """ Tests: - Making the policy object. + - Checking that the policy follows the correct protocol. - Updating the policy. - Using the policy to select actions at inference time. - Test the action can be applied to the policy @@ -38,10 +41,14 @@ def test_policy(env_name, policy_name, extra_overrides): f"policy={policy_name}", f"device={DEVICE}", ] - + extra_overrides + + extra_overrides, ) # Check that we can make the policy object. policy = make_policy(cfg) + # Check that the policy follows the required protocol. + assert isinstance( + policy, Policy + ), f"The policy does not follow the required protocol. Please see {Policy.__module__}.{Policy.__name__}." # Check that we run select_actions and get the appropriate output. dataset = make_dataset(cfg) env = make_env(cfg, num_parallel_envs=2) @@ -62,7 +69,7 @@ def test_policy(env_name, policy_name, extra_overrides): batch[key] = batch[key].to(DEVICE, non_blocking=True) # Test updating the policy - policy(batch, step=0) + policy.update(batch, step=0) # reset the policy and environment policy.reset() @@ -83,4 +90,3 @@ def test_policy(env_name, policy_name, extra_overrides): # Test step through policy env.step(action) - From 8a322da42292cd4b1c19a33821a8ee4c57b5a3d6 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Tue, 16 Apr 2024 16:35:04 +0100 Subject: [PATCH 2/7] backup wip --- tests/test_policies.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_policies.py b/tests/test_policies.py index 2547a3a2..f53e402a 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -4,7 +4,7 @@ import torch from lerobot.common.datasets.utils import cycle from lerobot.common.envs.utils import postprocess_action, preprocess_observation from lerobot.common.policies.factory import make_policy -from lerobot.common.policies.protocol import Policy +from lerobot.common.policies.policy_protocol import Policy from lerobot.common.envs.factory import make_env from lerobot.common.datasets.factory import make_dataset from lerobot.common.utils import init_hydra_config From cb3978b5f3ac78fd03bc3398257f7ce18ad4345e Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Tue, 16 Apr 2024 18:12:39 +0100 Subject: [PATCH 3/7] backup wip --- lerobot/common/policies/act/modeling_act.py | 107 +++++++++++--------- lerobot/common/policies/policy_protocol.py | 45 ++++++++ 2 files changed, 102 insertions(+), 50 deletions(-) create mode 100644 lerobot/common/policies/policy_protocol.py diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 143d1459..e1576777 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -165,8 +165,8 @@ class ActionChunkingTransformerPolicy(nn.Module): if self.cfg.n_action_steps is not None: self._action_queue = deque([], maxlen=self.cfg.n_action_steps) - # def forward(self, batch: dict[str, Tensor], **_) -> Tensor: - # return self.select_action(self, batch) + def forward(self, batch: dict[str, Tensor], **_) -> Tensor: + return self.select_action(self, batch) @torch.no_grad def select_action(self, batch: dict[str, Tensor], **_) -> Tensor: @@ -186,20 +186,28 @@ class ActionChunkingTransformerPolicy(nn.Module): def _select_actions(self, batch: dict[str, Tensor]) -> Tensor: """Use the action chunking transformer to generate a sequence of actions.""" self.eval() - self._preprocess_batch(batch, add_obs_steps_dim=True) - action = self.forward(batch, return_loss=False) - return action[: self.cfg.n_action_steps] + batch = self._reshape_batch(batch, add_obs_steps_dim=True) + actions, _ = self._forward( + batch["observation.state"], self.image_normalizer(batch["observation.images.top"]) + ) + return actions[: self.cfg.n_action_steps] + + def _reshape_batch(self, batch: dict[str, Tensor], add_obs_steps_dim: bool = False) -> dict[str, Tensor]: + """Reshapes the batch items to account for various requirements of this policy. - def _preprocess_batch( - self, batch: dict[str, Tensor], add_obs_steps_dim: bool = False - ) -> dict[str, Tensor]: - """ This function expects `batch` to have (at least): { "observation.state": (B, 1, J) OR (B, J) tensor of robot states (joint configuration). "observation.images.top": (B, 1, C, H, W) OR (B, C, H, W) tensor of images. } + + TODO(alexander-soare): Right now this method does and undoes reshaping operations. This is just to + separate out the core logic from the temporary logic. See comments below. """ + # Create a shallow copy. + batch = dict(batch) + + # Add a dimension for observation steps. if add_obs_steps_dim: # Add a dimension for the observations steps. Since n_obs_steps > 1 is not supported right now, # this just amounts to an unsqueeze. @@ -207,18 +215,29 @@ class ActionChunkingTransformerPolicy(nn.Module): if k.startswith("observation."): batch[k] = batch[k].unsqueeze(1) - if batch["observation.state"].shape[1] != 1: + # Temporary logic to remove the observation step dimension as the policy does not yet handle it. + # TODO(alexander-soare): generalize this to multiple observations steps. + # Check that there is only 1 observation step (policy does not yet handle more). + if not all(batch[k].shape[1] == 1 for k in batch if k.startswith("observation.")): raise ValueError(self._multiple_obs_steps_not_handled_msg) - batch["observation.state"] = batch["observation.state"].squeeze(1) - # TODO(alexander-soare): generalize this to multiple images. + # Remove observation steps dimension. + for k in batch: + if k.startswith("observation."): + batch[k] = batch[k].squeeze(1) + + # Temporary logic to add the multiple image dimension back in. + # TODO(alexander-soare): generalize this to multiple images. Once resolved, this logic will stack all + # images. assert ( sum(k.startswith("observation.images.") and not k.endswith("is_pad") for k in batch) == 1 - ), "ACT only handles one image for now." - # Note: no squeeze is required for "observation.images.top" because then we'd have to unsqueeze to get - # the image index dimension. + ), f"{self.__class__.__name__} only handles one image for now." + # Since we only handle one image, just unsqueeze instead of stacking. + batch["observation.images.top"] = batch["observation.images.top"].unsqueeze(1) + + return batch def compute_loss(self, batch, **_) -> float: - self._preprocess_batch(batch) + batch = self._reshape_batch(batch) self.train() @@ -228,7 +247,28 @@ class ActionChunkingTransformerPolicy(nn.Module): assert batch_size % self.cfg.chunk_size == 0 assert batch_size % num_slices == 0 - loss = self.forward(batch, return_loss=True)["loss"] + actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward( + batch["observation.state"], + self.image_normalizer(batch["observation.images.top"]), + batch["action"], + ) + + l1_loss = ( + F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1) + ).mean() + + if self.cfg.use_vae: + # Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for + # each dimension independently, we sum over the latent dimension to get the total + # KL-divergence per batch element, then take the mean over the batch. + # (See App. B of https://arxiv.org/abs/1312.6114 for more details). + mean_kld = ( + (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean() + ) + loss = l1_loss + mean_kld * self.cfg.kl_weight + else: + loss = l1_loss + return loss def update(self, batch, **_) -> dict: @@ -255,39 +295,6 @@ class ActionChunkingTransformerPolicy(nn.Module): return info - def forward(self, batch: dict[str, Tensor], return_loss: bool = False) -> dict | Tensor: - """A forward pass through the DNN part of this policy with optional loss computation.""" - images = self.image_normalizer(batch["observation.images.top"]) - - if return_loss: # training time - actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward( - batch["observation.state"], images, batch["action"] - ) - - l1_loss = ( - F.l1_loss(batch["action"], actions_hat, reduction="none") - * ~batch["action_is_pad"].unsqueeze(-1) - ).mean() - - loss_dict = {} - loss_dict["l1"] = l1_loss - if self.cfg.use_vae: - # Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for - # each dimension independently, we sum over the latent dimension to get the total - # KL-divergence per batch element, then take the mean over the batch. - # (See App. B of https://arxiv.org/abs/1312.6114 for more details). - mean_kld = ( - (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean() - ) - loss_dict["kl"] = mean_kld - loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.cfg.kl_weight - else: - loss_dict["loss"] = loss_dict["l1"] - return loss_dict - else: - action, _ = self._forward(batch["observation.state"], images) - return action - def _forward( self, robot_state: Tensor, image: Tensor, actions: Tensor | None = None ) -> tuple[Tensor, tuple[Tensor | None, Tensor | None]]: diff --git a/lerobot/common/policies/policy_protocol.py b/lerobot/common/policies/policy_protocol.py new file mode 100644 index 00000000..3a396e84 --- /dev/null +++ b/lerobot/common/policies/policy_protocol.py @@ -0,0 +1,45 @@ +"""A protocol that all policies should follow. + +This provides a mechanism for type-hinting and isinstance checks without requiring the policies classes +subclass a base class. + +The protocol structure, method signatures, and docstrings should be used by developers as a reference for +how to implement new policies. +""" + +from typing import Protocol, runtime_checkable + +from torch import Tensor + + +@runtime_checkable +class Policy(Protocol): + """The required interface for implementing a policy.""" + + name: str + + def reset(self): + """To be called whenever the environment is reset. + + Does things like clearing caches. + """ + + def forward(self, batch: dict[str, Tensor], **kwargs): + """Wired to `select_action`.""" + + def select_action(self, batch: dict[str, Tensor], **kwargs): + """Return one action to run in the environment (potentially in batch mode). + + When the model uses a history of observations, or outputs a sequence of actions, this method deals + with caching. + """ + + def compute_loss(self, batch: dict[str, Tensor], **kwargs): + """Runs the batch through the model and computes the loss for training or validation.""" + + def update(self, batch, **kwargs): + """Does compute_loss then an optimization step. + + TODO(alexander-soare): We will move the optimization step back into the training loop, so this will + disappear. + """ From c50a13ab31e538f3ffea704e456f0c57f8d4be2f Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Wed, 17 Apr 2024 10:50:54 +0100 Subject: [PATCH 4/7] draft --- .../common/policies/act/configuration_act.py | 15 +- lerobot/common/policies/act/modeling_act.py | 165 +++++++----------- lerobot/configs/policy/act.yaml | 2 - 3 files changed, 79 insertions(+), 103 deletions(-) diff --git a/lerobot/common/policies/act/configuration_act.py b/lerobot/common/policies/act/configuration_act.py index 1b438f2d..211a8ed0 100644 --- a/lerobot/common/policies/act/configuration_act.py +++ b/lerobot/common/policies/act/configuration_act.py @@ -103,12 +103,21 @@ class ActionChunkingTransformerConfig: def __post_init__(self): """Input validation (not exhaustive).""" if not self.vision_backbone.startswith("resnet"): - raise ValueError("`vision_backbone` must be one of the ResNet variants.") + raise ValueError( + f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}." + ) if self.use_temporal_aggregation: raise NotImplementedError("Temporal aggregation is not yet implemented.") if self.n_action_steps > self.chunk_size: raise ValueError( - "The chunk size is the upper bound for the number of action steps per model invocation." + f"The chunk size is the upper bound for the number of action steps per model invocation. Got " + f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`." + ) + if self.n_obs_steps != 1: + raise ValueError( + f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`" ) if self.camera_names != ["top"]: - raise ValueError("For now, `camera_names` can only be ['top']") + raise ValueError(f"For now, `camera_names` can only be ['top']. Got {self.camera_names}.") + if len(set(self.camera_names)) != len(self.camera_names): + raise ValueError(f"`camera_names` should not have any repeated entries. Got {self.camera_names}.") diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index e1576777..af8566c7 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -20,7 +20,9 @@ from torch import Tensor, nn from torchvision.models._utils import IntermediateLayerGetter from torchvision.ops.misc import FrozenBatchNorm2d -from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig +from lerobot.common.policies.act.configuration_act import ( + ActionChunkingTransformerConfig, +) class ActionChunkingTransformerPolicy(nn.Module): @@ -61,9 +63,6 @@ class ActionChunkingTransformerPolicy(nn.Module): """ name = "act" - _multiple_obs_steps_not_handled_msg = ( - "ActionChunkingTransformerPolicy does not handle multiple observation steps." - ) def __init__(self, cfg: ActionChunkingTransformerConfig | None = None): """ @@ -74,8 +73,6 @@ class ActionChunkingTransformerPolicy(nn.Module): super().__init__() if cfg is None: cfg = ActionChunkingTransformerConfig() - if cfg.n_obs_steps != 1: - raise ValueError(self._multiple_obs_steps_not_handled_msg) self.cfg = cfg # BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence]. @@ -102,7 +99,11 @@ class ActionChunkingTransformerPolicy(nn.Module): mean=cfg.image_normalization_mean, std=cfg.image_normalization_std ) backbone_model = getattr(torchvision.models, cfg.vision_backbone)( - replace_stride_with_dilation=[False, False, cfg.replace_final_stride_with_dilation], + replace_stride_with_dilation=[ + False, + False, + cfg.replace_final_stride_with_dilation, + ], pretrained=cfg.use_pretrained_backbone, norm_layer=FrozenBatchNorm2d, ) @@ -176,82 +177,16 @@ class ActionChunkingTransformerPolicy(nn.Module): environment. It works by managing the actions in a queue and only calling `select_actions` when the queue is empty. """ + self.eval() if len(self._action_queue) == 0: - # `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has - # shape (n_action_steps, batch_size, *), hence the transpose. - self._action_queue.extend(self._select_actions(batch).transpose(0, 1)) + # `_forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue effectively + # has shape (n_action_steps, batch_size, *), hence the transpose. + self._action_queue.extend(self._forward(batch)[0][: self.cfg.n_action_steps].transpose(0, 1)) return self._action_queue.popleft() - @torch.no_grad - def _select_actions(self, batch: dict[str, Tensor]) -> Tensor: - """Use the action chunking transformer to generate a sequence of actions.""" - self.eval() - batch = self._reshape_batch(batch, add_obs_steps_dim=True) - actions, _ = self._forward( - batch["observation.state"], self.image_normalizer(batch["observation.images.top"]) - ) - return actions[: self.cfg.n_action_steps] - - def _reshape_batch(self, batch: dict[str, Tensor], add_obs_steps_dim: bool = False) -> dict[str, Tensor]: - """Reshapes the batch items to account for various requirements of this policy. - - This function expects `batch` to have (at least): - { - "observation.state": (B, 1, J) OR (B, J) tensor of robot states (joint configuration). - "observation.images.top": (B, 1, C, H, W) OR (B, C, H, W) tensor of images. - } - - TODO(alexander-soare): Right now this method does and undoes reshaping operations. This is just to - separate out the core logic from the temporary logic. See comments below. - """ - # Create a shallow copy. - batch = dict(batch) - - # Add a dimension for observation steps. - if add_obs_steps_dim: - # Add a dimension for the observations steps. Since n_obs_steps > 1 is not supported right now, - # this just amounts to an unsqueeze. - for k in batch: - if k.startswith("observation."): - batch[k] = batch[k].unsqueeze(1) - - # Temporary logic to remove the observation step dimension as the policy does not yet handle it. - # TODO(alexander-soare): generalize this to multiple observations steps. - # Check that there is only 1 observation step (policy does not yet handle more). - if not all(batch[k].shape[1] == 1 for k in batch if k.startswith("observation.")): - raise ValueError(self._multiple_obs_steps_not_handled_msg) - # Remove observation steps dimension. - for k in batch: - if k.startswith("observation."): - batch[k] = batch[k].squeeze(1) - - # Temporary logic to add the multiple image dimension back in. - # TODO(alexander-soare): generalize this to multiple images. Once resolved, this logic will stack all - # images. - assert ( - sum(k.startswith("observation.images.") and not k.endswith("is_pad") for k in batch) == 1 - ), f"{self.__class__.__name__} only handles one image for now." - # Since we only handle one image, just unsqueeze instead of stacking. - batch["observation.images.top"] = batch["observation.images.top"].unsqueeze(1) - - return batch - def compute_loss(self, batch, **_) -> float: - batch = self._reshape_batch(batch) - - self.train() - - num_slices = self.cfg.batch_size - batch_size = self.cfg.chunk_size * num_slices - - assert batch_size % self.cfg.chunk_size == 0 - assert batch_size % num_slices == 0 - - actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward( - batch["observation.state"], - self.image_normalizer(batch["observation.images.top"]), - batch["action"], - ) + """Runs the batch through the model and computes the loss for training or validation.""" + actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward(batch) l1_loss = ( F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1) @@ -274,6 +209,7 @@ class ActionChunkingTransformerPolicy(nn.Module): def update(self, batch, **_) -> dict: """Run the model in train mode, compute the loss, and do an optimization step.""" start_time = time.time() + self.train() loss = self.compute_loss(batch) loss.backward() @@ -295,35 +231,64 @@ class ActionChunkingTransformerPolicy(nn.Module): return info - def _forward( - self, robot_state: Tensor, image: Tensor, actions: Tensor | None = None - ) -> tuple[Tensor, tuple[Tensor | None, Tensor | None]]: + def _stack_images(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + """Stacks all the images in a batch and puts them in a new key: "observation.images". + + This function expects `batch` to have (at least): + { + "observation.state": (B, state_dim) batch of robot states. + "observation.images.{name}": (B, C, H, W) tensor of images. + } """ - Args: - robot_state: (B, J) batch of robot joint configurations. - image: (B, N, C, H, W) batch of N camera frames. - actions: (B, S, A) batch of actions from the target dataset which must be provided if the - VAE is enabled and the model is in training mode. + # Check that there is only one image. + # TODO(alexander-soare): generalize this to multiple images. + provided_cameras = {k.rsplit(".", 1)[-1] for k in batch if k.startswith("observation.images.")} + if len(missing := set(self.cfg.camera_names).difference(provided_cameras)) > 0: + raise ValueError( + f"The following camera images are missing from the provided batch: {missing}. Check the " + "configuration parameter: `camera_names`." + ) + # Stack images in the order dictated by the camera names. + batch["observation.images"] = torch.stack( + [batch[f"observation.images.{name}"] for name in self.cfg.camera_names], + dim=-4, + ) + + def _forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]: + """A forward pass through the Action Chunking Transformer (with optional VAE encoder). + + `batch` should have the following structure: + + { + "observation.state": (B, state_dim) batch of robot states. + "observation.images": (B, n_cameras, C, H, W) batch of images. + "action" (optional, only if training with VAE): (B, chunk_size, action dim) batch of actions. + } + Returns: - (B, S, A) batch of action sequences + (B, chunk_size, action_dim) batch of action sequences Tuple containing the latent PDF's parameters (mean, log(σ²)) both as (B, L) tensors where L is the latent dimension. """ if self.cfg.use_vae and self.training: assert ( - actions is not None + "action" in batch ), "actions must be provided when using the variational objective in training mode." - batch_size = robot_state.shape[0] + self._stack_images(batch) + + batch_size = batch["observation.state"].shape[0] # Prepare the latent for input to the transformer encoder. - if self.cfg.use_vae and actions is not None: + if self.cfg.use_vae and "action" in batch: # Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence]. cls_embed = einops.repeat( self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size ) # (B, 1, D) - robot_state_embed = self.vae_encoder_robot_state_input_proj(robot_state).unsqueeze(1) # (B, 1, D) - action_embed = self.vae_encoder_action_input_proj(actions) # (B, S, D) + robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"]).unsqueeze( + 1 + ) # (B, 1, D) + action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D) vae_encoder_input = torch.cat([cls_embed, robot_state_embed, action_embed], axis=1) # (B, S+2, D) # Prepare fixed positional embedding. @@ -345,15 +310,16 @@ class ActionChunkingTransformerPolicy(nn.Module): # When not using the VAE encoder, we set the latent to be all zeros. mu = log_sigma_x2 = None latent_sample = torch.zeros([batch_size, self.latent_dim], dtype=torch.float32).to( - robot_state.device + batch["observation.state"].device ) # Prepare all other transformer encoder inputs. # Camera observation features and positional embeddings. all_cam_features = [] all_cam_pos_embeds = [] - for cam_id, _ in enumerate(self.cfg.camera_names): - cam_features = self.backbone(image[:, cam_id])["feature_map"] + images = self.image_normalizer(batch["observation.images"]) + for cam_index in range(len(self.cfg.camera_names)): + cam_features = self.backbone(images[:, cam_index])["feature_map"] cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype) cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w) all_cam_features.append(cam_features) @@ -363,7 +329,7 @@ class ActionChunkingTransformerPolicy(nn.Module): cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=3) # Get positional embeddings for robot state and latent. - robot_state_embed = self.encoder_robot_state_input_proj(robot_state) + robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) latent_embed = self.encoder_latent_input_proj(latent_sample) # Stack encoder input and positional embeddings moving to (S, B, C). @@ -479,7 +445,10 @@ class _TransformerDecoder(nn.Module): ) -> Tensor: for layer in self.layers: x = layer( - x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed + x, + encoder_out, + decoder_pos_embed=decoder_pos_embed, + encoder_pos_embed=encoder_pos_embed, ) if self.norm is not None: x = self.norm(x) diff --git a/lerobot/configs/policy/act.yaml b/lerobot/configs/policy/act.yaml index 5dd70d71..eb4e512b 100644 --- a/lerobot/configs/policy/act.yaml +++ b/lerobot/configs/policy/act.yaml @@ -67,6 +67,4 @@ policy: utd: 1 delta_timestamps: - observation.images.top: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1)]" - observation.state: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1)]" action: "[i / ${fps} for i in range(${policy.chunk_size})]" From 63e5ec64837ae765289d08b3766300105a31198a Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Wed, 17 Apr 2024 11:01:01 +0100 Subject: [PATCH 5/7] revert some formatting changes --- lerobot/common/policies/act/modeling_act.py | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index af8566c7..6b9b4e0a 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -20,9 +20,7 @@ from torch import Tensor, nn from torchvision.models._utils import IntermediateLayerGetter from torchvision.ops.misc import FrozenBatchNorm2d -from lerobot.common.policies.act.configuration_act import ( - ActionChunkingTransformerConfig, -) +from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig class ActionChunkingTransformerPolicy(nn.Module): @@ -99,11 +97,7 @@ class ActionChunkingTransformerPolicy(nn.Module): mean=cfg.image_normalization_mean, std=cfg.image_normalization_std ) backbone_model = getattr(torchvision.models, cfg.vision_backbone)( - replace_stride_with_dilation=[ - False, - False, - cfg.replace_final_stride_with_dilation, - ], + replace_stride_with_dilation=[False, False, cfg.replace_final_stride_with_dilation], pretrained=cfg.use_pretrained_backbone, norm_layer=FrozenBatchNorm2d, ) @@ -445,10 +439,7 @@ class _TransformerDecoder(nn.Module): ) -> Tensor: for layer in self.layers: x = layer( - x, - encoder_out, - decoder_pos_embed=decoder_pos_embed, - encoder_pos_embed=encoder_pos_embed, + x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed ) if self.norm is not None: x = self.norm(x) From 2298ddf2266df1fed33cb14d991b0cae9938d008 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Wed, 17 Apr 2024 16:21:37 +0100 Subject: [PATCH 6/7] wip --- lerobot/common/policies/act/modeling_act.py | 22 +++++++++---------- .../policies/diffusion/modeling_diffusion.py | 12 +++++++--- lerobot/common/policies/policy_protocol.py | 14 ++++++------ 3 files changed, 26 insertions(+), 22 deletions(-) diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 6b9b4e0a..c1af4ef4 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -160,9 +160,6 @@ class ActionChunkingTransformerPolicy(nn.Module): if self.cfg.n_action_steps is not None: self._action_queue = deque([], maxlen=self.cfg.n_action_steps) - def forward(self, batch: dict[str, Tensor], **_) -> Tensor: - return self.select_action(self, batch) - @torch.no_grad def select_action(self, batch: dict[str, Tensor], **_) -> Tensor: """Select a single action given environment observations. @@ -178,14 +175,15 @@ class ActionChunkingTransformerPolicy(nn.Module): self._action_queue.extend(self._forward(batch)[0][: self.cfg.n_action_steps].transpose(0, 1)) return self._action_queue.popleft() - def compute_loss(self, batch, **_) -> float: - """Runs the batch through the model and computes the loss for training or validation.""" + def forward(self, batch, **_) -> dict[str, Tensor]: + """Run the batch through the model and compute the loss for training or validation.""" actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward(batch) l1_loss = ( F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1) ).mean() + loss_dict = {"l1_loss": l1_loss} if self.cfg.use_vae: # Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for # each dimension independently, we sum over the latent dimension to get the total @@ -194,23 +192,23 @@ class ActionChunkingTransformerPolicy(nn.Module): mean_kld = ( (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean() ) - loss = l1_loss + mean_kld * self.cfg.kl_weight + loss_dict["kld_loss"] = mean_kld + loss_dict["loss"] = l1_loss + mean_kld * self.cfg.kl_weight else: - loss = l1_loss + loss_dict["loss"] = l1_loss - return loss + return loss_dict def update(self, batch, **_) -> dict: """Run the model in train mode, compute the loss, and do an optimization step.""" start_time = time.time() self.train() - loss = self.compute_loss(batch) + loss_dict = self.forward(batch) + loss = loss_dict["loss"] loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_( - self.parameters(), - self.cfg.grad_clip_norm, - error_if_nonfinite=False, + self.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False ) self.optimizer.step() diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index dfab9bb7..e7cc62f4 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -43,12 +43,12 @@ class DiffusionPolicy(nn.Module): name = "diffusion" def __init__(self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0): - super().__init__() """ Args: cfg: Policy configuration class instance or None, in which case the default instantiation of the configuration class is used. """ + super().__init__() # TODO(alexander-soare): LR scheduler will be removed. assert lr_scheduler_num_training_steps > 0 if cfg is None: @@ -140,12 +140,18 @@ class DiffusionPolicy(nn.Module): action = self._queues["action"].popleft() return action - def forward(self, batch, **_): + def forward(self, batch: dict[str, Tensor], **_) -> dict[str, Tensor]: + """Run the batch through the model and compute the loss for training or validation.""" + loss = self.diffusion.compute_loss(batch) + return {"loss": loss} + + def update(self, batch: dict[str, Tensor], **_) -> dict: + """Run the model in train mode, compute the loss, and do an optimization step.""" start_time = time.time() self.diffusion.train() - loss = self.diffusion.compute_loss(batch) + loss = self.forward(batch)["loss"] loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_( diff --git a/lerobot/common/policies/policy_protocol.py b/lerobot/common/policies/policy_protocol.py index 3a396e84..6401c734 100644 --- a/lerobot/common/policies/policy_protocol.py +++ b/lerobot/common/policies/policy_protocol.py @@ -24,20 +24,20 @@ class Policy(Protocol): Does things like clearing caches. """ - def forward(self, batch: dict[str, Tensor], **kwargs): - """Wired to `select_action`.""" + def forward(self, batch: dict[str, Tensor]) -> dict: + """Run the batch through the model and compute the loss for training or validation. - def select_action(self, batch: dict[str, Tensor], **kwargs): + Returns a dictionary with "loss" and maybe other information. + """ + + def select_action(self, batch: dict[str, Tensor]): """Return one action to run in the environment (potentially in batch mode). When the model uses a history of observations, or outputs a sequence of actions, this method deals with caching. """ - def compute_loss(self, batch: dict[str, Tensor], **kwargs): - """Runs the batch through the model and computes the loss for training or validation.""" - - def update(self, batch, **kwargs): + def update(self, batch): """Does compute_loss then an optimization step. TODO(alexander-soare): We will move the optimization step back into the training loop, so this will From dd9c6eed15efb28d82c5742c59d09560417cb65d Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Wed, 17 Apr 2024 16:27:57 +0100 Subject: [PATCH 7/7] Add temporary patch in TD-MPC --- lerobot/common/policies/tdmpc/policy.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/lerobot/common/policies/tdmpc/policy.py b/lerobot/common/policies/tdmpc/policy.py index 14728576..ed28c4a6 100644 --- a/lerobot/common/policies/tdmpc/policy.py +++ b/lerobot/common/policies/tdmpc/policy.py @@ -330,6 +330,10 @@ class TDMPCPolicy(nn.Module): return td_target def forward(self, batch, step): + # TODO(alexander-soare): Refactor TDMPC and make it comply with the policy interface documentation. + raise NotImplementedError() + + def update(self, batch, step): """Main update function. Corresponds to one iteration of the model learning.""" start_time = time.time()