remove def draw_logits_forward since it is not used, and change def draw_code_forward to def get_embeddings_from_code

This commit is contained in:
jayLEE0301 2024-06-08 17:29:01 -04:00
parent 2f0e601d9f
commit cff71e6a54
1 changed files with 3 additions and 16 deletions

View File

@ -449,8 +449,8 @@ class VQBeTHead(nn.Module):
# Then, sum the offsets over the RVQ layers to get a net offset for the bin prediction
sampled_offsets = sampled_offsets.sum(dim=1)
with torch.no_grad():
# Get the centroids of each layer to pass it through RVQ decoder
return_decoder_input = self.vqvae_model.draw_code_forward(sampled_centers).clone().detach()
# Get the centroids (= vectors corresponding to the codes) of each layer to pass it through RVQ decoder
return_decoder_input = self.vqvae_model.get_embeddings_from_code(sampled_centers).clone().detach()
# pass the centroids through decoder to get actions.
decoded_action = (
self.vqvae_model.get_action_from_latent(return_decoder_input)
@ -776,11 +776,7 @@ class VqVae(nn.Module):
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, self.config.output_shapes["action"][0] * self.config.action_chunk_size],
)
def draw_logits_forward(self, encoding_logits):
z_embed = self.vq_layer.draw_logits_forward(encoding_logits)
return z_embed
def draw_code_forward(self, encoding_indices):
def get_embeddings_from_code(self, encoding_indices):
with torch.no_grad():
z_embed = self.vq_layer.get_codes_from_indices(encoding_indices)
z_embed = z_embed.sum(dim=0)
@ -994,15 +990,6 @@ class ResidualVQ(nn.Module):
return all_codes
def draw_logits_forward(self, encoding_logits):
# encoding_indices : dim1 = batch_size dim2 = 4 (number of groups) dim3 = vq dict size (header)
encoding_logits = encoding_logits
bs = encoding_logits.shape[0]
quantized = torch.zeros((bs, self.codebooks.shape[-1]))
for q in range(encoding_logits.shape[1]):
quantized += torch.matmul(encoding_logits[:, q], self.codebooks[q])
return quantized
def forward(
self, x, indices=None, return_all_codes=False, sample_codebook_temp=None
):