add explanaitons for n_different_codes, n_different_combinations
This commit is contained in:
parent
6012b4c859
commit
87842c0d19
|
@ -110,6 +110,9 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
batch = self.normalize_targets(batch)
|
||||
# VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://arxiv.org/pdf/2403.03181)
|
||||
if not self.vqbet.action_head.vqvae_model.check_discretized():
|
||||
# loss: total loss of training RVQ
|
||||
# n_different_codes: how many of total possible codes are being used (max: vqvae_n_embed).
|
||||
# n_different_combinations: how many different code combinations you are using out of all possible code combinations (max: vqvae_n_embed ^ vqvae_groups).
|
||||
loss, n_different_codes, n_different_combinations = self.vqbet.discretize(self.config.discretize_step, batch['action'])
|
||||
return {"loss": loss, "n_different_codes": n_different_codes, "n_different_combinations": n_different_combinations}
|
||||
# if Residual VQ is already trained, VQ-BeT trains its GPT and bin ped header / offset header parts.
|
||||
|
|
Loading…
Reference in New Issue