remove unnecessary resizing parts

This commit is contained in:
jayLEE0301 2024-06-05 11:57:41 -04:00
parent 63d198cf0a
commit f69d552480
1 changed files with 1 additions and 6 deletions

View File

@ -454,12 +454,7 @@ class VQBeTHead(nn.Module):
# Then, sum the offsets over the RVQ layers to get a net offset for the bin prediction
sampled_offsets = sampled_offsets.sum(dim=1)
# Get the centroids of each layer to pass it through RVQ decoder
centers = self.vqvae_model.draw_code_forward(sampled_centers).view(
NT, -1, self.config.vqvae_embedding_dim
)
return_decoder_input = einops.rearrange(
centers.clone().detach(), "NT 1 D -> NT D"
)
return_decoder_input = self.vqvae_model.draw_code_forward(sampled_centers).clone().detach()
# pass the centroids through decoder to get actions.
decoded_action = (
self.vqvae_model.get_action_from_latent(return_decoder_input)