From cb3978b5f3ac78fd03bc3398257f7ce18ad4345e Mon Sep 17 00:00:00 2001
From: Alexander Soare <alexander.soare159@gmail.com>
Date: Tue, 16 Apr 2024 18:12:39 +0100
Subject: [PATCH] backup wip

---
 lerobot/common/policies/act/modeling_act.py | 107 +++++++++++---------
 lerobot/common/policies/policy_protocol.py  |  45 ++++++++
 2 files changed, 102 insertions(+), 50 deletions(-)
 create mode 100644 lerobot/common/policies/policy_protocol.py

diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py
index 143d1459..e1576777 100644
--- a/lerobot/common/policies/act/modeling_act.py
+++ b/lerobot/common/policies/act/modeling_act.py
@@ -165,8 +165,8 @@ class ActionChunkingTransformerPolicy(nn.Module):
         if self.cfg.n_action_steps is not None:
             self._action_queue = deque([], maxlen=self.cfg.n_action_steps)
 
-    # def forward(self, batch: dict[str, Tensor], **_) -> Tensor:
-    #     return self.select_action(self, batch)
+    def forward(self, batch: dict[str, Tensor], **_) -> Tensor:
+        return self.select_action(self, batch)
 
     @torch.no_grad
     def select_action(self, batch: dict[str, Tensor], **_) -> Tensor:
@@ -186,20 +186,28 @@ class ActionChunkingTransformerPolicy(nn.Module):
     def _select_actions(self, batch: dict[str, Tensor]) -> Tensor:
         """Use the action chunking transformer to generate a sequence of actions."""
         self.eval()
-        self._preprocess_batch(batch, add_obs_steps_dim=True)
-        action = self.forward(batch, return_loss=False)
-        return action[: self.cfg.n_action_steps]
+        batch = self._reshape_batch(batch, add_obs_steps_dim=True)
+        actions, _ = self._forward(
+            batch["observation.state"], self.image_normalizer(batch["observation.images.top"])
+        )
+        return actions[: self.cfg.n_action_steps]
+
+    def _reshape_batch(self, batch: dict[str, Tensor], add_obs_steps_dim: bool = False) -> dict[str, Tensor]:
+        """Reshapes the batch items to account for various requirements of this policy.
 
-    def _preprocess_batch(
-        self, batch: dict[str, Tensor], add_obs_steps_dim: bool = False
-    ) -> dict[str, Tensor]:
-        """
         This function expects `batch` to have (at least):
         {
             "observation.state": (B, 1, J) OR (B, J) tensor of robot states (joint configuration).
             "observation.images.top": (B, 1, C, H, W) OR (B, C, H, W) tensor of images.
         }
