Add focal loss
This commit is contained in:
parent
f994febca4
commit
0108caacdc
|
@ -38,6 +38,38 @@ from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||||
|
|
||||||
|
|
||||||
|
def focal_regression_loss(
|
||||||
|
input: torch.Tensor,
|
||||||
|
target: torch.Tensor,
|
||||||
|
gamma: float = 2.0,
|
||||||
|
alpha: float = 0.25,
|
||||||
|
reduction: str = "mean",
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Computes a focal version of the L1 loss for regression tasks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input (Tensor): Predicted values.
|
||||||
|
target (Tensor): Ground-truth values.
|
||||||
|
gamma (float): Focusing parameter. (How strongly the loss focuses on difficult examples l1 = 0, more is > 0)
|
||||||
|
alpha (float): Weighting factor. (Balancing parameter to weigh the focal term, preventing excessively large gradients. Lower alpha helps in controlling aggressive scaling, maintaining stable training)
|
||||||
|
reduction (str): 'mean', 'sum', or 'none'.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: The computed loss.
|
||||||
|
"""
|
||||||
|
# Standard L1 error
|
||||||
|
l1_loss = torch.abs(input - target)
|
||||||
|
focal_weight = (1 - torch.exp(-l1_loss)) ** gamma
|
||||||
|
loss = alpha * focal_weight * l1_loss
|
||||||
|
if reduction == "mean":
|
||||||
|
return loss.mean()
|
||||||
|
elif reduction == "sum":
|
||||||
|
return loss.sum()
|
||||||
|
else:
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
class ACTPolicy(PreTrainedPolicy):
|
class ACTPolicy(PreTrainedPolicy):
|
||||||
"""
|
"""
|
||||||
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
|
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
|
||||||
|
@ -155,11 +187,13 @@ class ACTPolicy(PreTrainedPolicy):
|
||||||
batch = self.normalize_targets(batch)
|
batch = self.normalize_targets(batch)
|
||||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
||||||
|
|
||||||
l1_loss = (
|
focal_loss = focal_regression_loss(
|
||||||
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
|
batch["action"], actions_hat, gamma=2.0, alpha=0.25, reduction="none"
|
||||||
).mean()
|
)
|
||||||
|
focal_loss = focal_loss * ~batch["action_is_pad"].unsqueeze(-1)
|
||||||
|
focal_loss = focal_loss.mean()
|
||||||
|
|
||||||
loss_dict = {"l1_loss": l1_loss.item()}
|
loss_dict = {"focal_loss": focal_loss.item()}
|
||||||
if self.config.use_vae:
|
if self.config.use_vae:
|
||||||
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
||||||
# each dimension independently, we sum over the latent dimension to get the total
|
# each dimension independently, we sum over the latent dimension to get the total
|
||||||
|
@ -169,9 +203,9 @@ class ACTPolicy(PreTrainedPolicy):
|
||||||
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
|
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
|
||||||
)
|
)
|
||||||
loss_dict["kld_loss"] = mean_kld.item()
|
loss_dict["kld_loss"] = mean_kld.item()
|
||||||
loss = l1_loss + mean_kld * self.config.kl_weight
|
loss = focal_loss + mean_kld * self.config.kl_weight
|
||||||
else:
|
else:
|
||||||
loss = l1_loss
|
loss = focal_loss
|
||||||
|
|
||||||
return loss, loss_dict
|
return loss, loss_dict
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue