fix header -> head, bind pred -> bin prediction
This commit is contained in:
parent
87842c0d19
commit
833d440ebf
|
@ -115,7 +115,7 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
# n_different_combinations: how many different code combinations you are using out of all possible code combinations (max: vqvae_n_embed ^ vqvae_groups).
|
||||
loss, n_different_codes, n_different_combinations = self.vqbet.discretize(self.config.discretize_step, batch['action'])
|
||||
return {"loss": loss, "n_different_codes": n_different_codes, "n_different_combinations": n_different_combinations}
|
||||
# if Residual VQ is already trained, VQ-BeT trains its GPT and bin ped header / offset header parts.
|
||||
# if Residual VQ is already trained, VQ-BeT trains its GPT and bin prediction head / offset prediction head parts.
|
||||
_, loss_dict = self.vqbet(batch, rollout=False)
|
||||
|
||||
return loss_dict
|
||||
|
@ -269,7 +269,7 @@ class VQBeTModel(nn.Module):
|
|||
|
||||
# GPT part of VQ-BeT
|
||||
self.policy = GPT(config)
|
||||
# bin prediction header / offset prediction header part of VQ-BeT
|
||||
# bin prediction head / offset prediction head part of VQ-BeT
|
||||
self.action_head = VQBeTHead(config)
|
||||
|
||||
def discretize(self, discretize_step, actions):
|
||||
|
|
Loading…
Reference in New Issue