diff --git a/lerobot/common/policies/diffusion/configuration_diffusion.py b/lerobot/common/policies/diffusion/configuration_diffusion.py index 73fabefa..d7341c33 100644 --- a/lerobot/common/policies/diffusion/configuration_diffusion.py +++ b/lerobot/common/policies/diffusion/configuration_diffusion.py @@ -64,6 +64,9 @@ class DiffusionConfig: clip_sample_range: The magnitude of the clipping range as described above. num_inference_steps: Number of reverse diffusion steps to use at inference time (steps are evenly spaced). If not provided, this defaults to be the same as `num_train_timesteps`. + do_mask_loss_for_padding: Whether to mask the loss when there are copy-padded actions. See + `LeRobotDataset` and `load_previous_and_future_frames` for mor information. Note, this defaults + to False as the original Diffusion Policy implementation does the same. """ # Inputs / output structure. @@ -118,6 +121,9 @@ class DiffusionConfig: # Inference num_inference_steps: int | None = None + # Loss computation + do_mask_loss_for_padding: bool = False + def __post_init__(self): """Input validation (not exhaustive).""" if not self.vision_backbone.startswith("resnet"): diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index f5f64d80..91cf6dd0 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -268,7 +268,7 @@ class DiffusionModel(nn.Module): loss = F.mse_loss(pred, target, reduction="none") # Mask loss wherever the action is padded with copies (edges of the dataset trajectory). - if "action_is_pad" in batch: + if self.config.do_mask_loss_for_padding and "action_is_pad" in batch: in_episode_bound = ~batch["action_is_pad"] loss = loss * in_episode_bound.unsqueeze(-1) diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index 60061c38..2d611c88 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -95,3 +95,6 @@ policy: # Inference num_inference_steps: 100 + + # Loss computation + do_mask_loss_for_padding: false diff --git a/tests/data/save_policy_to_safetensors/pusht_diffusion/grad_stats.safetensors b/tests/data/save_policy_to_safetensors/pusht_diffusion/grad_stats.safetensors index 1b1142b2..f27cd678 100644 Binary files a/tests/data/save_policy_to_safetensors/pusht_diffusion/grad_stats.safetensors and b/tests/data/save_policy_to_safetensors/pusht_diffusion/grad_stats.safetensors differ diff --git a/tests/data/save_policy_to_safetensors/pusht_diffusion/output_dict.safetensors b/tests/data/save_policy_to_safetensors/pusht_diffusion/output_dict.safetensors index 77472bb5..5f33535d 100644 Binary files a/tests/data/save_policy_to_safetensors/pusht_diffusion/output_dict.safetensors and b/tests/data/save_policy_to_safetensors/pusht_diffusion/output_dict.safetensors differ