fix header -> head, bind pred -> bin prediction

This commit is contained in:
jayLEE0301 2024-06-04 17:10:39 -04:00
parent 87842c0d19
commit 833d440ebf
1 changed files with 2 additions and 2 deletions

View File

@ -115,7 +115,7 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
# n_different_combinations: how many different code combinations you are using out of all possible code combinations (max: vqvae_n_embed ^ vqvae_groups).
loss, n_different_codes, n_different_combinations = self.vqbet.discretize(self.config.discretize_step, batch['action'])
return {"loss": loss, "n_different_codes": n_different_codes, "n_different_combinations": n_different_combinations}
# if Residual VQ is already trained, VQ-BeT trains its GPT and bin ped header / offset header parts.
# if Residual VQ is already trained, VQ-BeT trains its GPT and bin prediction head / offset prediction head parts.
_, loss_dict = self.vqbet(batch, rollout=False)
return loss_dict
@ -269,7 +269,7 @@ class VQBeTModel(nn.Module):
# GPT part of VQ-BeT
self.policy = GPT(config)
# bin prediction header / offset prediction header part of VQ-BeT
# bin prediction head / offset prediction head part of VQ-BeT
self.action_head = VQBeTHead(config)
def discretize(self, discretize_step, actions):