remove toggle_discretized, load_state_dict, and change freeze_codebook in class ResidualVQ as resigered buffer
This commit is contained in:
parent
670c08a2e4
commit
cde1804bc7
|
@ -851,13 +851,6 @@ class VqVae(nn.Module):
|
|||
return rep_loss, metric
|
||||
|
||||
|
||||
def load_state_dict(self, *args, **kwargs):
|
||||
super(VqVae, self).state_dict(self, *args, **kwargs)
|
||||
self.toggle_discretized(True)
|
||||
|
||||
def toggle_discretized(self, state=True):
|
||||
self.discretized = torch.tensor(state)
|
||||
self.vq_layer.freeze_codebook = state
|
||||
|
||||
|
||||
def pretrain_vqvae(vqvae_model, n_vqvae_training_steps, actions):
|
||||
|
@ -879,9 +872,10 @@ def pretrain_vqvae(vqvae_model, n_vqvae_training_steps, actions):
|
|||
n_different_codes = len(torch.unique(metric[2]))
|
||||
n_different_combinations = len(torch.unique(metric[2], dim=0))
|
||||
vqvae_model.optimized_steps += 1
|
||||
# if we updated RVQ more than `n_vqvae_training_steps` steps,
|
||||
# if we updated RVQ more than `n_vqvae_training_steps` steps, we freeze the RVQ part.
|
||||
if vqvae_model.optimized_steps >= n_vqvae_training_steps:
|
||||
vqvae_model.toggle_discretized(True)
|
||||
vqvae_model.discretized = torch.tensor(True)
|
||||
vqvae_model.vq_layer.freeze_codebook = torch.tensor(True)
|
||||
print("Finished discretizing action data!")
|
||||
vqvae_model.eval()
|
||||
for param in vqvae_model.vq_layer.parameters():
|
||||
|
@ -962,7 +956,7 @@ class ResidualVQ(nn.Module):
|
|||
|
||||
assert quantize_dropout_cutoff_index >= 0
|
||||
|
||||
self.freeze_codebook = False
|
||||
self.register_buffer('freeze_codebook', torch.tensor(False))
|
||||
self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index
|
||||
self.quantize_dropout_multiple_of = quantize_dropout_multiple_of # encodec paper proposes structured dropout, believe this was set to 4
|
||||
|
||||
|
|
Loading…
Reference in New Issue