add bin pred temperature

This commit is contained in:
jayLEE0301 2024-05-24 15:44:18 -04:00
parent 8ee1e53fee
commit b3e0ec1afd
3 changed files with 4 additions and 2 deletions

View File

@ -108,6 +108,7 @@ class VQBeTConfig:
mlp_hidden_dim: int = 1024
offset_loss_weight: float = 10000.
secondary_code_loss_weight: float = 0.5
bet_softmax_temperature: float = 0.1
def __post_init__(self):
"""Input validation (not exhaustive)."""

View File

@ -270,7 +270,7 @@ class VQBeTHead(nn.Module):
cbet_offsets = einops.rearrange(
cbet_offsets, "(NT) (G C WA) -> (NT) G C WA", G=self.config.vqvae_groups, C=self.config.vqvae_n_embed
)
cbet_probs = torch.softmax(cbet_logits, dim=-1)
cbet_probs = torch.softmax(cbet_logits / self.config.bet_softmax_temperature, dim=-1)
NT, G, choices = cbet_probs.shape
sampled_centers = einops.rearrange(
torch.multinomial(cbet_probs.view(-1, choices), num_samples=1),

View File

@ -99,4 +99,5 @@ policy:
dropout: 0.1
mlp_hidden_dim: 1024
offset_loss_weight: 10000.
secondary_code_loss_weight: 0.5
secondary_code_loss_weight: 0.5
bet_softmax_temperature: 0.1