diff --git a/lerobot/common/policies/pi0/modeling_pi0.py b/lerobot/common/policies/pi0/modeling_pi0.py index 7599fa63..43b11456 100644 --- a/lerobot/common/policies/pi0/modeling_pi0.py +++ b/lerobot/common/policies/pi0/modeling_pi0.py @@ -325,7 +325,8 @@ class PI0Policy(PreTrainedPolicy): loss_dict["losses_after_in_ep_bound"] = losses.clone() # Remove padding - losses = losses[:, :, : self.config.max_action_dim] + original_action_dim = self.config.action_feature.shape[0] + losses = losses[:, :, :original_action_dim] loss_dict["losses_after_rm_padding"] = losses.clone() # For backward pass