modified loss.item to loss.detach, so that no scalar will be created and when using torch.compile, it will not break graph
This commit is contained in:
parent
c574eb4984
commit
f15dd56d0e
|
@ -330,8 +330,8 @@ class PI0Policy(PreTrainedPolicy):
|
|||
|
||||
# For backward pass
|
||||
loss = losses.mean()
|
||||
# For logging
|
||||
loss_dict["l2_loss"] = loss.item()
|
||||
# For logging. Use detach so won't create scalar to break graph when using torch.compile
|
||||
loss_dict["l2_loss"] = loss.detach()
|
||||
|
||||
return loss, loss_dict
|
||||
|
||||
|
|
Loading…
Reference in New Issue