+
+        TODO(alexander-soare): Right now this method does and undoes reshaping operations. This is just to
+        separate out the core logic from the temporary logic. See comments below.
         """
+        # Create a shallow copy.
+        batch = dict(batch)
+
+        # Add a dimension for observation steps.
         if add_obs_steps_dim:
             # Add a dimension for the observations steps. Since n_obs_steps > 1 is not supported right now,
             # this just amounts to an unsqueeze.
@@ -207,18 +215,29 @@ class ActionChunkingTransformerPolicy(nn.Module):
                 if k.startswith("observation."):
                     batch[k] = batch[k].unsqueeze(1)
 
-        if batch["observation.state"].shape[1] != 1:
+        # Temporary logic to remove the observation step dimension as the policy does not yet handle it.
+        # TODO(alexander-soare): generalize this to multiple observations steps.
+        # Check that there is only 1 observation step (policy does not yet handle more).
+        if not all(batch[k].shape[1] == 1 for k in batch if k.startswith("observation.")):
             raise ValueError(self._multiple_obs_steps_not_handled_msg)
-        batch["observation.state"] = batch["observation.state"].squeeze(1)
-        # TODO(alexander-soare): generalize this to multiple images.
+        # Remove observation steps dimension.
+        for k in batch:
+            if k.startswith("observation."):
+                batch[k] = batch[k].squeeze(1)
+
+        # Temporary logic to add the multiple image dimension back in.
+        # TODO(alexander-soare): generalize this to multiple images. Once resolved, this logic will stack all
+        # images.
         assert (
             sum(k.startswith("observation.images.") and not k.endswith("is_pad") for k in batch) == 1
-        ), "ACT only handles one image for now."
-        # Note: no squeeze is required for "observation.images.top" because then we'd have to unsqueeze to get
-        # the image index dimension.
+        ), f"{self.__class__.__name__} only handles one image for now."
+        # Since we only handle one image, just unsqueeze instead of stacking.
+        batch["observation.images.top"] = batch["observation.images.top"].unsqueeze(1)
+
+        return batch
 
     def compute_loss(self, batch, **_) -> float:
-        self._preprocess_batch(batch)
+        batch = self._reshape_batch(batch)
 
         self.train()
 
@@ -228,7 +247,28 @@ class ActionChunkingTransformerPolicy(nn.Module):
         assert batch_size % self.cfg.chunk_size == 0
         assert batch_size % num_slices == 0
 
-        loss = self.forward(batch, return_loss=True)["loss"]
+        actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward(
+            batch["observation.state"],
+            self.image_normalizer(batch["observation.images.top"]),
+            batch["action"],
+        )
+
+        l1_loss = (
+            F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
+        ).mean()
+
+        if self.cfg.use_vae:
+            # Calculate Dā‚–ā‚—(latent_pdf || standard_normal). Note: After computing the KL-divergence for
+            # each dimension independently, we sum over the latent dimension to get the total
+            # KL-divergence per batch element, then take the mean over the batch.
+            # (See App. B of https://arxiv.org/abs/1312.6114 for more details).
+            mean_kld = (
+                (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
+            )
+            loss = l1_loss + mean_kld * self.cfg.kl_weight
+        else:
+            loss = l1_loss
+
         return loss
 
     def update(self, batch, **_) -> dict:
@@ -255,39 +295,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
 
         return info
 
-    def forward(self, batch: dict[str, Tensor], return_loss: bool = False) -> dict | Tensor:
-        """A forward pass through the DNN part of this policy with optional loss computation."""
-        images = self.image_normalizer(batch["observation.images.top"])
-
-        if return_loss:  # training time
-            actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward(
-                batch["observation.state"], images, batch["action"]
-            )
-
-            l1_loss = (
-                F.l1_loss(batch["action"], actions_hat, reduction="none")
-                * ~batch["action_is_pad"].unsqueeze(-1)
-            ).mean()
-
-            loss_dict = {}
-            loss_dict["l1"] = l1_loss
-            if self.cfg.use_vae:
-                # Calculate Dā‚–ā‚—(latent_pdf || standard_normal). Note: After computing the KL-divergence for
-                # each dimension independently, we sum over the latent dimension to get the total
-                # KL-divergence per batch element, then take the mean over the batch.
-                # (See App. B of https://arxiv.org/abs/1312.6114 for more details).
-                mean_kld = (
-                    (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
-                )
-                loss_dict["kl"] = mean_kld
-                loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.cfg.kl_weight
-            else:
-                loss_dict["loss"] = loss_dict["l1"]
-            return loss_dict
-        else:
-            action, _ = self._forward(batch["observation.state"], images)
-            return action
-
     def _forward(
         self, robot_state: Tensor, image: Tensor, actions: Tensor | None = None
     ) -> tuple[Tensor, tuple[Tensor | None, Tensor | None]]:
diff --git a/lerobot/common/policies/policy_protocol.py b/lerobot/common/policies/policy_protocol.py
new file mode 100644
index 00000000..3a396e84
--- /dev/null
+++ b/lerobot/common/policies/policy_protocol.py
@@ -0,0 +1,45 @@
+"""A protocol that all policies should follow.
+
+This provides a mechanism for type-hinting and isinstance checks without requiring the policies classes
+subclass a base class.
+
+The protocol structure, method signatures, and docstrings should be used by developers as a reference for
+how to implement new policies.
+"""
+
+from typing import Protocol, runtime_checkable
+
+from torch import Tensor
+
+
+@runtime_checkable
+class Policy(Protocol):
+    """The required interface for implementing a policy."""
+
+    name: str
+
+    def reset(self):
+        """To be called whenever the environment is reset.
+
+        Does things like clearing caches.
+        """
+
+    def forward(self, batch: dict[str, Tensor], **kwargs):
+        """Wired to `select_action`."""
+
+    def select_action(self, batch: dict[str, Tensor], **kwargs):
+        """Return one action to run in the environment (potentially in batch mode).
+
+        When the model uses a history of observations, or outputs a sequence of actions, this method deals
+        with caching.
+        """
+
+    def compute_loss(self, batch: dict[str, Tensor], **kwargs):
+        """Runs the batch through the model and computes the loss for training or validation."""
+
+    def update(self, batch, **kwargs):
+        """Does compute_loss then an optimization step.
+
+        TODO(alexander-soare): We will move the optimization step back into the training loop, so this will
+        disappear.
+        """