remove .cuda() parts
This commit is contained in:
parent
75bdcaab81
commit
36525b3d2e
|
@ -442,10 +442,11 @@ class VQBeTHead(nn.Module):
|
|||
"(NT G) 1 -> NT G",
|
||||
NT=NT,
|
||||
)
|
||||
|
||||
|
||||
device = get_device_from_parameters(self)
|
||||
indices = (
|
||||
torch.arange(NT).unsqueeze(1).cuda(),
|
||||
torch.arange(self.config.vqvae_groups).unsqueeze(0).cuda(),
|
||||
torch.arange(NT, device=device).unsqueeze(1),
|
||||
torch.arange(self.config.vqvae_groups, device=device).unsqueeze(0),
|
||||
sampled_centers,
|
||||
)
|
||||
# Use advanced indexing to sample the values
|
||||
|
|
Loading…
Reference in New Issue