diff --git a/lerobot/common/policies/pi0/modeling_pi0.py b/lerobot/common/policies/pi0/modeling_pi0.py index 7599fa63..f3e17539 100644 --- a/lerobot/common/policies/pi0/modeling_pi0.py +++ b/lerobot/common/policies/pi0/modeling_pi0.py @@ -330,8 +330,8 @@ class PI0Policy(PreTrainedPolicy): # For backward pass loss = losses.mean() - # For logging - loss_dict["l2_loss"] = loss.item() + # For logging. Use detach so won't create scalar to break graph when using torch.compile + loss_dict["l2_loss"] = loss.detach() return loss, loss_dict