remove def draw_logits_forward since it is not used, and change def draw_code_forward to def get_embeddings_from_code
This commit is contained in:
parent
2f0e601d9f
commit
cff71e6a54
|
@ -449,8 +449,8 @@ class VQBeTHead(nn.Module):
|
|||
# Then, sum the offsets over the RVQ layers to get a net offset for the bin prediction
|
||||
sampled_offsets = sampled_offsets.sum(dim=1)
|
||||
with torch.no_grad():
|
||||
# Get the centroids of each layer to pass it through RVQ decoder
|
||||
return_decoder_input = self.vqvae_model.draw_code_forward(sampled_centers).clone().detach()
|
||||
# Get the centroids (= vectors corresponding to the codes) of each layer to pass it through RVQ decoder
|
||||
return_decoder_input = self.vqvae_model.get_embeddings_from_code(sampled_centers).clone().detach()
|
||||
# pass the centroids through decoder to get actions.
|
||||
decoded_action = (
|
||||
self.vqvae_model.get_action_from_latent(return_decoder_input)
|
||||
|
@ -776,11 +776,7 @@ class VqVae(nn.Module):
|
|||
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, self.config.output_shapes["action"][0] * self.config.action_chunk_size],
|
||||
)
|
||||
|
||||
def draw_logits_forward(self, encoding_logits):
|
||||
z_embed = self.vq_layer.draw_logits_forward(encoding_logits)
|
||||
return z_embed
|
||||
|
||||
def draw_code_forward(self, encoding_indices):
|
||||
def get_embeddings_from_code(self, encoding_indices):
|
||||
with torch.no_grad():
|
||||
z_embed = self.vq_layer.get_codes_from_indices(encoding_indices)
|
||||
z_embed = z_embed.sum(dim=0)
|
||||
|
@ -994,15 +990,6 @@ class ResidualVQ(nn.Module):
|
|||
|
||||
return all_codes
|
||||
|
||||
def draw_logits_forward(self, encoding_logits):
|
||||
# encoding_indices : dim1 = batch_size dim2 = 4 (number of groups) dim3 = vq dict size (header)
|
||||
encoding_logits = encoding_logits
|
||||
bs = encoding_logits.shape[0]
|
||||
quantized = torch.zeros((bs, self.codebooks.shape[-1]))
|
||||
for q in range(encoding_logits.shape[1]):
|
||||
quantized += torch.matmul(encoding_logits[:, q], self.codebooks[q])
|
||||
return quantized
|
||||
|
||||
def forward(
|
||||
self, x, indices=None, return_all_codes=False, sample_codebook_temp=None
|
||||
):
|
||||
|
|
Loading…
Reference in New Issue