From f15dd56d0e211d15fda3ba7dcec5a8f7dedb1ced Mon Sep 17 00:00:00 2001 From: IrvingF7 <13810062+IrvingF7@users.noreply.github.com> Date: Fri, 14 Feb 2025 08:43:09 -0500 Subject: [PATCH] modified loss.item to loss.detach, so that no scalar will be created and when using torch.compile, it will not break graph --- lerobot/common/policies/pi0/modeling_pi0.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lerobot/common/policies/pi0/modeling_pi0.py b/lerobot/common/policies/pi0/modeling_pi0.py index c8b12caf..51f9dd9f 100644 --- a/lerobot/common/policies/pi0/modeling_pi0.py +++ b/lerobot/common/policies/pi0/modeling_pi0.py @@ -330,8 +330,8 @@ class PI0Policy(PreTrainedPolicy): # For backward pass loss = losses.mean() - # For logging - loss_dict["l2_loss"] = loss.item() + # For logging. Use detach so won't create scalar to break graph when using torch.compile + loss_dict["l2_loss"] = loss.detach() return loss, loss_dict