remove redundant self.eval(), self.train(), and removed def toggle_discretized

This commit is contained in:
jayLEE0301 2024-06-08 15:56:32 -04:00
parent 4d3a45ac26
commit 1257aaf438
1 changed files with 2 additions and 13 deletions

View File

@ -82,7 +82,6 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
queue is empty.
"""
self.eval()
batch = self.normalize_inputs(batch)
self._queues = populate_queues(self._queues, batch)
@ -777,19 +776,10 @@ class VqVae(nn.Module):
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, self.config.output_shapes["action"][0] * self.config.action_chunk_size],
)
self.train()
def toggle_discretized(self, state=True):
self.discretized = torch.tensor(state)
def check_discretized(self):
return self.discretized.item()
def eval(self):
self.training = False
self.vq_layer.eval()
self.encoder.eval()
self.decoder.eval()
def train(self, mode=True):
"""
@ -885,8 +875,7 @@ class VqVae(nn.Module):
def load_state_dict(self, *args, **kwargs):
super(VqVae, self).state_dict(self, *args, **kwargs)
self.eval()
self.toggle_discretized(True)
self.discretized = torch.tensor(True)
@ -913,7 +902,7 @@ def pretrain_vqvae(vqvae_model, n_vqvae_training_steps, actions):
vqvae_model.optimized_steps += 1
# if we updated RVQ more than `n_vqvae_training_steps` steps,
if vqvae_model.optimized_steps >= n_vqvae_training_steps:
vqvae_model.toggle_discretized(True)
vqvae_model.discretized = torch.tensor(True)
print("Finished discretizing action data!")
vqvae_model.eval()
for param in vqvae_model.vq_layer.parameters():