removed def train, def check_discretized, and restore def toggle_discretized
This commit is contained in:
parent
1257aaf438
commit
0b48268689
|
@ -86,7 +86,7 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
batch = self.normalize_inputs(batch)
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
assert self.vqbet.action_head.vqvae_model.check_discretized(), "To evaluate in the environment, your VQ-BeT model should contain a pretrained Residual VQ."
|
||||
assert self.vqbet.action_head.vqvae_model.discretized.item(), "To evaluate in the environment, your VQ-BeT model should contain a pretrained Residual VQ."
|
||||
assert "observation.image" in batch
|
||||
assert "observation.state" in batch
|
||||
|
||||
|
@ -108,7 +108,7 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
# VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://arxiv.org/pdf/2403.03181)
|
||||
if not self.vqbet.action_head.vqvae_model.check_discretized():
|
||||
if not self.vqbet.action_head.vqvae_model.discretized.item():
|
||||
# loss: total loss of training RVQ
|
||||
# n_different_codes: how many of total possible codes are being used (max: vqvae_n_embed).
|
||||
# n_different_combinations: how many different code combinations you are using out of all possible code combinations (max: vqvae_n_embed ^ vqvae_groups).
|
||||
|
@ -776,28 +776,6 @@ class VqVae(nn.Module):
|
|||
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, self.config.output_shapes["action"][0] * self.config.action_chunk_size],
|
||||
)
|
||||
|
||||
|
||||
def check_discretized(self):
|
||||
return self.discretized.item()
|
||||
|
||||
|
||||
def train(self, mode=True):
|
||||
"""
|
||||
This function forces the RVQ to no longer update when action discretization is complete.
|
||||
Since VQs are partly updated via the EMA method, simply passing data through them can cause unintended modifications.
|
||||
Therefore, we use function overriding to prevent RVQs from being updated during the training of VQ-BeT after discretization completes.
|
||||
"""
|
||||
if mode:
|
||||
if self.check_discretized():
|
||||
pass
|
||||
else:
|
||||
self.training = True
|
||||
self.vq_layer.train()
|
||||
self.decoder.train()
|
||||
self.encoder.train()
|
||||
else:
|
||||
self.eval()
|
||||
|
||||
def draw_logits_forward(self, encoding_logits):
|
||||
z_embed = self.vq_layer.draw_logits_forward(encoding_logits)
|
||||
return z_embed
|
||||
|
@ -875,10 +853,11 @@ class VqVae(nn.Module):
|
|||
|
||||
def load_state_dict(self, *args, **kwargs):
|
||||
super(VqVae, self).state_dict(self, *args, **kwargs)
|
||||
self.discretized = torch.tensor(True)
|
||||
|
||||
|
||||
self.toggle_discretized(True)
|
||||
|
||||
def toggle_discretized(self, state=True):
|
||||
self.discretized = torch.tensor(state)
|
||||
self.vq_layer.freeze_codebook = state
|
||||
|
||||
|
||||
def pretrain_vqvae(vqvae_model, n_vqvae_training_steps, actions):
|
||||
|
@ -902,7 +881,7 @@ def pretrain_vqvae(vqvae_model, n_vqvae_training_steps, actions):
|
|||
vqvae_model.optimized_steps += 1
|
||||
# if we updated RVQ more than `n_vqvae_training_steps` steps,
|
||||
if vqvae_model.optimized_steps >= n_vqvae_training_steps:
|
||||
vqvae_model.discretized = torch.tensor(True)
|
||||
vqvae_model.toggle_discretized(True)
|
||||
print("Finished discretizing action data!")
|
||||
vqvae_model.eval()
|
||||
for param in vqvae_model.vq_layer.parameters():
|
||||
|
@ -983,6 +962,7 @@ class ResidualVQ(nn.Module):
|
|||
|
||||
assert quantize_dropout_cutoff_index >= 0
|
||||
|
||||
self.freeze_codebook = False
|
||||
self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index
|
||||
self.quantize_dropout_multiple_of = quantize_dropout_multiple_of # encodec paper proposes structured dropout, believe this was set to 4
|
||||
|
||||
|
@ -1124,6 +1104,7 @@ class ResidualVQ(nn.Module):
|
|||
residual,
|
||||
indices=layer_indices,
|
||||
sample_codebook_temp=sample_codebook_temp,
|
||||
freeze_codebook = self.freeze_codebook
|
||||
)
|
||||
|
||||
residual = residual - quantized.detach()
|
||||
|
@ -1178,7 +1159,6 @@ class VectorQuantize(nn.Module):
|
|||
separate_codebook_per_head=False,
|
||||
decay=0.8,
|
||||
eps=1e-5,
|
||||
freeze_codebook=False,
|
||||
kmeans_init=False,
|
||||
kmeans_iters=10,
|
||||
sync_kmeans=True,
|
||||
|
|
Loading…
Reference in New Issue