diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 6b9b4e0a..c1af4ef4 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -160,9 +160,6 @@ class ActionChunkingTransformerPolicy(nn.Module): if self.cfg.n_action_steps is not None: self._action_queue = deque([], maxlen=self.cfg.n_action_steps) - def forward(self, batch: dict[str, Tensor], **_) -> Tensor: - return self.select_action(self, batch) - @torch.no_grad def select_action(self, batch: dict[str, Tensor], **_) -> Tensor: """Select a single action given environment observations. @@ -178,14 +175,15 @@ class ActionChunkingTransformerPolicy(nn.Module): self._action_queue.extend(self._forward(batch)[0][: self.cfg.n_action_steps].transpose(0, 1)) return self._action_queue.popleft() - def compute_loss(self, batch, **_) -> float: - """Runs the batch through the model and computes the loss for training or validation.""" + def forward(self, batch, **_) -> dict[str, Tensor]: + """Run the batch through the model and compute the loss for training or validation.""" actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward(batch) l1_loss = ( F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1) ).mean() + loss_dict = {"l1_loss": l1_loss} if self.cfg.use_vae: # Calculate Dā‚–ā‚—(latent_pdf || standard_normal). Note: After computing the KL-divergence for # each dimension independently, we sum over the latent dimension to get the total @@ -194,23 +192,23 @@ class ActionChunkingTransformerPolicy(nn.Module): mean_kld = ( (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean() ) - loss = l1_loss + mean_kld * self.cfg.kl_weight + loss_dict["kld_loss"] = mean_kld + loss_dict["loss"] = l1_loss + mean_kld * self.cfg.kl_weight else: - loss = l1_loss + loss_dict["loss"] = l1_loss - return loss + return loss_dict def update(self, batch, **_) -> dict: """Run the model in train mode, compute the loss, and do an optimization step.""" start_time = time.time() self.train() - loss = self.compute_loss(batch) + loss_dict = self.forward(batch) + loss = loss_dict["loss"] loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_( - self.parameters(), - self.cfg.grad_clip_norm, - error_if_nonfinite=False, + self.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False ) self.optimizer.step() diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index dfab9bb7..e7cc62f4 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -43,12 +43,12 @@ class DiffusionPolicy(nn.Module): name = "diffusion" def __init__(self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0): - super().__init__() """ Args: cfg: Policy configuration class instance or None, in which case the default instantiation of the configuration class is used. """ + super().__init__() # TODO(alexander-soare): LR scheduler will be removed. assert lr_scheduler_num_training_steps > 0 if cfg is None: @@ -140,12 +140,18 @@ class DiffusionPolicy(nn.Module): action = self._queues["action"].popleft() return action - def forward(self, batch, **_): + def forward(self, batch: dict[str, Tensor], **_) -> dict[str, Tensor]: + """Run the batch through the model and compute the loss for training or validation.""" + loss = self.diffusion.compute_loss(batch) + return {"loss": loss} + + def update(self, batch: dict[str, Tensor], **_) -> dict: + """Run the model in train mode, compute the loss, and do an optimization step.""" start_time = time.time() self.diffusion.train() - loss = self.diffusion.compute_loss(batch) + loss = self.forward(batch)["loss"] loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_( diff --git a/lerobot/common/policies/policy_protocol.py b/lerobot/common/policies/policy_protocol.py index 3a396e84..6401c734 100644 --- a/lerobot/common/policies/policy_protocol.py +++ b/lerobot/common/policies/policy_protocol.py @@ -24,20 +24,20 @@ class Policy(Protocol): Does things like clearing caches. """ - def forward(self, batch: dict[str, Tensor], **kwargs): - """Wired to `select_action`.""" + def forward(self, batch: dict[str, Tensor]) -> dict: + """Run the batch through the model and compute the loss for training or validation. - def select_action(self, batch: dict[str, Tensor], **kwargs): + Returns a dictionary with "loss" and maybe other information. + """ + + def select_action(self, batch: dict[str, Tensor]): """Return one action to run in the environment (potentially in batch mode). When the model uses a history of observations, or outputs a sequence of actions, this method deals with caching. """ - def compute_loss(self, batch: dict[str, Tensor], **kwargs): - """Runs the batch through the model and computes the loss for training or validation.""" - - def update(self, batch, **kwargs): + def update(self, batch): """Does compute_loss then an optimization step. TODO(alexander-soare): We will move the optimization step back into the training loop, so this will