add comments in class VqVae, change get_codes_from_indices -> get_codebook_vector_from_indices

This commit is contained in:
jayLEE0301 2024-06-08 18:16:51 -04:00
parent cff71e6a54
commit 32fa5d22b9
1 changed files with 20 additions and 7 deletions

View File

@ -753,11 +753,14 @@ class VqVae(nn.Module):
VQ-VAE is composed of three parts: encoder, vq_layer, and decoder.
Encoder and decoder are MLPs consisting of an input, output layer, and hidden layer, respectively.
The vq_layer uses residual VQs.
This class contains functions for training the encoder and decoder along with the residual VQ layer (for trainign phase 1),
as well as functions to help BeT training part in training phase 2.
"""
super(VqVae, self).__init__()
self.config = config
# 'discretized' indicates whether the Residual VQ part is trained or not. (After finishing the training, we set discretized=True)
self.register_buffer('discretized', torch.tensor(False))
self.optimized_steps = 0
@ -777,12 +780,15 @@ class VqVae(nn.Module):
)
def get_embeddings_from_code(self, encoding_indices):
# This function gets code indices as inputs, and outputs embedding vectors corresponding to the code indices.
with torch.no_grad():
z_embed = self.vq_layer.get_codes_from_indices(encoding_indices)
z_embed = self.vq_layer.get_codebook_vector_from_indices(encoding_indices)
# since the RVQ has multiple layers, it adds the vectors in the axis of layers to provide a vector for that code combination.
z_embed = z_embed.sum(dim=0)
return z_embed
def get_action_from_latent(self, latent):
# given latent vector, this function outputs the decoded action.
output = self.decoder(latent)
if self.config.action_chunk_size == 1:
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0])
@ -790,6 +796,8 @@ class VqVae(nn.Module):
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0])
def get_code(self, state):
# in phase 2 of VQ-BeT training, we need a `GT code` to calculate the Focal loss for code prediction head.
# this function outputs the `GT code` of given action using frozen encoder and quantization layers. (please refer to Figure 2. in the paper https://arxiv.org/pdf/2403.03181)
state = einops.rearrange(state, "N T A -> N (T A)")
with torch.no_grad():
state_rep = self.encoder(state)
@ -802,18 +810,23 @@ class VqVae(nn.Module):
return state_vq, vq_code
def vqvae_forward(self, state):
# This function passes the given data through Residual VQ with Encoder and Decoder. Please refer to section 3.2 in the paper https://arxiv.org/pdf/2403.03181).
state = einops.rearrange(state, "N T A -> N (T A)")
# We start with passing action (or action chunk) at:t+n through the encoder ϕ.
state_rep = self.encoder(state)
state_rep_shape = state_rep.shape[:-1]
state_rep_flat = state_rep.view(state_rep.size(0), -1, state_rep.size(1))
# The resulting latent embedding vector x = ϕ(at:t+n) is then mapped to an embedding vector in the codebook of the RVQ layers by the nearest neighbor look-up.
state_rep_flat, vq_code, vq_loss_state = self.vq_layer(state_rep_flat)
state_vq = state_rep_flat.view(*state_rep_shape, -1)
vq_code = vq_code.view(*state_rep_shape, -1)
# since the RVQ has multiple layers, it adds the vectors in the axis of layers to provide a vector for that code combination.
vq_loss_state = torch.sum(vq_loss_state)
# Then, the discretized vector zq(x) is reconstructed as ψ(zq(x)) by passing through the decoder ψ.
dec_out = self.decoder(state_vq)
# Calculate L1 reconstruction loss
encoder_loss = (state - dec_out).abs().mean()
# add encoder reconstruction loss and commitment loss
rep_loss = encoder_loss + vq_loss_state * 5
metric = (
@ -950,7 +963,7 @@ class ResidualVQ(nn.Module):
codebooks = rearrange(codebooks, "q 1 c d -> q c d")
return codebooks
def get_codes_from_indices(self, indices):
def get_codebook_vector_from_indices(self, indices):
batch, quantize_dim = indices.shape[0], indices.shape[-1]
# may also receive indices in the shape of 'b h w q' (accept_image_fmap)
@ -1098,7 +1111,7 @@ class ResidualVQ(nn.Module):
if return_all_codes:
# whether to return all codes from all codebooks across layers
all_codes = self.get_codes_from_indices(all_indices)
all_codes = self.get_codebook_vector_from_indices(all_indices)
# will return all codes in shape (quantizer, batch, sequence length, codebook dimension)
ret = (*ret, all_codes)
@ -1252,7 +1265,7 @@ class VectorQuantize(nn.Module):
self._codebook.embed.copy_(codes)
def get_codes_from_indices(self, indices):
def get_codebook_vector_from_indices(self, indices):
codebook = self.codebook
is_multiheaded = codebook.ndim > 2