Compare commits

...

4 Commits

Author SHA1 Message Date
IrvingF7 2611e3bf8f
Merge 2fe78bbfac into 768e36660d 2025-04-15 21:19:27 +08:00
Steven Palma 2fe78bbfac
Merge branch 'main' into main 2025-03-05 01:37:18 +01:00
IrvingF7 735258d121
Merge branch 'main' into main 2025-02-15 12:21:53 -05:00
IrvingF7 f15dd56d0e
modified loss.item to loss.detach, so that no scalar will be created and when using torch.compile, it will not break graph 2025-02-14 08:43:09 -05:00
1 changed files with 2 additions and 2 deletions

View File

@ -330,8 +330,8 @@ class PI0Policy(PreTrainedPolicy):
# For backward pass
loss = losses.mean()
# For logging
loss_dict["l2_loss"] = loss.item()
# For logging. Use detach so won't create scalar to break graph when using torch.compile
loss_dict["l2_loss"] = loss.detach()
return loss, loss_dict