move if self.vqvae_model.check_discretized(): part inside def pretrain_vqvae
This commit is contained in:
parent
1778dee9ab
commit
eedc131d27
|
@ -380,12 +380,6 @@ class VQBeTHead(nn.Module):
|
|||
|
||||
def discretize(self, discretize_step, actions):
|
||||
loss, n_different_codes, n_different_combinations = pretrain_vqvae(self.vqvae_model, discretize_step, actions)
|
||||
# if we updated RVQ more than `discretize_step` steps,
|
||||
if self.vqvae_model.check_discretized():
|
||||
print("Finished discretizing action data!")
|
||||
self.vqvae_model.eval()
|
||||
for param in self.vqvae_model.vq_layer.parameters():
|
||||
param.requires_grad = False
|
||||
return loss, n_different_codes, n_different_combinations
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
|
@ -932,8 +926,13 @@ def pretrain_vqvae(vqvae_model, discretize_step, actions):
|
|||
n_different_codes = len(torch.unique(metric[2]))
|
||||
n_different_combinations = len(torch.unique(metric[2], dim=0))
|
||||
vqvae_model.optimized_steps += 1
|
||||
# if we updated RVQ more than `discretize_step` steps,
|
||||
if vqvae_model.optimized_steps >= discretize_step:
|
||||
vqvae_model.toggle_discretized(True)
|
||||
print("Finished discretizing action data!")
|
||||
vqvae_model.eval()
|
||||
for param in vqvae_model.vq_layer.parameters():
|
||||
param.requires_grad = False
|
||||
return loss, n_different_codes, n_different_combinations
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue