add bin pred temperature
This commit is contained in:
parent
8ee1e53fee
commit
b3e0ec1afd
|
@ -108,6 +108,7 @@ class VQBeTConfig:
|
|||
mlp_hidden_dim: int = 1024
|
||||
offset_loss_weight: float = 10000.
|
||||
secondary_code_loss_weight: float = 0.5
|
||||
bet_softmax_temperature: float = 0.1
|
||||
|
||||
def __post_init__(self):
|
||||
"""Input validation (not exhaustive)."""
|
||||
|
|
|
@ -270,7 +270,7 @@ class VQBeTHead(nn.Module):
|
|||
cbet_offsets = einops.rearrange(
|
||||
cbet_offsets, "(NT) (G C WA) -> (NT) G C WA", G=self.config.vqvae_groups, C=self.config.vqvae_n_embed
|
||||
)
|
||||
cbet_probs = torch.softmax(cbet_logits, dim=-1)
|
||||
cbet_probs = torch.softmax(cbet_logits / self.config.bet_softmax_temperature, dim=-1)
|
||||
NT, G, choices = cbet_probs.shape
|
||||
sampled_centers = einops.rearrange(
|
||||
torch.multinomial(cbet_probs.view(-1, choices), num_samples=1),
|
||||
|
|
|
@ -99,4 +99,5 @@ policy:
|
|||
dropout: 0.1
|
||||
mlp_hidden_dim: 1024
|
||||
offset_loss_weight: 10000.
|
||||
secondary_code_loss_weight: 0.5
|
||||
secondary_code_loss_weight: 0.5
|
||||
bet_softmax_temperature: 0.1
|
Loading…
Reference in New Issue