move if self.vqvae_model.check_discretized(): part inside def pretrain_vqvae

This commit is contained in:
jayLEE0301 2024-06-05 13:56:21 -04:00
parent 1778dee9ab
commit eedc131d27
1 changed files with 5 additions and 6 deletions

View File

@ -380,12 +380,6 @@ class VQBeTHead(nn.Module):
def discretize(self, discretize_step, actions):
loss, n_different_codes, n_different_combinations = pretrain_vqvae(self.vqvae_model, discretize_step, actions)
# if we updated RVQ more than `discretize_step` steps,
if self.vqvae_model.check_discretized():
print("Finished discretizing action data!")
self.vqvae_model.eval()
for param in self.vqvae_model.vq_layer.parameters():
param.requires_grad = False
return loss, n_different_codes, n_different_combinations
def forward(self, x, **kwargs):
@ -932,8 +926,13 @@ def pretrain_vqvae(vqvae_model, discretize_step, actions):
n_different_codes = len(torch.unique(metric[2]))
n_different_combinations = len(torch.unique(metric[2], dim=0))
vqvae_model.optimized_steps += 1
# if we updated RVQ more than `discretize_step` steps,
if vqvae_model.optimized_steps >= discretize_step:
vqvae_model.toggle_discretized(True)
print("Finished discretizing action data!")
vqvae_model.eval()
for param in vqvae_model.vq_layer.parameters():
param.requires_grad = False
return loss, n_different_codes, n_different_combinations