removed def train, def check_discretized, and restore def toggle_discretized

This commit is contained in:
jayLEE0301 2024-06-08 16:33:43 -04:00
parent 1257aaf438
commit 0b48268689
1 changed files with 9 additions and 29 deletions

View File

@ -86,7 +86,7 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
batch = self.normalize_inputs(batch)
self._queues = populate_queues(self._queues, batch)
assert self.vqbet.action_head.vqvae_model.check_discretized(), "To evaluate in the environment, your VQ-BeT model should contain a pretrained Residual VQ."
assert self.vqbet.action_head.vqvae_model.discretized.item(), "To evaluate in the environment, your VQ-BeT model should contain a pretrained Residual VQ."
assert "observation.image" in batch
assert "observation.state" in batch
@ -108,7 +108,7 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
# VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://arxiv.org/pdf/2403.03181)
if not self.vqbet.action_head.vqvae_model.check_discretized():
if not self.vqbet.action_head.vqvae_model.discretized.item():
# loss: total loss of training RVQ
# n_different_codes: how many of total possible codes are being used (max: vqvae_n_embed).
# n_different_combinations: how many different code combinations you are using out of all possible code combinations (max: vqvae_n_embed ^ vqvae_groups).
@ -776,28 +776,6 @@ class VqVae(nn.Module):
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, self.config.output_shapes["action"][0] * self.config.action_chunk_size],
)
def check_discretized(self):
return self.discretized.item()
def train(self, mode=True):
"""
This function forces the RVQ to no longer update when action discretization is complete.
Since VQs are partly updated via the EMA method, simply passing data through them can cause unintended modifications.
Therefore, we use function overriding to prevent RVQs from being updated during the training of VQ-BeT after discretization completes.
"""
if mode:
if self.check_discretized():
pass
else:
self.training = True
self.vq_layer.train()
self.decoder.train()
self.encoder.train()
else:
self.eval()
def draw_logits_forward(self, encoding_logits):
z_embed = self.vq_layer.draw_logits_forward(encoding_logits)
return z_embed
@ -875,10 +853,11 @@ class VqVae(nn.Module):
def load_state_dict(self, *args, **kwargs):
super(VqVae, self).state_dict(self, *args, **kwargs)
self.discretized = torch.tensor(True)
self.toggle_discretized(True)
def toggle_discretized(self, state=True):
self.discretized = torch.tensor(state)
self.vq_layer.freeze_codebook = state
def pretrain_vqvae(vqvae_model, n_vqvae_training_steps, actions):
@ -902,7 +881,7 @@ def pretrain_vqvae(vqvae_model, n_vqvae_training_steps, actions):
vqvae_model.optimized_steps += 1
# if we updated RVQ more than `n_vqvae_training_steps` steps,
if vqvae_model.optimized_steps >= n_vqvae_training_steps:
vqvae_model.discretized = torch.tensor(True)
vqvae_model.toggle_discretized(True)
print("Finished discretizing action data!")
vqvae_model.eval()
for param in vqvae_model.vq_layer.parameters():
@ -983,6 +962,7 @@ class ResidualVQ(nn.Module):
assert quantize_dropout_cutoff_index >= 0
self.freeze_codebook = False
self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index
self.quantize_dropout_multiple_of = quantize_dropout_multiple_of # encodec paper proposes structured dropout, believe this was set to 4
@ -1124,6 +1104,7 @@ class ResidualVQ(nn.Module):
residual,
indices=layer_indices,
sample_codebook_temp=sample_codebook_temp,
freeze_codebook = self.freeze_codebook
)
residual = residual - quantized.detach()
@ -1178,7 +1159,6 @@ class VectorQuantize(nn.Module):
separate_codebook_per_head=False,
decay=0.8,
eps=1e-5,
freeze_codebook=False,
kmeans_init=False,
kmeans_iters=10,
sync_kmeans=True,