change explanations

This commit is contained in:
jayLEE0301 2024-06-05 11:52:32 -04:00
parent 36525b3d2e
commit 63d198cf0a
1 changed files with 2 additions and 2 deletions

View File

@ -449,9 +449,9 @@ class VQBeTHead(nn.Module):
torch.arange(self.config.vqvae_groups, device=device).unsqueeze(0),
sampled_centers,
)
# Use advanced indexing to sample the values
# Use advanced indexing to sample the values (Extract the only offsets corresponding to the sampled codes.)
sampled_offsets = cbet_offsets[indices]
# Extract the only offsets corresponding to the sampled codes.
# Then, sum the offsets over the RVQ layers to get a net offset for the bin prediction
sampled_offsets = sampled_offsets.sum(dim=1)
# Get the centroids of each layer to pass it through RVQ decoder
centers = self.vqvae_model.draw_code_forward(sampled_centers).view(