change explanations
This commit is contained in:
parent
36525b3d2e
commit
63d198cf0a
|
@ -449,9 +449,9 @@ class VQBeTHead(nn.Module):
|
|||
torch.arange(self.config.vqvae_groups, device=device).unsqueeze(0),
|
||||
sampled_centers,
|
||||
)
|
||||
# Use advanced indexing to sample the values
|
||||
# Use advanced indexing to sample the values (Extract the only offsets corresponding to the sampled codes.)
|
||||
sampled_offsets = cbet_offsets[indices]
|
||||
# Extract the only offsets corresponding to the sampled codes.
|
||||
# Then, sum the offsets over the RVQ layers to get a net offset for the bin prediction
|
||||
sampled_offsets = sampled_offsets.sum(dim=1)
|
||||
# Get the centroids of each layer to pass it through RVQ decoder
|
||||
centers = self.vqvae_model.draw_code_forward(sampled_centers).view(
|
||||
|
|
Loading…
Reference in New Issue