Updated standard scheduling

This commit is contained in:
Mahi Shafiullah 2025-02-04 10:58:43 -05:00
parent 5d7a0ce32e
commit 86e75ab7f8
2 changed files with 4 additions and 7 deletions

View File

@ -131,7 +131,7 @@ class DiTFlowConfig(PreTrainedConfig):
# Noise scheduler. # Noise scheduler.
training_noise_sampling: str = ( training_noise_sampling: str = (
"beta" # "uniform" or "beta", from pi0 https://www.physicalintelligence.company/download/pi0.pdf "uniform" # "uniform" or "beta", from pi0 https://www.physicalintelligence.company/download/pi0.pdf
) )
clip_sample: bool = True clip_sample: bool = True
clip_sample_range: float = 1.0 clip_sample_range: float = 1.0

View File

@ -139,9 +139,6 @@ class _DiTDecoder(nn.Module):
x = x + self.attn_gate(self.dropout1(x2), cond) x = x + self.attn_gate(self.dropout1(x2), cond)
x3 = self.mlp_modulate(self.norm2(x), cond) x3 = self.mlp_modulate(self.norm2(x), cond)
# TODO: verify and then remove
# x3 = self.linear2(self.dropout2(self.activation(self.linear1(x3))))
# x3 = self.mlp_gate(self.dropout3(x3), cond)
x3 = self.mlp(x3) x3 = self.mlp(x3)
x3 = self.mlp_gate(x3, cond) x3 = self.mlp_gate(x3, cond)
return x + x3 return x + x3
@ -202,7 +199,7 @@ class _DiTNoiseNet(nn.Module):
num_blocks=6, num_blocks=6,
dropout=0.1, dropout=0.1,
dim_feedforward=2048, dim_feedforward=2048,
nhead=6, nhead=8,
activation="gelu", activation="gelu",
clip_sample=False, clip_sample=False,
clip_sample_range=1.0, clip_sample_range=1.0,
@ -441,8 +438,8 @@ class DiTFlowModel(nn.Module):
elif config.training_noise_sampling == "beta": elif config.training_noise_sampling == "beta":
# From the Pi0 paper, https://www.physicalintelligence.company/download/pi0.pdf Appendix B. # From the Pi0 paper, https://www.physicalintelligence.company/download/pi0.pdf Appendix B.
# There, they say the PDF for the distribution they use is the following: # There, they say the PDF for the distribution they use is the following:
# $p(t) = Beta(s-t / s; 1.5, 1)$ # $p(t) = Beta((s-t) / s; 1.5, 1)$
# So, we first figure out the distribution over $t' = s - s * t$ and then transform it to $t$. # So, we first figure out the distribution over $t'$ and then transform it to $t = s - s * t'$.
s = 0.999 # constant from the paper s = 0.999 # constant from the paper
beta_dist = torch.distributions.Beta( beta_dist = torch.distributions.Beta(
concentration1=1.5, # alpha concentration1=1.5, # alpha