remove redundant self.eval(), self.train(), and removed def toggle_discretized
This commit is contained in:
parent
4d3a45ac26
commit
1257aaf438
|
@ -82,7 +82,6 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
queue is empty.
|
||||
"""
|
||||
|
||||
self.eval()
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
@ -777,19 +776,10 @@ class VqVae(nn.Module):
|
|||
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, self.config.output_shapes["action"][0] * self.config.action_chunk_size],
|
||||
)
|
||||
|
||||
self.train()
|
||||
|
||||
def toggle_discretized(self, state=True):
|
||||
self.discretized = torch.tensor(state)
|
||||
|
||||
def check_discretized(self):
|
||||
return self.discretized.item()
|
||||
|
||||
def eval(self):
|
||||
self.training = False
|
||||
self.vq_layer.eval()
|
||||
self.encoder.eval()
|
||||
self.decoder.eval()
|
||||
|
||||
def train(self, mode=True):
|
||||
"""
|
||||
|
@ -885,8 +875,7 @@ class VqVae(nn.Module):
|
|||
|
||||
def load_state_dict(self, *args, **kwargs):
|
||||
super(VqVae, self).state_dict(self, *args, **kwargs)
|
||||
self.eval()
|
||||
self.toggle_discretized(True)
|
||||
self.discretized = torch.tensor(True)
|
||||
|
||||
|
||||
|
||||
|
@ -913,7 +902,7 @@ def pretrain_vqvae(vqvae_model, n_vqvae_training_steps, actions):
|
|||
vqvae_model.optimized_steps += 1
|
||||
# if we updated RVQ more than `n_vqvae_training_steps` steps,
|
||||
if vqvae_model.optimized_steps >= n_vqvae_training_steps:
|
||||
vqvae_model.toggle_discretized(True)
|
||||
vqvae_model.discretized = torch.tensor(True)
|
||||
print("Finished discretizing action data!")
|
||||
vqvae_model.eval()
|
||||
for param in vqvae_model.vq_layer.parameters():
|
||||
|
|
Loading…
Reference in New Issue