make the code to handle dynamically
This commit is contained in:
parent
eedc131d27
commit
fd8fc11342
|
@ -516,10 +516,10 @@ class VQBeTHead(nn.Module):
|
|||
cbet_logits[:, 0, :],
|
||||
action_bins[:, 0],
|
||||
)
|
||||
# calculate secondary code prediction loss
|
||||
# calculate secondary code prediction loss (if there are more than 2 layers in RVQ, then this part will calculate all the loss for remaining layers together)
|
||||
cbet_loss2 = self._focal_loss_fn(
|
||||
cbet_logits[:, 1, :],
|
||||
action_bins[:, 1],
|
||||
cbet_logits[:, 1:, :],
|
||||
action_bins[:, 1:],
|
||||
)
|
||||
# add all the prediction loss
|
||||
cbet_loss = cbet_loss1 * self.config.primary_code_loss_weight + cbet_loss2 * self.config.secondary_code_loss_weight
|
||||
|
|
Loading…
Reference in New Issue