diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index f2b16a1e..ae90b842 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -38,6 +38,38 @@ from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.pretrained import PreTrainedPolicy +def focal_regression_loss( + input: torch.Tensor, + target: torch.Tensor, + gamma: float = 2.0, + alpha: float = 0.25, + reduction: str = "mean", +) -> torch.Tensor: + """ + Computes a focal version of the L1 loss for regression tasks. + + Args: + input (Tensor): Predicted values. + target (Tensor): Ground-truth values. + gamma (float): Focusing parameter. (How strongly the loss focuses on difficult examples l1 = 0, more is > 0) + alpha (float): Weighting factor. (Balancing parameter to weigh the focal term, preventing excessively large gradients. Lower alpha helps in controlling aggressive scaling, maintaining stable training) + reduction (str): 'mean', 'sum', or 'none'. + + Returns: + Tensor: The computed loss. + """ + # Standard L1 error + l1_loss = torch.abs(input - target) + focal_weight = (1 - torch.exp(-l1_loss)) ** gamma + loss = alpha * focal_weight * l1_loss + if reduction == "mean": + return loss.mean() + elif reduction == "sum": + return loss.sum() + else: + return loss + + class ACTPolicy(PreTrainedPolicy): """ Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost @@ -155,11 +187,13 @@ class ACTPolicy(PreTrainedPolicy): batch = self.normalize_targets(batch) actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch) - l1_loss = ( - F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1) - ).mean() + focal_loss = focal_regression_loss( + batch["action"], actions_hat, gamma=2.0, alpha=0.25, reduction="none" + ) + focal_loss = focal_loss * ~batch["action_is_pad"].unsqueeze(-1) + focal_loss = focal_loss.mean() - loss_dict = {"l1_loss": l1_loss.item()} + loss_dict = {"focal_loss": focal_loss.item()} if self.config.use_vae: # Calculate Dā‚–ā‚—(latent_pdf || standard_normal). Note: After computing the KL-divergence for # each dimension independently, we sum over the latent dimension to get the total @@ -169,9 +203,9 @@ class ACTPolicy(PreTrainedPolicy): (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean() ) loss_dict["kld_loss"] = mean_kld.item() - loss = l1_loss + mean_kld * self.config.kl_weight + loss = focal_loss + mean_kld * self.config.kl_weight else: - loss = l1_loss + loss = focal_loss return loss, loss_dict