This commit is contained in:
Yachen Kang 2025-04-15 10:18:22 +08:00 committed by GitHub
commit 959e8b0d7a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 2 additions and 1 deletions

View File

@ -325,7 +325,8 @@ class PI0Policy(PreTrainedPolicy):
loss_dict["losses_after_in_ep_bound"] = losses.clone()
# Remove padding
losses = losses[:, :, : self.config.max_action_dim]
original_action_dim = self.config.action_feature.shape[0]
losses = losses[:, :, :original_action_dim]
loss_dict["losses_after_rm_padding"] = losses.clone()
# For backward pass