Merge 2fe78bbfac
into 145fe4cd17
This commit is contained in:
commit
713eab7e3b
lerobot/common/policies/pi0
|
@ -330,8 +330,8 @@ class PI0Policy(PreTrainedPolicy):
|
|||
|
||||
# For backward pass
|
||||
loss = losses.mean()
|
||||
# For logging
|
||||
loss_dict["l2_loss"] = loss.item()
|
||||
# For logging. Use detach so won't create scalar to break graph when using torch.compile
|
||||
loss_dict["l2_loss"] = loss.detach()
|
||||
|
||||
return loss, loss_dict
|
||||
|
||||
|
|
Loading…
Reference in New Issue