From 334d9e92bd3c84eb25de01bebef4a468f0ebb089 Mon Sep 17 00:00:00 2001 From: Mahi Shafiullah <3000253+notmahi@users.noreply.github.com> Date: Mon, 3 Feb 2025 23:31:39 -0500 Subject: [PATCH] Add different sampling algorithm. --- .../dit_flow/configuration_dit_flow.py | 8 +++++ .../policies/dit_flow/modeling_dit_flow.py | 33 ++++++++++++++----- 2 files changed, 32 insertions(+), 9 deletions(-) diff --git a/lerobot/common/policies/dit_flow/configuration_dit_flow.py b/lerobot/common/policies/dit_flow/configuration_dit_flow.py index 60093024..571764c2 100644 --- a/lerobot/common/policies/dit_flow/configuration_dit_flow.py +++ b/lerobot/common/policies/dit_flow/configuration_dit_flow.py @@ -130,6 +130,9 @@ class DiTFlowConfig(PreTrainedConfig): activation: str = "gelu" # Noise scheduler. + training_noise_sampling: str = ( + "beta" # "uniform" or "beta", from pi0 https://www.physicalintelligence.company/download/pi0.pdf + ) clip_sample: bool = True clip_sample_range: float = 1.0 @@ -156,6 +159,11 @@ class DiTFlowConfig(PreTrainedConfig): f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}." ) + if self.training_noise_sampling not in ("uniform", "beta"): + raise ValueError( + f"`training_noise_sampling` must be either 'uniform' or 'beta'. Got {self.training_noise_sampling}." + ) + def get_optimizer_preset(self) -> AdamConfig: return AdamConfig( lr=self.optimizer_lr, diff --git a/lerobot/common/policies/dit_flow/modeling_dit_flow.py b/lerobot/common/policies/dit_flow/modeling_dit_flow.py index 4d82adbb..547ad0ab 100644 --- a/lerobot/common/policies/dit_flow/modeling_dit_flow.py +++ b/lerobot/common/policies/dit_flow/modeling_dit_flow.py @@ -3,7 +3,7 @@ # Heavy inspiration taken from # * DETR by Meta AI (Carion et. al.): https://github.com/facebookresearch/detr # * DiT by Meta AI (Peebles and Xie): https://github.com/facebookresearch/DiT -# * DiT Policy by Dasari et. al. :https://dit-policy.github.io/ +# * DiT Policy by Dasari et. al. : https://dit-policy.github.io/ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. @@ -243,9 +243,7 @@ class _DiTNoiseNet(nn.Module): self.clip_sample = clip_sample self.clip_sample_range = clip_sample_range - print( - "number of diffusion parameters: {:2f}M".format(sum(p.numel() for p in self.parameters()) / 1e6) - ) + print("Number of flow params: {:.2f}M".format(sum(p.numel() for p in self.parameters()) / 1e6)) def forward(self, noisy_actions, time, global_cond): c = self.cond_proj(global_cond) @@ -434,6 +432,26 @@ class DiTFlowModel(nn.Module): ) self.num_inference_steps = config.num_inference_steps or 100 + self.training_noise_sampling = config.training_noise_sampling + if config.training_noise_sampling == "uniform": + self.noise_distribution = torch.distributions.Uniform( + low=0, + high=1, + ) + elif config.training_noise_sampling == "beta": + # From the Pi0 paper, https://www.physicalintelligence.company/download/pi0.pdf Appendix B. + # There, they say the PDF for the distribution they use is the following: + # $p(t) = Beta(s-t / s; 1.5, 1)$ + # So, we first figure out the distribution over $t' = s - s * t$ and then transform it to $t$. + s = 0.999 # constant from the paper + beta_dist = torch.distributions.Beta( + concentration1=1.5, # alpha + concentration0=1.0, # beta + ) + affine_transform = torch.distributions.transforms.AffineTransform(loc=s, scale=-s) + self.noise_distribution = torch.distributions.TransformedDistribution( + beta_dist, [affine_transform] + ) # ========= inference ============ def conditional_sample( @@ -447,7 +465,7 @@ class DiTFlowModel(nn.Module): # Expand global conditioning to the batch size. if global_cond is not None: - global_cond = global_cond.unsqueeze(0).expand(batch_size, -1).to(device=device, dtype=dtype) + global_cond = global_cond.expand(batch_size, -1).to(device=device, dtype=dtype) # Sample prior. sample = self.velocity_net.sample( @@ -550,10 +568,7 @@ class DiTFlowModel(nn.Module): # Sample noise to add to the trajectory. noise = self.velocity_net.sample_noise(trajectory.shape[0], trajectory.device) # Sample a random noising timestep for each item in the batch. - timesteps = torch.rand( - size=(trajectory.shape[0],), - device=trajectory.device, - ) + timesteps = self.noise_distribution.sample((trajectory.shape[0],)).to(trajectory.device) # Add noise to the clean trajectories according to the noise magnitude at each timestep. noisy_trajectory = (1 - timesteps[:, None, None]) * noise + timesteps[:, None, None] * trajectory