Updated standard scheduling
This commit is contained in:
parent
5d7a0ce32e
commit
86e75ab7f8
|
@ -131,7 +131,7 @@ class DiTFlowConfig(PreTrainedConfig):
|
||||||
|
|
||||||
# Noise scheduler.
|
# Noise scheduler.
|
||||||
training_noise_sampling: str = (
|
training_noise_sampling: str = (
|
||||||
"beta" # "uniform" or "beta", from pi0 https://www.physicalintelligence.company/download/pi0.pdf
|
"uniform" # "uniform" or "beta", from pi0 https://www.physicalintelligence.company/download/pi0.pdf
|
||||||
)
|
)
|
||||||
clip_sample: bool = True
|
clip_sample: bool = True
|
||||||
clip_sample_range: float = 1.0
|
clip_sample_range: float = 1.0
|
||||||
|
|
|
@ -139,9 +139,6 @@ class _DiTDecoder(nn.Module):
|
||||||
x = x + self.attn_gate(self.dropout1(x2), cond)
|
x = x + self.attn_gate(self.dropout1(x2), cond)
|
||||||
|
|
||||||
x3 = self.mlp_modulate(self.norm2(x), cond)
|
x3 = self.mlp_modulate(self.norm2(x), cond)
|
||||||
# TODO: verify and then remove
|
|
||||||
# x3 = self.linear2(self.dropout2(self.activation(self.linear1(x3))))
|
|
||||||
# x3 = self.mlp_gate(self.dropout3(x3), cond)
|
|
||||||
x3 = self.mlp(x3)
|
x3 = self.mlp(x3)
|
||||||
x3 = self.mlp_gate(x3, cond)
|
x3 = self.mlp_gate(x3, cond)
|
||||||
return x + x3
|
return x + x3
|
||||||
|
@ -202,7 +199,7 @@ class _DiTNoiseNet(nn.Module):
|
||||||
num_blocks=6,
|
num_blocks=6,
|
||||||
dropout=0.1,
|
dropout=0.1,
|
||||||
dim_feedforward=2048,
|
dim_feedforward=2048,
|
||||||
nhead=6,
|
nhead=8,
|
||||||
activation="gelu",
|
activation="gelu",
|
||||||
clip_sample=False,
|
clip_sample=False,
|
||||||
clip_sample_range=1.0,
|
clip_sample_range=1.0,
|
||||||
|
@ -441,8 +438,8 @@ class DiTFlowModel(nn.Module):
|
||||||
elif config.training_noise_sampling == "beta":
|
elif config.training_noise_sampling == "beta":
|
||||||
# From the Pi0 paper, https://www.physicalintelligence.company/download/pi0.pdf Appendix B.
|
# From the Pi0 paper, https://www.physicalintelligence.company/download/pi0.pdf Appendix B.
|
||||||
# There, they say the PDF for the distribution they use is the following:
|
# There, they say the PDF for the distribution they use is the following:
|
||||||
# $p(t) = Beta(s-t / s; 1.5, 1)$
|
# $p(t) = Beta((s-t) / s; 1.5, 1)$
|
||||||
# So, we first figure out the distribution over $t' = s - s * t$ and then transform it to $t$.
|
# So, we first figure out the distribution over $t'$ and then transform it to $t = s - s * t'$.
|
||||||
s = 0.999 # constant from the paper
|
s = 0.999 # constant from the paper
|
||||||
beta_dist = torch.distributions.Beta(
|
beta_dist = torch.distributions.Beta(
|
||||||
concentration1=1.5, # alpha
|
concentration1=1.5, # alpha
|
||||||
|
|
Loading…
Reference in New Issue