fix: action_is_pad was missing in compute_loss

This commit is contained in:
Cadene 2024-04-05 11:33:39 +00:00
parent ad3379a73a
commit a420714ee4
2 changed files with 13 additions and 10 deletions

View File

@ -243,9 +243,12 @@ class DiffusionUnetImagePolicy(BaseImagePolicy):
result = {"action": action, "action_pred": action_pred} result = {"action": action, "action_pred": action_pred}
return result return result
def compute_loss(self, obs_dict, action): def compute_loss(self, batch):
nobs = obs_dict nobs = {
nactions = action "image": batch["observation.image"],
"agent_pos": batch["observation.state"],
}
nactions = batch["action"]
batch_size = nactions.shape[0] batch_size = nactions.shape[0]
horizon = nactions.shape[1] horizon = nactions.shape[1]
@ -302,6 +305,11 @@ class DiffusionUnetImagePolicy(BaseImagePolicy):
loss = F.mse_loss(pred, target, reduction="none") loss = F.mse_loss(pred, target, reduction="none")
loss = loss * loss_mask.type(loss.dtype) loss = loss * loss_mask.type(loss.dtype)
loss = reduce(loss, "b ... -> b (...)", "mean")
if "action_is_pad" in batch:
in_episode_bound = ~batch["action_is_pad"]
loss = loss * in_episode_bound[:, :, None].type(loss.dtype)
loss = reduce(loss, "b t c -> b", "mean", b=batch_size)
loss = loss.mean() loss = loss.mean()
return loss return loss

View File

@ -153,12 +153,7 @@ class DiffusionPolicy(nn.Module):
data_s = time.time() - start_time data_s = time.time() - start_time
obs_dict = { loss = self.diffusion.compute_loss(batch)
"image": batch["observation.image"],
"agent_pos": batch["observation.state"],
}
action = batch["action"]
loss = self.diffusion.compute_loss(obs_dict, action)
loss.backward() loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_( grad_norm = torch.nn.utils.clip_grad_norm_(