add comments in class VqVae, change get_codes_from_indices -> get_codebook_vector_from_indices
This commit is contained in:
parent
cff71e6a54
commit
32fa5d22b9
|
@ -753,11 +753,14 @@ class VqVae(nn.Module):
|
|||
VQ-VAE is composed of three parts: encoder, vq_layer, and decoder.
|
||||
Encoder and decoder are MLPs consisting of an input, output layer, and hidden layer, respectively.
|
||||
The vq_layer uses residual VQs.
|
||||
|
||||
This class contains functions for training the encoder and decoder along with the residual VQ layer (for trainign phase 1),
|
||||
as well as functions to help BeT training part in training phase 2.
|
||||
"""
|
||||
|
||||
super(VqVae, self).__init__()
|
||||
self.config = config
|
||||
|
||||
# 'discretized' indicates whether the Residual VQ part is trained or not. (After finishing the training, we set discretized=True)
|
||||
self.register_buffer('discretized', torch.tensor(False))
|
||||
self.optimized_steps = 0
|
||||
|
||||
|
@ -777,12 +780,15 @@ class VqVae(nn.Module):
|
|||
)
|
||||
|
||||
def get_embeddings_from_code(self, encoding_indices):
|
||||
# This function gets code indices as inputs, and outputs embedding vectors corresponding to the code indices.
|
||||
with torch.no_grad():
|
||||
z_embed = self.vq_layer.get_codes_from_indices(encoding_indices)
|
||||
z_embed = self.vq_layer.get_codebook_vector_from_indices(encoding_indices)
|
||||
# since the RVQ has multiple layers, it adds the vectors in the axis of layers to provide a vector for that code combination.
|
||||
z_embed = z_embed.sum(dim=0)
|
||||
return z_embed
|
||||
|
||||
def get_action_from_latent(self, latent):
|
||||
# given latent vector, this function outputs the decoded action.
|
||||
output = self.decoder(latent)
|
||||
if self.config.action_chunk_size == 1:
|
||||
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0])
|
||||
|
@ -790,6 +796,8 @@ class VqVae(nn.Module):
|
|||
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0])
|
||||
|
||||
def get_code(self, state):
|
||||
# in phase 2 of VQ-BeT training, we need a `GT code` to calculate the Focal loss for code prediction head.
|
||||
# this function outputs the `GT code` of given action using frozen encoder and quantization layers. (please refer to Figure 2. in the paper https://arxiv.org/pdf/2403.03181)
|
||||
state = einops.rearrange(state, "N T A -> N (T A)")
|
||||
with torch.no_grad():
|
||||
state_rep = self.encoder(state)
|
||||
|
@ -802,18 +810,23 @@ class VqVae(nn.Module):
|
|||
return state_vq, vq_code
|
||||
|
||||
def vqvae_forward(self, state):
|
||||
# This function passes the given data through Residual VQ with Encoder and Decoder. Please refer to section 3.2 in the paper https://arxiv.org/pdf/2403.03181).
|
||||
state = einops.rearrange(state, "N T A -> N (T A)")
|
||||
# We start with passing action (or action chunk) at:t+n through the encoder ϕ.
|
||||
state_rep = self.encoder(state)
|
||||
state_rep_shape = state_rep.shape[:-1]
|
||||
state_rep_flat = state_rep.view(state_rep.size(0), -1, state_rep.size(1))
|
||||
# The resulting latent embedding vector x = ϕ(at:t+n) is then mapped to an embedding vector in the codebook of the RVQ layers by the nearest neighbor look-up.
|
||||
state_rep_flat, vq_code, vq_loss_state = self.vq_layer(state_rep_flat)
|
||||
state_vq = state_rep_flat.view(*state_rep_shape, -1)
|
||||
vq_code = vq_code.view(*state_rep_shape, -1)
|
||||
# since the RVQ has multiple layers, it adds the vectors in the axis of layers to provide a vector for that code combination.
|
||||
vq_loss_state = torch.sum(vq_loss_state)
|
||||
|
||||
# Then, the discretized vector zq(x) is reconstructed as ψ(zq(x)) by passing through the decoder ψ.
|
||||
dec_out = self.decoder(state_vq)
|
||||
# Calculate L1 reconstruction loss
|
||||
encoder_loss = (state - dec_out).abs().mean()
|
||||
|
||||
# add encoder reconstruction loss and commitment loss
|
||||
rep_loss = encoder_loss + vq_loss_state * 5
|
||||
|
||||
metric = (
|
||||
|
@ -950,7 +963,7 @@ class ResidualVQ(nn.Module):
|
|||
codebooks = rearrange(codebooks, "q 1 c d -> q c d")
|
||||
return codebooks
|
||||
|
||||
def get_codes_from_indices(self, indices):
|
||||
def get_codebook_vector_from_indices(self, indices):
|
||||
batch, quantize_dim = indices.shape[0], indices.shape[-1]
|
||||
|
||||
# may also receive indices in the shape of 'b h w q' (accept_image_fmap)
|
||||
|
@ -1098,7 +1111,7 @@ class ResidualVQ(nn.Module):
|
|||
|
||||
if return_all_codes:
|
||||
# whether to return all codes from all codebooks across layers
|
||||
all_codes = self.get_codes_from_indices(all_indices)
|
||||
all_codes = self.get_codebook_vector_from_indices(all_indices)
|
||||
|
||||
# will return all codes in shape (quantizer, batch, sequence length, codebook dimension)
|
||||
ret = (*ret, all_codes)
|
||||
|
@ -1252,7 +1265,7 @@ class VectorQuantize(nn.Module):
|
|||
|
||||
self._codebook.embed.copy_(codes)
|
||||
|
||||
def get_codes_from_indices(self, indices):
|
||||
def get_codebook_vector_from_indices(self, indices):
|
||||
codebook = self.codebook
|
||||
is_multiheaded = codebook.ndim > 2
|
||||
|
||||
|
|
Loading…
Reference in New Issue