Add different sampling algorithm.

This commit is contained in:
Mahi Shafiullah 2025-02-03 23:31:39 -05:00
parent e82b4c9460
commit 334d9e92bd
2 changed files with 32 additions and 9 deletions

View File

@ -130,6 +130,9 @@ class DiTFlowConfig(PreTrainedConfig):
activation: str = "gelu" activation: str = "gelu"
# Noise scheduler. # Noise scheduler.
training_noise_sampling: str = (
"beta" # "uniform" or "beta", from pi0 https://www.physicalintelligence.company/download/pi0.pdf
)
clip_sample: bool = True clip_sample: bool = True
clip_sample_range: float = 1.0 clip_sample_range: float = 1.0
@ -156,6 +159,11 @@ class DiTFlowConfig(PreTrainedConfig):
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}." f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
) )
if self.training_noise_sampling not in ("uniform", "beta"):
raise ValueError(
f"`training_noise_sampling` must be either 'uniform' or 'beta'. Got {self.training_noise_sampling}."
)
def get_optimizer_preset(self) -> AdamConfig: def get_optimizer_preset(self) -> AdamConfig:
return AdamConfig( return AdamConfig(
lr=self.optimizer_lr, lr=self.optimizer_lr,

View File

@ -3,7 +3,7 @@
# Heavy inspiration taken from # Heavy inspiration taken from
# * DETR by Meta AI (Carion et. al.): https://github.com/facebookresearch/detr # * DETR by Meta AI (Carion et. al.): https://github.com/facebookresearch/detr
# * DiT by Meta AI (Peebles and Xie): https://github.com/facebookresearch/DiT # * DiT by Meta AI (Peebles and Xie): https://github.com/facebookresearch/DiT
# * DiT Policy by Dasari et. al. :https://dit-policy.github.io/ # * DiT Policy by Dasari et. al. : https://dit-policy.github.io/
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
@ -243,9 +243,7 @@ class _DiTNoiseNet(nn.Module):
self.clip_sample = clip_sample self.clip_sample = clip_sample
self.clip_sample_range = clip_sample_range self.clip_sample_range = clip_sample_range
print( print("Number of flow params: {:.2f}M".format(sum(p.numel() for p in self.parameters()) / 1e6))
"number of diffusion parameters: {:2f}M".format(sum(p.numel() for p in self.parameters()) / 1e6)
)
def forward(self, noisy_actions, time, global_cond): def forward(self, noisy_actions, time, global_cond):
c = self.cond_proj(global_cond) c = self.cond_proj(global_cond)
@ -434,6 +432,26 @@ class DiTFlowModel(nn.Module):
) )
self.num_inference_steps = config.num_inference_steps or 100 self.num_inference_steps = config.num_inference_steps or 100
self.training_noise_sampling = config.training_noise_sampling
if config.training_noise_sampling == "uniform":
self.noise_distribution = torch.distributions.Uniform(
low=0,
high=1,
)
elif config.training_noise_sampling == "beta":
# From the Pi0 paper, https://www.physicalintelligence.company/download/pi0.pdf Appendix B.
# There, they say the PDF for the distribution they use is the following:
# $p(t) = Beta(s-t / s; 1.5, 1)$
# So, we first figure out the distribution over $t' = s - s * t$ and then transform it to $t$.
s = 0.999 # constant from the paper
beta_dist = torch.distributions.Beta(
concentration1=1.5, # alpha
concentration0=1.0, # beta
)
affine_transform = torch.distributions.transforms.AffineTransform(loc=s, scale=-s)
self.noise_distribution = torch.distributions.TransformedDistribution(
beta_dist, [affine_transform]
)
# ========= inference ============ # ========= inference ============
def conditional_sample( def conditional_sample(
@ -447,7 +465,7 @@ class DiTFlowModel(nn.Module):
# Expand global conditioning to the batch size. # Expand global conditioning to the batch size.
if global_cond is not None: if global_cond is not None:
global_cond = global_cond.unsqueeze(0).expand(batch_size, -1).to(device=device, dtype=dtype) global_cond = global_cond.expand(batch_size, -1).to(device=device, dtype=dtype)
# Sample prior. # Sample prior.
sample = self.velocity_net.sample( sample = self.velocity_net.sample(
@ -550,10 +568,7 @@ class DiTFlowModel(nn.Module):
# Sample noise to add to the trajectory. # Sample noise to add to the trajectory.
noise = self.velocity_net.sample_noise(trajectory.shape[0], trajectory.device) noise = self.velocity_net.sample_noise(trajectory.shape[0], trajectory.device)
# Sample a random noising timestep for each item in the batch. # Sample a random noising timestep for each item in the batch.
timesteps = torch.rand( timesteps = self.noise_distribution.sample((trajectory.shape[0],)).to(trajectory.device)
size=(trajectory.shape[0],),
device=trajectory.device,
)
# Add noise to the clean trajectories according to the noise magnitude at each timestep. # Add noise to the clean trajectories according to the noise magnitude at each timestep.
noisy_trajectory = (1 - timesteps[:, None, None]) * noise + timesteps[:, None, None] * trajectory noisy_trajectory = (1 - timesteps[:, None, None]) * noise + timesteps[:, None, None] * trajectory