throw an error if config.do_maks_loss and action_is_pad not provided in batch (#213)
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
parent
6d39b73399
commit
3b86050ab0
|
@ -304,7 +304,11 @@ class DiffusionModel(nn.Module):
|
|||
loss = F.mse_loss(pred, target, reduction="none")
|
||||
|
||||
# Mask loss wherever the action is padded with copies (edges of the dataset trajectory).
|
||||
if self.config.do_mask_loss_for_padding and "action_is_pad" in batch:
|
||||
if self.config.do_mask_loss_for_padding:
|
||||
if "action_is_pad" not in batch:
|
||||
raise ValueError(
|
||||
f"You need to provide 'action_is_pad' in the batch when {self.config.do_mask_loss_for_padding=}."
|
||||
)
|
||||
in_episode_bound = ~batch["action_is_pad"]
|
||||
loss = loss * in_episode_bound.unsqueeze(-1)
|
||||
|
||||
|
|
Loading…
Reference in New Issue