diff --git a/lerobot/common/policies/diffusion/model/diffusion_unet_image_policy.py b/lerobot/common/policies/diffusion/model/diffusion_unet_image_policy.py index b6b78925..92928c70 100644 --- a/lerobot/common/policies/diffusion/model/diffusion_unet_image_policy.py +++ b/lerobot/common/policies/diffusion/model/diffusion_unet_image_policy.py @@ -66,53 +66,33 @@ class DiffusionUnetImagePolicy(nn.Module): self.num_inference_steps = num_inference_steps # ========= inference ============ - def conditional_sample( - self, - condition_data, - inpainting_mask, - local_cond=None, - global_cond=None, - generator=None, - ): - model = self.unet - scheduler = self.noise_scheduler + def conditional_sample(self, batch_size, global_cond=None, generator=None): + device = get_device_from_parameters(self) + dtype = get_dtype_from_parameters(self) - trajectory = torch.randn( - size=condition_data.shape, - dtype=condition_data.dtype, - device=condition_data.device, + # Sample prior. + sample = torch.randn( + size=(batch_size, self.horizon, self.action_dim), + dtype=dtype, + device=device, generator=generator, ) - # set step values - scheduler.set_timesteps(self.num_inference_steps) + self.noise_scheduler.set_timesteps(self.num_inference_steps) - for t in scheduler.timesteps: - # 1. apply conditioning - trajectory[inpainting_mask] = condition_data[inpainting_mask] - - # 2. predict model output - model_output = model( - trajectory, - torch.full(trajectory.shape[:1], t, dtype=torch.long, device=trajectory.device), - local_cond=local_cond, + for t in self.noise_scheduler.timesteps: + # Predict model output. + model_output = self.unet( + sample, + torch.full(sample.shape[:1], t, dtype=torch.long, device=sample.device), global_cond=global_cond, ) + # Compute previous image: x_t -> x_t-1 # TODO(now): Is this right? + sample = self.noise_scheduler.step(model_output, t, sample, generator=generator).prev_sample - # 3. compute previous image: x_t -> x_t-1 - trajectory = scheduler.step( - model_output, - t, - trajectory, - generator=generator, - ).prev_sample + return sample - # finally make sure conditioning is enforced - trajectory[inpainting_mask] = condition_data[inpainting_mask] - - return trajectory - - def predict_action(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + def generate_actions(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: """ This function expects `batch` to have (at least): { @@ -125,27 +105,19 @@ class DiffusionUnetImagePolicy(nn.Module): assert n_obs_steps == self.n_obs_steps assert self.n_obs_steps == n_obs_steps - # build input - device = get_device_from_parameters(self) - dtype = get_dtype_from_parameters(self) - # Extract image feature (first combine batch and sequence dims). img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ...")) # Separate batch and sequence dims. img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size) # Concatenate state and image features then flatten to (B, global_cond_dim). global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1) - # reshape back to B, Do - # empty data for action - cond_data = torch.zeros(size=(batch_size, self.horizon, self.action_dim), device=device, dtype=dtype) - cond_mask = torch.zeros_like(cond_data, dtype=torch.bool) # run sampling - nsample = self.conditional_sample(cond_data, cond_mask, global_cond=global_cond) + sample = self.conditional_sample(batch_size, global_cond=global_cond) # `horizon` steps worth of actions (from the first observation). - action = nsample[..., : self.action_dim] - # Extract `n_action_steps` steps worth of action (from the current observation). + action = sample[..., : self.action_dim] + # Extract `n_action_steps` steps worth of actions (from the current observation). start = n_obs_steps - 1 end = start + self.n_action_steps action = action[:, start:end] @@ -159,9 +131,10 @@ class DiffusionUnetImagePolicy(nn.Module): "observation.state": (B, n_obs_steps, state_dim) "observation.image": (B, n_obs_steps, C, H, W) "action": (B, horizon, action_dim) - "action_is_pad": (B, horizon) # TODO(now) maybe this is (B, horizon, 1) + "action_is_pad": (B, horizon) } """ + # Input validation. assert set(batch).issuperset({"observation.state", "observation.image", "action", "action_is_pad"}) batch_size, n_obs_steps = batch["observation.state"].shape[:2] horizon = batch["action"].shape[1] @@ -169,12 +142,6 @@ class DiffusionUnetImagePolicy(nn.Module): assert n_obs_steps == self.n_obs_steps assert self.n_obs_steps == n_obs_steps - # handle different ways of passing observation - local_cond = None - global_cond = None - trajectory = batch["action"] - cond_data = trajectory - # Extract image feature (first combine batch and sequence dims). img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ...")) # Separate batch and sequence dims. @@ -182,39 +149,39 @@ class DiffusionUnetImagePolicy(nn.Module): # Concatenate state and image features then flatten to (B, global_cond_dim). global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1) - # Sample noise that we'll add to the images - noise = torch.randn(trajectory.shape, device=trajectory.device) - # Sample a random timestep for each image + trajectory = batch["action"] + + # Forward diffusion. + # Sample noise to add to the trajectory. + eps = torch.randn(trajectory.shape, device=trajectory.device) + # Sample a random noising timestep for each item in the batch. timesteps = torch.randint( - 0, - self.noise_scheduler.config.num_train_timesteps, - (trajectory.shape[0],), + low=0, + high=self.noise_scheduler.config.num_train_timesteps, + size=(trajectory.shape[0],), device=trajectory.device, ).long() - # Add noise to the clean images according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_trajectory = self.noise_scheduler.add_noise(trajectory, noise, timesteps) + # Add noise to the clean trajectories according to the noise magnitude at each timestep. + noisy_trajectory = self.noise_scheduler.add_noise(trajectory, eps, timesteps) - # Apply inpainting. TODO(now): implement? - inpainting_mask = torch.zeros_like(trajectory, dtype=bool) - noisy_trajectory[inpainting_mask] = cond_data[inpainting_mask] - - # Predict the noise residual - pred = self.unet(noisy_trajectory, timesteps, local_cond=local_cond, global_cond=global_cond) + # Run the denoising network (that might denoise the trajectory, or attempt to predict the noise). + pred = self.unet(noisy_trajectory, timesteps, global_cond=global_cond) + # Compute the loss. + # The targe is either the original trajectory, or the noise. pred_type = self.noise_scheduler.config.prediction_type if pred_type == "epsilon": - target = noise + target = eps elif pred_type == "sample": - target = trajectory + target = batch["action"] else: raise ValueError(f"Unsupported prediction type {pred_type}") loss = F.mse_loss(pred, target, reduction="none") - loss = loss * (~inpainting_mask) + # Mask loss wherever the action is padded with copies (edges of the dataset trajectory). if "action_is_pad" in batch: in_episode_bound = ~batch["action_is_pad"] - loss = loss * in_episode_bound[:, :, None].type(loss.dtype) + loss = loss * in_episode_bound.unsqueeze(-1) return loss.mean() diff --git a/lerobot/common/policies/diffusion/model/ema_model.py b/lerobot/common/policies/diffusion/model/ema_model.py index 3cb1dfbd..1e3447f3 100644 --- a/lerobot/common/policies/diffusion/model/ema_model.py +++ b/lerobot/common/policies/diffusion/model/ema_model.py @@ -32,7 +32,7 @@ class EMAModel: self.min_value = min_value self.max_value = max_value - self.decay = 0.0 + self.alpha = 0.0 self.optimization_step = 0 def get_decay(self, optimization_step): @@ -49,23 +49,20 @@ class EMAModel: @torch.no_grad() def step(self, new_model): - self.decay = self.get_decay(self.optimization_step) + self.alpha = self.get_decay(self.optimization_step) - for module, ema_module in zip(new_model.modules(), self.averaged_model.modules(), strict=False): + for module, ema_module in zip(new_model.modules(), self.averaged_model.modules(), strict=True): + # Iterate over immediate parameters only. for param, ema_param in zip( - module.parameters(recurse=False), ema_module.parameters(recurse=False), strict=False + module.parameters(recurse=False), ema_module.parameters(recurse=False), strict=True ): - # iterative over immediate parameters only. if isinstance(param, dict): raise RuntimeError("Dict parameter not supported") - - if isinstance(module, _BatchNorm): - # skip batchnorms - ema_param.copy_(param.to(dtype=ema_param.dtype).data) - elif not param.requires_grad: + if isinstance(module, _BatchNorm) or not param.requires_grad: + # Copy BatchNorm parameters, and non-trainable parameters directly. ema_param.copy_(param.to(dtype=ema_param.dtype).data) else: - ema_param.mul_(self.decay) - ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.decay) + ema_param.mul_(self.alpha) + ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.alpha) self.optimization_step += 1 diff --git a/lerobot/common/policies/diffusion/policy.py b/lerobot/common/policies/diffusion/policy.py index fca89d46..b1713869 100644 --- a/lerobot/common/policies/diffusion/policy.py +++ b/lerobot/common/policies/diffusion/policy.py @@ -130,9 +130,9 @@ class DiffusionPolicy(nn.Module): def _generate_actions(self, batch): if not self.training and self.ema_diffusion is not None: - return self.ema_diffusion.predict_action(batch) + return self.ema_diffusion.generate_actions(batch) else: - return self.diffusion.predict_action(batch) + return self.diffusion.generate_actions(batch) def update(self, batch, **_): """Run the model in train mode, compute the loss, and do an optimization step."""