remove toggle_discretized, load_state_dict, and change freeze_codebook in class ResidualVQ as resigered buffer

This commit is contained in:
jayLEE0301 2024-06-08 16:58:15 -04:00
parent 670c08a2e4
commit cde1804bc7
1 changed files with 4 additions and 10 deletions

View File

@ -851,13 +851,6 @@ class VqVae(nn.Module):
return rep_loss, metric
def load_state_dict(self, *args, **kwargs):
super(VqVae, self).state_dict(self, *args, **kwargs)
self.toggle_discretized(True)
def toggle_discretized(self, state=True):
self.discretized = torch.tensor(state)
self.vq_layer.freeze_codebook = state
def pretrain_vqvae(vqvae_model, n_vqvae_training_steps, actions):
@ -879,9 +872,10 @@ def pretrain_vqvae(vqvae_model, n_vqvae_training_steps, actions):
n_different_codes = len(torch.unique(metric[2]))
n_different_combinations = len(torch.unique(metric[2], dim=0))
vqvae_model.optimized_steps += 1
# if we updated RVQ more than `n_vqvae_training_steps` steps,
# if we updated RVQ more than `n_vqvae_training_steps` steps, we freeze the RVQ part.
if vqvae_model.optimized_steps >= n_vqvae_training_steps:
vqvae_model.toggle_discretized(True)
vqvae_model.discretized = torch.tensor(True)
vqvae_model.vq_layer.freeze_codebook = torch.tensor(True)
print("Finished discretizing action data!")
vqvae_model.eval()
for param in vqvae_model.vq_layer.parameters():
@ -962,7 +956,7 @@ class ResidualVQ(nn.Module):
assert quantize_dropout_cutoff_index >= 0
self.freeze_codebook = False
self.register_buffer('freeze_codebook', torch.tensor(False))
self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index
self.quantize_dropout_multiple_of = quantize_dropout_multiple_of # encodec paper proposes structured dropout, believe this was set to 4