fix: action_is_pad was missing in compute_loss
This commit is contained in:
parent
ad3379a73a
commit
a420714ee4
|
@ -243,9 +243,12 @@ class DiffusionUnetImagePolicy(BaseImagePolicy):
|
||||||
result = {"action": action, "action_pred": action_pred}
|
result = {"action": action, "action_pred": action_pred}
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def compute_loss(self, obs_dict, action):
|
def compute_loss(self, batch):
|
||||||
nobs = obs_dict
|
nobs = {
|
||||||
nactions = action
|
"image": batch["observation.image"],
|
||||||
|
"agent_pos": batch["observation.state"],
|
||||||
|
}
|
||||||
|
nactions = batch["action"]
|
||||||
batch_size = nactions.shape[0]
|
batch_size = nactions.shape[0]
|
||||||
horizon = nactions.shape[1]
|
horizon = nactions.shape[1]
|
||||||
|
|
||||||
|
@ -302,6 +305,11 @@ class DiffusionUnetImagePolicy(BaseImagePolicy):
|
||||||
|
|
||||||
loss = F.mse_loss(pred, target, reduction="none")
|
loss = F.mse_loss(pred, target, reduction="none")
|
||||||
loss = loss * loss_mask.type(loss.dtype)
|
loss = loss * loss_mask.type(loss.dtype)
|
||||||
loss = reduce(loss, "b ... -> b (...)", "mean")
|
|
||||||
|
if "action_is_pad" in batch:
|
||||||
|
in_episode_bound = ~batch["action_is_pad"]
|
||||||
|
loss = loss * in_episode_bound[:, :, None].type(loss.dtype)
|
||||||
|
|
||||||
|
loss = reduce(loss, "b t c -> b", "mean", b=batch_size)
|
||||||
loss = loss.mean()
|
loss = loss.mean()
|
||||||
return loss
|
return loss
|
||||||
|
|
|
@ -153,12 +153,7 @@ class DiffusionPolicy(nn.Module):
|
||||||
|
|
||||||
data_s = time.time() - start_time
|
data_s = time.time() - start_time
|
||||||
|
|
||||||
obs_dict = {
|
loss = self.diffusion.compute_loss(batch)
|
||||||
"image": batch["observation.image"],
|
|
||||||
"agent_pos": batch["observation.state"],
|
|
||||||
}
|
|
||||||
action = batch["action"]
|
|
||||||
loss = self.diffusion.compute_loss(obs_dict, action)
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
|
|
Loading…
Reference in New Issue