From a420714ee4e8c4d5afbc8cc65fe98ef78965a8cf Mon Sep 17 00:00:00 2001 From: Cadene Date: Fri, 5 Apr 2024 11:33:39 +0000 Subject: [PATCH] fix: action_is_pad was missing in compute_loss --- .../diffusion/diffusion_unet_image_policy.py | 16 ++++++++++++---- lerobot/common/policies/diffusion/policy.py | 7 +------ 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py b/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py index 373e4b6c..f7432db3 100644 --- a/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py +++ b/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py @@ -243,9 +243,12 @@ class DiffusionUnetImagePolicy(BaseImagePolicy): result = {"action": action, "action_pred": action_pred} return result - def compute_loss(self, obs_dict, action): - nobs = obs_dict - nactions = action + def compute_loss(self, batch): + nobs = { + "image": batch["observation.image"], + "agent_pos": batch["observation.state"], + } + nactions = batch["action"] batch_size = nactions.shape[0] horizon = nactions.shape[1] @@ -302,6 +305,11 @@ class DiffusionUnetImagePolicy(BaseImagePolicy): loss = F.mse_loss(pred, target, reduction="none") loss = loss * loss_mask.type(loss.dtype) - loss = reduce(loss, "b ... -> b (...)", "mean") + + if "action_is_pad" in batch: + in_episode_bound = ~batch["action_is_pad"] + loss = loss * in_episode_bound[:, :, None].type(loss.dtype) + + loss = reduce(loss, "b t c -> b", "mean", b=batch_size) loss = loss.mean() return loss diff --git a/lerobot/common/policies/diffusion/policy.py b/lerobot/common/policies/diffusion/policy.py index de8796ab..a4f4a450 100644 --- a/lerobot/common/policies/diffusion/policy.py +++ b/lerobot/common/policies/diffusion/policy.py @@ -153,12 +153,7 @@ class DiffusionPolicy(nn.Module): data_s = time.time() - start_time - obs_dict = { - "image": batch["observation.image"], - "agent_pos": batch["observation.state"], - } - action = batch["action"] - loss = self.diffusion.compute_loss(obs_dict, action) + loss = self.diffusion.compute_loss(batch) loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_(