From b3e0ec1afdb5e103dd96883eb92228e3a780c3ce Mon Sep 17 00:00:00 2001 From: jayLEE0301 Date: Fri, 24 May 2024 15:44:18 -0400 Subject: [PATCH] add bin pred temperature --- lerobot/common/policies/vqbet/configuration_vqbet.py | 1 + lerobot/common/policies/vqbet/modeling_vqbet.py | 2 +- lerobot/configs/policy/vqbet.yaml | 3 ++- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/lerobot/common/policies/vqbet/configuration_vqbet.py b/lerobot/common/policies/vqbet/configuration_vqbet.py index 5b2e4b36..fbe8773d 100644 --- a/lerobot/common/policies/vqbet/configuration_vqbet.py +++ b/lerobot/common/policies/vqbet/configuration_vqbet.py @@ -108,6 +108,7 @@ class VQBeTConfig: mlp_hidden_dim: int = 1024 offset_loss_weight: float = 10000. secondary_code_loss_weight: float = 0.5 + bet_softmax_temperature: float = 0.1 def __post_init__(self): """Input validation (not exhaustive).""" diff --git a/lerobot/common/policies/vqbet/modeling_vqbet.py b/lerobot/common/policies/vqbet/modeling_vqbet.py index 437d04e6..e8ca5d8c 100644 --- a/lerobot/common/policies/vqbet/modeling_vqbet.py +++ b/lerobot/common/policies/vqbet/modeling_vqbet.py @@ -270,7 +270,7 @@ class VQBeTHead(nn.Module): cbet_offsets = einops.rearrange( cbet_offsets, "(NT) (G C WA) -> (NT) G C WA", G=self.config.vqvae_groups, C=self.config.vqvae_n_embed ) - cbet_probs = torch.softmax(cbet_logits, dim=-1) + cbet_probs = torch.softmax(cbet_logits / self.config.bet_softmax_temperature, dim=-1) NT, G, choices = cbet_probs.shape sampled_centers = einops.rearrange( torch.multinomial(cbet_probs.view(-1, choices), num_samples=1), diff --git a/lerobot/configs/policy/vqbet.yaml b/lerobot/configs/policy/vqbet.yaml index 2e6eb038..0844bec9 100644 --- a/lerobot/configs/policy/vqbet.yaml +++ b/lerobot/configs/policy/vqbet.yaml @@ -99,4 +99,5 @@ policy: dropout: 0.1 mlp_hidden_dim: 1024 offset_loss_weight: 10000. - secondary_code_loss_weight: 0.5 \ No newline at end of file + secondary_code_loss_weight: 0.5 + bet_softmax_temperature: 0.1 \ No newline at end of file