make the code to handle dynamically

This commit is contained in:
jayLEE0301 2024-06-05 14:01:40 -04:00
parent eedc131d27
commit fd8fc11342
1 changed files with 3 additions and 3 deletions

View File

@ -516,10 +516,10 @@ class VQBeTHead(nn.Module):
cbet_logits[:, 0, :],
action_bins[:, 0],
)
# calculate secondary code prediction loss
# calculate secondary code prediction loss (if there are more than 2 layers in RVQ, then this part will calculate all the loss for remaining layers together)
cbet_loss2 = self._focal_loss_fn(
cbet_logits[:, 1, :],
action_bins[:, 1],
cbet_logits[:, 1:, :],
action_bins[:, 1:],
)
# add all the prediction loss
cbet_loss = cbet_loss1 * self.config.primary_code_loss_weight + cbet_loss2 * self.config.secondary_code_loss_weight