add explanaitons for n_different_codes, n_different_combinations

This commit is contained in:
jayLEE0301 2024-06-04 17:05:04 -04:00
parent 6012b4c859
commit 87842c0d19
1 changed files with 3 additions and 0 deletions

View File

@ -110,6 +110,9 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
batch = self.normalize_targets(batch)
# VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://arxiv.org/pdf/2403.03181)
if not self.vqbet.action_head.vqvae_model.check_discretized():
# loss: total loss of training RVQ
# n_different_codes: how many of total possible codes are being used (max: vqvae_n_embed).
# n_different_combinations: how many different code combinations you are using out of all possible code combinations (max: vqvae_n_embed ^ vqvae_groups).
loss, n_different_codes, n_different_combinations = self.vqbet.discretize(self.config.discretize_step, batch['action'])
return {"loss": loss, "n_different_codes": n_different_codes, "n_different_combinations": n_different_combinations}
# if Residual VQ is already trained, VQ-BeT trains its GPT and bin ped header / offset header parts.