remove unused parts in def discretize

This commit is contained in:
jayLEE0301 2024-06-05 11:36:10 -04:00
parent 06d3eb40e1
commit 75bdcaab81
1 changed files with 0 additions and 6 deletions

View File

@ -379,12 +379,6 @@ class VQBeTHead(nn.Module):
self._criterion = FocalLoss(gamma=2.0)
def discretize(self, discretize_step, actions):
if next(self.vqvae_model.encoder.parameters()).device != get_device_from_parameters(self):
self.vqvae_model.encoder.to(get_device_from_parameters(self))
self.vqvae_model.vq_layer.to(get_device_from_parameters(self))
self.vqvae_model.decoder.to(get_device_from_parameters(self))
self.vqvae_model.device = get_device_from_parameters(self)
loss, n_different_codes, n_different_combinations = pretrain_vqvae(self.vqvae_model, discretize_step, actions)
# if we updated RVQ more than `discretize_step` steps,
if self.vqvae_model.check_discretized():