remove unnecessary resizing parts
This commit is contained in:
parent
63d198cf0a
commit
f69d552480
|
@ -454,12 +454,7 @@ class VQBeTHead(nn.Module):
|
|||
# Then, sum the offsets over the RVQ layers to get a net offset for the bin prediction
|
||||
sampled_offsets = sampled_offsets.sum(dim=1)
|
||||
# Get the centroids of each layer to pass it through RVQ decoder
|
||||
centers = self.vqvae_model.draw_code_forward(sampled_centers).view(
|
||||
NT, -1, self.config.vqvae_embedding_dim
|
||||
)
|
||||
return_decoder_input = einops.rearrange(
|
||||
centers.clone().detach(), "NT 1 D -> NT D"
|
||||
)
|
||||
return_decoder_input = self.vqvae_model.draw_code_forward(sampled_centers).clone().detach()
|
||||
# pass the centroids through decoder to get actions.
|
||||
decoded_action = (
|
||||
self.vqvae_model.get_action_from_latent(return_decoder_input)
|
||||
|
|
Loading…
Reference in New Issue