Merge 2fe78bbfac
into b43ece8934
This commit is contained in:
commit
d6fe46c5a0
|
@ -330,8 +330,8 @@ class PI0Policy(PreTrainedPolicy):
|
||||||
|
|
||||||
# For backward pass
|
# For backward pass
|
||||||
loss = losses.mean()
|
loss = losses.mean()
|
||||||
# For logging
|
# For logging. Use detach so won't create scalar to break graph when using torch.compile
|
||||||
loss_dict["l2_loss"] = loss.item()
|
loss_dict["l2_loss"] = loss.detach()
|
||||||
|
|
||||||
return loss, loss_dict
|
return loss, loss_dict
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue