throw an error if config.do_maks_loss and action_is_pad not provided in batch (#213)

Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
Radek Osmulski 2024-05-27 18:06:26 +10:00 committed by GitHub
parent 6d39b73399
commit 3b86050ab0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 5 additions and 1 deletions

View File

@ -304,7 +304,11 @@ class DiffusionModel(nn.Module):
loss = F.mse_loss(pred, target, reduction="none")
# Mask loss wherever the action is padded with copies (edges of the dataset trajectory).
if self.config.do_mask_loss_for_padding and "action_is_pad" in batch:
if self.config.do_mask_loss_for_padding:
if "action_is_pad" not in batch:
raise ValueError(
f"You need to provide 'action_is_pad' in the batch when {self.config.do_mask_loss_for_padding=}."
)
in_episode_bound = ~batch["action_is_pad"]
loss = loss * in_episode_bound.unsqueeze(-1)