From 2298ddf2266df1fed33cb14d991b0cae9938d008 Mon Sep 17 00:00:00 2001
From: Alexander Soare <alexander.soare159@gmail.com>
Date: Wed, 17 Apr 2024 16:21:37 +0100
Subject: [PATCH] wip

---
 lerobot/common/policies/act/modeling_act.py   | 22 +++++++++----------
 .../policies/diffusion/modeling_diffusion.py  | 12 +++++++---
 lerobot/common/policies/policy_protocol.py    | 14 ++++++------
 3 files changed, 26 insertions(+), 22 deletions(-)

diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py
index 6b9b4e0a..c1af4ef4 100644
--- a/lerobot/common/policies/act/modeling_act.py
+++ b/lerobot/common/policies/act/modeling_act.py
@@ -160,9 +160,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
         if self.cfg.n_action_steps is not None:
             self._action_queue = deque([], maxlen=self.cfg.n_action_steps)
 
-    def forward(self, batch: dict[str, Tensor], **_) -> Tensor:
-        return self.select_action(self, batch)
-
     @torch.no_grad
     def select_action(self, batch: dict[str, Tensor], **_) -> Tensor:
         """Select a single action given environment observations.
@@ -178,14 +175,15 @@ class ActionChunkingTransformerPolicy(nn.Module):
             self._action_queue.extend(self._forward(batch)[0][: self.cfg.n_action_steps].transpose(0, 1))
         return self._action_queue.popleft()
 
-    def compute_loss(self, batch, **_) -> float:
-        """Runs the batch through the model and computes the loss for training or validation."""
+    def forward(self, batch, **_) -> dict[str, Tensor]:
+        """Run the batch through the model and compute the loss for training or validation."""
         actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward(batch)
 
         l1_loss = (
             F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
         ).mean()
 
+        loss_dict = {"l1_loss": l1_loss}
         if self.cfg.use_vae:
             # Calculate Dā‚–ā‚—(latent_pdf || standard_normal). Note: After computing the KL-divergence for
             # each dimension independently, we sum over the latent dimension to get the total
@@ -194,23 +192,23 @@ class ActionChunkingTransformerPolicy(nn.Module):
             mean_kld = (
                 (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
             )
-            loss = l1_loss + mean_kld * self.cfg.kl_weight
+            loss_dict["kld_loss"] = mean_kld
+            loss_dict["loss"] = l1_loss + mean_kld * self.cfg.kl_weight
         else:
-            loss = l1_loss
+            loss_dict["loss"] = l1_loss
 
-        return loss
+        return loss_dict
 
     def update(self, batch, **_) -> dict:
         """Run the model in train mode, compute the loss, and do an optimization step."""
         start_time = time.time()
         self.train()
-        loss = self.compute_loss(batch)
+        loss_dict = self.forward(batch)
+        loss = loss_dict["loss"]
         loss.backward()
 
         grad_norm = torch.nn.utils.clip_grad_norm_(
-            self.parameters(),
-            self.cfg.grad_clip_norm,
-            error_if_nonfinite=False,
+            self.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False
         )
 
         self.optimizer.step()
diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py
index dfab9bb7..e7cc62f4 100644
--- a/lerobot/common/policies/diffusion/modeling_diffusion.py
+++ b/lerobot/common/policies/diffusion/modeling_diffusion.py
@@ -43,12 +43,12 @@ class DiffusionPolicy(nn.Module):
     name = "diffusion"
 
     def __init__(self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0):
-        super().__init__()
         """
         Args:
             cfg: Policy configuration class instance or None, in which case the default instantiation of the
                  configuration class is used.
         """
+        super().__init__()
         # TODO(alexander-soare): LR scheduler will be removed.
         assert lr_scheduler_num_training_steps > 0
         if cfg is None:
@@ -140,12 +140,18 @@ class DiffusionPolicy(nn.Module):
         action = self._queues["action"].popleft()
         return action
 
-    def forward(self, batch, **_):
+    def forward(self, batch: dict[str, Tensor], **_) -> dict[str, Tensor]:
+        """Run the batch through the model and compute the loss for training or validation."""
+        loss = self.diffusion.compute_loss(batch)
+        return {"loss": loss}
+
+    def update(self, batch: dict[str, Tensor], **_) -> dict:
+        """Run the model in train mode, compute the loss, and do an optimization step."""
         start_time = time.time()
 
         self.diffusion.train()
 
-        loss = self.diffusion.compute_loss(batch)
+        loss = self.forward(batch)["loss"]
         loss.backward()
 
         grad_norm = torch.nn.utils.clip_grad_norm_(
diff --git a/lerobot/common/policies/policy_protocol.py b/lerobot/common/policies/policy_protocol.py
index 3a396e84..6401c734 100644
--- a/lerobot/common/policies/policy_protocol.py
+++ b/lerobot/common/policies/policy_protocol.py
@@ -24,20 +24,20 @@ class Policy(Protocol):
         Does things like clearing caches.
         """
 
-    def forward(self, batch: dict[str, Tensor], **kwargs):
-        """Wired to `select_action`."""
+    def forward(self, batch: dict[str, Tensor]) -> dict:
+        """Run the batch through the model and compute the loss for training or validation.
 
-    def select_action(self, batch: dict[str, Tensor], **kwargs):
+        Returns a dictionary with "loss" and maybe other information.
+        """
+
+    def select_action(self, batch: dict[str, Tensor]):
         """Return one action to run in the environment (potentially in batch mode).
 
         When the model uses a history of observations, or outputs a sequence of actions, this method deals
         with caching.
         """
 
-    def compute_loss(self, batch: dict[str, Tensor], **kwargs):
-        """Runs the batch through the model and computes the loss for training or validation."""
-
-    def update(self, batch, **kwargs):
+    def update(self, batch):
         """Does compute_loss then an optimization step.
 
         TODO(alexander-soare): We will move the optimization step back into the training loop, so this will