remove unused parts in def discretize
This commit is contained in:
parent
06d3eb40e1
commit
75bdcaab81
|
@ -379,12 +379,6 @@ class VQBeTHead(nn.Module):
|
|||
self._criterion = FocalLoss(gamma=2.0)
|
||||
|
||||
def discretize(self, discretize_step, actions):
|
||||
if next(self.vqvae_model.encoder.parameters()).device != get_device_from_parameters(self):
|
||||
self.vqvae_model.encoder.to(get_device_from_parameters(self))
|
||||
self.vqvae_model.vq_layer.to(get_device_from_parameters(self))
|
||||
self.vqvae_model.decoder.to(get_device_from_parameters(self))
|
||||
self.vqvae_model.device = get_device_from_parameters(self)
|
||||
|
||||
loss, n_different_codes, n_different_combinations = pretrain_vqvae(self.vqvae_model, discretize_step, actions)
|
||||
# if we updated RVQ more than `discretize_step` steps,
|
||||
if self.vqvae_model.check_discretized():
|
||||
|
|
Loading…
Reference in New Issue