From f15dd56d0e211d15fda3ba7dcec5a8f7dedb1ced Mon Sep 17 00:00:00 2001
From: IrvingF7 <13810062+IrvingF7@users.noreply.github.com>
Date: Fri, 14 Feb 2025 08:43:09 -0500
Subject: [PATCH] modified loss.item to loss.detach, so that no scalar will be
 created and when using torch.compile, it will not break graph

---
 lerobot/common/policies/pi0/modeling_pi0.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/lerobot/common/policies/pi0/modeling_pi0.py b/lerobot/common/policies/pi0/modeling_pi0.py
index c8b12caf..51f9dd9f 100644
--- a/lerobot/common/policies/pi0/modeling_pi0.py
+++ b/lerobot/common/policies/pi0/modeling_pi0.py
@@ -330,8 +330,8 @@ class PI0Policy(PreTrainedPolicy):
 
         # For backward pass
         loss = losses.mean()
-        # For logging
-        loss_dict["l2_loss"] = loss.item()
+        # For logging. Use detach so won't create scalar to break graph when using torch.compile
+        loss_dict["l2_loss"] = loss.detach()
 
         return loss, loss_dict