Add different sampling algorithm.
This commit is contained in:
parent
e82b4c9460
commit
334d9e92bd
|
@ -130,6 +130,9 @@ class DiTFlowConfig(PreTrainedConfig):
|
||||||
activation: str = "gelu"
|
activation: str = "gelu"
|
||||||
|
|
||||||
# Noise scheduler.
|
# Noise scheduler.
|
||||||
|
training_noise_sampling: str = (
|
||||||
|
"beta" # "uniform" or "beta", from pi0 https://www.physicalintelligence.company/download/pi0.pdf
|
||||||
|
)
|
||||||
clip_sample: bool = True
|
clip_sample: bool = True
|
||||||
clip_sample_range: float = 1.0
|
clip_sample_range: float = 1.0
|
||||||
|
|
||||||
|
@ -156,6 +159,11 @@ class DiTFlowConfig(PreTrainedConfig):
|
||||||
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.training_noise_sampling not in ("uniform", "beta"):
|
||||||
|
raise ValueError(
|
||||||
|
f"`training_noise_sampling` must be either 'uniform' or 'beta'. Got {self.training_noise_sampling}."
|
||||||
|
)
|
||||||
|
|
||||||
def get_optimizer_preset(self) -> AdamConfig:
|
def get_optimizer_preset(self) -> AdamConfig:
|
||||||
return AdamConfig(
|
return AdamConfig(
|
||||||
lr=self.optimizer_lr,
|
lr=self.optimizer_lr,
|
||||||
|
|
|
@ -243,9 +243,7 @@ class _DiTNoiseNet(nn.Module):
|
||||||
self.clip_sample = clip_sample
|
self.clip_sample = clip_sample
|
||||||
self.clip_sample_range = clip_sample_range
|
self.clip_sample_range = clip_sample_range
|
||||||
|
|
||||||
print(
|
print("Number of flow params: {:.2f}M".format(sum(p.numel() for p in self.parameters()) / 1e6))
|
||||||
"number of diffusion parameters: {:2f}M".format(sum(p.numel() for p in self.parameters()) / 1e6)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, noisy_actions, time, global_cond):
|
def forward(self, noisy_actions, time, global_cond):
|
||||||
c = self.cond_proj(global_cond)
|
c = self.cond_proj(global_cond)
|
||||||
|
@ -434,6 +432,26 @@ class DiTFlowModel(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.num_inference_steps = config.num_inference_steps or 100
|
self.num_inference_steps = config.num_inference_steps or 100
|
||||||
|
self.training_noise_sampling = config.training_noise_sampling
|
||||||
|
if config.training_noise_sampling == "uniform":
|
||||||
|
self.noise_distribution = torch.distributions.Uniform(
|
||||||
|
low=0,
|
||||||
|
high=1,
|
||||||
|
)
|
||||||
|
elif config.training_noise_sampling == "beta":
|
||||||
|
# From the Pi0 paper, https://www.physicalintelligence.company/download/pi0.pdf Appendix B.
|
||||||
|
# There, they say the PDF for the distribution they use is the following:
|
||||||
|
# $p(t) = Beta(s-t / s; 1.5, 1)$
|
||||||
|
# So, we first figure out the distribution over $t' = s - s * t$ and then transform it to $t$.
|
||||||
|
s = 0.999 # constant from the paper
|
||||||
|
beta_dist = torch.distributions.Beta(
|
||||||
|
concentration1=1.5, # alpha
|
||||||
|
concentration0=1.0, # beta
|
||||||
|
)
|
||||||
|
affine_transform = torch.distributions.transforms.AffineTransform(loc=s, scale=-s)
|
||||||
|
self.noise_distribution = torch.distributions.TransformedDistribution(
|
||||||
|
beta_dist, [affine_transform]
|
||||||
|
)
|
||||||
|
|
||||||
# ========= inference ============
|
# ========= inference ============
|
||||||
def conditional_sample(
|
def conditional_sample(
|
||||||
|
@ -447,7 +465,7 @@ class DiTFlowModel(nn.Module):
|
||||||
|
|
||||||
# Expand global conditioning to the batch size.
|
# Expand global conditioning to the batch size.
|
||||||
if global_cond is not None:
|
if global_cond is not None:
|
||||||
global_cond = global_cond.unsqueeze(0).expand(batch_size, -1).to(device=device, dtype=dtype)
|
global_cond = global_cond.expand(batch_size, -1).to(device=device, dtype=dtype)
|
||||||
|
|
||||||
# Sample prior.
|
# Sample prior.
|
||||||
sample = self.velocity_net.sample(
|
sample = self.velocity_net.sample(
|
||||||
|
@ -550,10 +568,7 @@ class DiTFlowModel(nn.Module):
|
||||||
# Sample noise to add to the trajectory.
|
# Sample noise to add to the trajectory.
|
||||||
noise = self.velocity_net.sample_noise(trajectory.shape[0], trajectory.device)
|
noise = self.velocity_net.sample_noise(trajectory.shape[0], trajectory.device)
|
||||||
# Sample a random noising timestep for each item in the batch.
|
# Sample a random noising timestep for each item in the batch.
|
||||||
timesteps = torch.rand(
|
timesteps = self.noise_distribution.sample((trajectory.shape[0],)).to(trajectory.device)
|
||||||
size=(trajectory.shape[0],),
|
|
||||||
device=trajectory.device,
|
|
||||||
)
|
|
||||||
# Add noise to the clean trajectories according to the noise magnitude at each timestep.
|
# Add noise to the clean trajectories according to the noise magnitude at each timestep.
|
||||||
noisy_trajectory = (1 - timesteps[:, None, None]) * noise + timesteps[:, None, None] * trajectory
|
noisy_trajectory = (1 - timesteps[:, None, None]) * noise + timesteps[:, None, None] * trajectory
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue