fix: action_is_pad was missing in compute_loss

This commit is contained in:
Cadene 2024-04-05 11:33:39 +00:00
parent ad3379a73a
commit a420714ee4
2 changed files with 13 additions and 10 deletions

View File

@ -243,9 +243,12 @@ class DiffusionUnetImagePolicy(BaseImagePolicy):
result = {"action": action, "action_pred": action_pred}
return result
def compute_loss(self, obs_dict, action):
nobs = obs_dict
nactions = action
def compute_loss(self, batch):
nobs = {
"image": batch["observation.image"],
"agent_pos": batch["observation.state"],
}
nactions = batch["action"]
batch_size = nactions.shape[0]
horizon = nactions.shape[1]
@ -302,6 +305,11 @@ class DiffusionUnetImagePolicy(BaseImagePolicy):
loss = F.mse_loss(pred, target, reduction="none")
loss = loss * loss_mask.type(loss.dtype)
loss = reduce(loss, "b ... -> b (...)", "mean")
if "action_is_pad" in batch:
in_episode_bound = ~batch["action_is_pad"]
loss = loss * in_episode_bound[:, :, None].type(loss.dtype)
loss = reduce(loss, "b t c -> b", "mean", b=batch_size)
loss = loss.mean()
return loss

View File

@ -153,12 +153,7 @@ class DiffusionPolicy(nn.Module):
data_s = time.time() - start_time
obs_dict = {
"image": batch["observation.image"],
"agent_pos": batch["observation.state"],
}
action = batch["action"]
loss = self.diffusion.compute_loss(obs_dict, action)
loss = self.diffusion.compute_loss(batch)
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(