Updated standard scheduling
This commit is contained in:
parent
5d7a0ce32e
commit
86e75ab7f8
|
@ -131,7 +131,7 @@ class DiTFlowConfig(PreTrainedConfig):
|
|||
|
||||
# Noise scheduler.
|
||||
training_noise_sampling: str = (
|
||||
"beta" # "uniform" or "beta", from pi0 https://www.physicalintelligence.company/download/pi0.pdf
|
||||
"uniform" # "uniform" or "beta", from pi0 https://www.physicalintelligence.company/download/pi0.pdf
|
||||
)
|
||||
clip_sample: bool = True
|
||||
clip_sample_range: float = 1.0
|
||||
|
|
|
@ -139,9 +139,6 @@ class _DiTDecoder(nn.Module):
|
|||
x = x + self.attn_gate(self.dropout1(x2), cond)
|
||||
|
||||
x3 = self.mlp_modulate(self.norm2(x), cond)
|
||||
# TODO: verify and then remove
|
||||
# x3 = self.linear2(self.dropout2(self.activation(self.linear1(x3))))
|
||||
# x3 = self.mlp_gate(self.dropout3(x3), cond)
|
||||
x3 = self.mlp(x3)
|
||||
x3 = self.mlp_gate(x3, cond)
|
||||
return x + x3
|
||||
|
@ -202,7 +199,7 @@ class _DiTNoiseNet(nn.Module):
|
|||
num_blocks=6,
|
||||
dropout=0.1,
|
||||
dim_feedforward=2048,
|
||||
nhead=6,
|
||||
nhead=8,
|
||||
activation="gelu",
|
||||
clip_sample=False,
|
||||
clip_sample_range=1.0,
|
||||
|
@ -441,8 +438,8 @@ class DiTFlowModel(nn.Module):
|
|||
elif config.training_noise_sampling == "beta":
|
||||
# From the Pi0 paper, https://www.physicalintelligence.company/download/pi0.pdf Appendix B.
|
||||
# There, they say the PDF for the distribution they use is the following:
|
||||
# $p(t) = Beta(s-t / s; 1.5, 1)$
|
||||
# So, we first figure out the distribution over $t' = s - s * t$ and then transform it to $t$.
|
||||
# $p(t) = Beta((s-t) / s; 1.5, 1)$
|
||||
# So, we first figure out the distribution over $t'$ and then transform it to $t = s - s * t'$.
|
||||
s = 0.999 # constant from the paper
|
||||
beta_dist = torch.distributions.Beta(
|
||||
concentration1=1.5, # alpha
|
||||
|
|
Loading…
Reference in New Issue