modified loss.item to loss.detach, so that no scalar will be created and when using torch.compile, it will not break graph

This commit is contained in:
IrvingF7 2025-02-14 08:43:09 -05:00 committed by GitHub
parent c574eb4984
commit f15dd56d0e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 2 additions and 2 deletions

View File

@ -330,8 +330,8 @@ class PI0Policy(PreTrainedPolicy):
# For backward pass
loss = losses.mean()
# For logging
loss_dict["l2_loss"] = loss.item()
# For logging. Use detach so won't create scalar to break graph when using torch.compile
loss_dict["l2_loss"] = loss.detach()
return loss, loss_dict