remove .cuda() parts

This commit is contained in:
jayLEE0301 2024-06-05 11:46:32 -04:00
parent 75bdcaab81
commit 36525b3d2e
1 changed files with 4 additions and 3 deletions

View File

@ -442,10 +442,11 @@ class VQBeTHead(nn.Module):
"(NT G) 1 -> NT G",
NT=NT,
)
device = get_device_from_parameters(self)
indices = (
torch.arange(NT).unsqueeze(1).cuda(),
torch.arange(self.config.vqvae_groups).unsqueeze(0).cuda(),
torch.arange(NT, device=device).unsqueeze(1),
torch.arange(self.config.vqvae_groups, device=device).unsqueeze(0),
sampled_centers,
)
# Use advanced indexing to sample the values