diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 143d1459..e1576777 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -165,8 +165,8 @@ class ActionChunkingTransformerPolicy(nn.Module): if self.cfg.n_action_steps is not None: self._action_queue = deque([], maxlen=self.cfg.n_action_steps) - # def forward(self, batch: dict[str, Tensor], **_) -> Tensor: - # return self.select_action(self, batch) + def forward(self, batch: dict[str, Tensor], **_) -> Tensor: + return self.select_action(self, batch) @torch.no_grad def select_action(self, batch: dict[str, Tensor], **_) -> Tensor: @@ -186,20 +186,28 @@ class ActionChunkingTransformerPolicy(nn.Module): def _select_actions(self, batch: dict[str, Tensor]) -> Tensor: """Use the action chunking transformer to generate a sequence of actions.""" self.eval() - self._preprocess_batch(batch, add_obs_steps_dim=True) - action = self.forward(batch, return_loss=False) - return action[: self.cfg.n_action_steps] + batch = self._reshape_batch(batch, add_obs_steps_dim=True) + actions, _ = self._forward( + batch["observation.state"], self.image_normalizer(batch["observation.images.top"]) + ) + return actions[: self.cfg.n_action_steps] + + def _reshape_batch(self, batch: dict[str, Tensor], add_obs_steps_dim: bool = False) -> dict[str, Tensor]: + """Reshapes the batch items to account for various requirements of this policy. - def _preprocess_batch( - self, batch: dict[str, Tensor], add_obs_steps_dim: bool = False - ) -> dict[str, Tensor]: - """ This function expects `batch` to have (at least): { "observation.state": (B, 1, J) OR (B, J) tensor of robot states (joint configuration). "observation.images.top": (B, 1, C, H, W) OR (B, C, H, W) tensor of images. } + + TODO(alexander-soare): Right now this method does and undoes reshaping operations. This is just to + separate out the core logic from the temporary logic. See comments below. """ + # Create a shallow copy. + batch = dict(batch) + + # Add a dimension for observation steps. if add_obs_steps_dim: # Add a dimension for the observations steps. Since n_obs_steps > 1 is not supported right now, # this just amounts to an unsqueeze. @@ -207,18 +215,29 @@ class ActionChunkingTransformerPolicy(nn.Module): if k.startswith("observation."): batch[k] = batch[k].unsqueeze(1) - if batch["observation.state"].shape[1] != 1: + # Temporary logic to remove the observation step dimension as the policy does not yet handle it. + # TODO(alexander-soare): generalize this to multiple observations steps. + # Check that there is only 1 observation step (policy does not yet handle more). + if not all(batch[k].shape[1] == 1 for k in batch if k.startswith("observation.")): raise ValueError(self._multiple_obs_steps_not_handled_msg) - batch["observation.state"] = batch["observation.state"].squeeze(1) - # TODO(alexander-soare): generalize this to multiple images. + # Remove observation steps dimension. + for k in batch: + if k.startswith("observation."): + batch[k] = batch[k].squeeze(1) + + # Temporary logic to add the multiple image dimension back in. + # TODO(alexander-soare): generalize this to multiple images. Once resolved, this logic will stack all + # images. assert ( sum(k.startswith("observation.images.") and not k.endswith("is_pad") for k in batch) == 1 - ), "ACT only handles one image for now." - # Note: no squeeze is required for "observation.images.top" because then we'd have to unsqueeze to get - # the image index dimension. + ), f"{self.__class__.__name__} only handles one image for now." + # Since we only handle one image, just unsqueeze instead of stacking. + batch["observation.images.top"] = batch["observation.images.top"].unsqueeze(1) + + return batch def compute_loss(self, batch, **_) -> float: - self._preprocess_batch(batch) + batch = self._reshape_batch(batch) self.train() @@ -228,7 +247,28 @@ class ActionChunkingTransformerPolicy(nn.Module): assert batch_size % self.cfg.chunk_size == 0 assert batch_size % num_slices == 0 - loss = self.forward(batch, return_loss=True)["loss"] + actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward( + batch["observation.state"], + self.image_normalizer(batch["observation.images.top"]), + batch["action"], + ) + + l1_loss = ( + F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1) + ).mean() + + if self.cfg.use_vae: + # Calculate Dā‚–ā‚—(latent_pdf || standard_normal). Note: After computing the KL-divergence for + # each dimension independently, we sum over the latent dimension to get the total + # KL-divergence per batch element, then take the mean over the batch. + # (See App. B of https://arxiv.org/abs/1312.6114 for more details). + mean_kld = ( + (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean() + ) + loss = l1_loss + mean_kld * self.cfg.kl_weight + else: + loss = l1_loss + return loss def update(self, batch, **_) -> dict: @@ -255,39 +295,6 @@ class ActionChunkingTransformerPolicy(nn.Module): return info - def forward(self, batch: dict[str, Tensor], return_loss: bool = False) -> dict | Tensor: - """A forward pass through the DNN part of this policy with optional loss computation.""" - images = self.image_normalizer(batch["observation.images.top"]) - - if return_loss: # training time - actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward( - batch["observation.state"], images, batch["action"] - ) - - l1_loss = ( - F.l1_loss(batch["action"], actions_hat, reduction="none") - * ~batch["action_is_pad"].unsqueeze(-1) - ).mean() - - loss_dict = {} - loss_dict["l1"] = l1_loss - if self.cfg.use_vae: - # Calculate Dā‚–ā‚—(latent_pdf || standard_normal). Note: After computing the KL-divergence for - # each dimension independently, we sum over the latent dimension to get the total - # KL-divergence per batch element, then take the mean over the batch. - # (See App. B of https://arxiv.org/abs/1312.6114 for more details). - mean_kld = ( - (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean() - ) - loss_dict["kl"] = mean_kld - loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.cfg.kl_weight - else: - loss_dict["loss"] = loss_dict["l1"] - return loss_dict - else: - action, _ = self._forward(batch["observation.state"], images) - return action - def _forward( self, robot_state: Tensor, image: Tensor, actions: Tensor | None = None ) -> tuple[Tensor, tuple[Tensor | None, Tensor | None]]: diff --git a/lerobot/common/policies/policy_protocol.py b/lerobot/common/policies/policy_protocol.py new file mode 100644 index 00000000..3a396e84 --- /dev/null +++ b/lerobot/common/policies/policy_protocol.py @@ -0,0 +1,45 @@ +"""A protocol that all policies should follow. + +This provides a mechanism for type-hinting and isinstance checks without requiring the policies classes +subclass a base class. + +The protocol structure, method signatures, and docstrings should be used by developers as a reference for +how to implement new policies. +""" + +from typing import Protocol, runtime_checkable + +from torch import Tensor + + +@runtime_checkable +class Policy(Protocol): + """The required interface for implementing a policy.""" + + name: str + + def reset(self): + """To be called whenever the environment is reset. + + Does things like clearing caches. + """ + + def forward(self, batch: dict[str, Tensor], **kwargs): + """Wired to `select_action`.""" + + def select_action(self, batch: dict[str, Tensor], **kwargs): + """Return one action to run in the environment (potentially in batch mode). + + When the model uses a history of observations, or outputs a sequence of actions, this method deals + with caching. + """ + + def compute_loss(self, batch: dict[str, Tensor], **kwargs): + """Runs the batch through the model and computes the loss for training or validation.""" + + def update(self, batch, **kwargs): + """Does compute_loss then an optimization step. + + TODO(alexander-soare): We will move the optimization step back into the training loop, so this will + disappear. + """