remove .discretized, and add register_buffer

This commit is contained in:
jayLEE0301 2024-06-03 18:25:30 -04:00
parent f0508d02b9
commit 651d9f46e5
1 changed files with 13 additions and 11 deletions

View File

@ -74,7 +74,7 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
}
def check_discretized(self):
return self.vqbet.action_head.vqvae_model.discretized
return self.vqbet.action_head.vqvae_model.check_discretized()
@torch.no_grad
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
@ -90,11 +90,7 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
batch = self.normalize_inputs(batch)
self._queues = populate_queues(self._queues, batch)
if not self.check_discretized():
self.vqbet.action_head.vqvae_model.discretized = True
# VQ-BeT can predict action only after finishing action discretization.
# We added a logit to force self.vqbet.action_head.vqvae_model.discretized to be True if not self.check_discretized() to account for the case of predicting with a pretrained model, but this shouldn't happen if you're learning from scratch, so set eval_freq greater than discretize_step.
warnings.warn('To evaluate in the environment, the model was forced to stop learning the Residual VQ. If you are not evaluating with a pre-trained model, this can degrade overall performance. To avoid this warning, please set "eval_freq" greater than "discretize_step".')
assert self.check_discretized(), "To evaluate in the environment, your VQ-BeT model should contain a pretrained Residual VQ."
assert "observation.image" in batch
assert "observation.state" in batch
@ -390,7 +386,7 @@ class VQBeTHead(nn.Module):
loss, n_different_codes, n_different_combinations = pretrain_vqvae(self.vqvae_model, discretize_step, actions)
# if we updated RVQ more than `discretize_step` steps,
if self.vqvae_model.discretized:
if self.vqvae_model.check_discretized():
print("Finished discretizing action data!")
self.vqvae_model.eval()
for param in self.vqvae_model.vq_layer.parameters():
@ -763,7 +759,7 @@ class VqVae(nn.Module):
super(VqVae, self).__init__()
self.config = config
self.discretized = False
self.register_buffer('discretized', torch.tensor(False))
self.optimized_steps = 0
self.vq_layer = ResidualVQ(
@ -783,6 +779,12 @@ class VqVae(nn.Module):
self.train()
def toggle_discretized(self, state=True):
self.discretized = torch.tensor(state)
def check_discretized(self):
return self.discretized.item()
def eval(self):
self.training = False
self.vq_layer.eval()
@ -796,7 +798,7 @@ class VqVae(nn.Module):
Therefore, we use function overriding to prevent RVQs from being updated during the training of VQ-BeT after discretization completes.
"""
if mode:
if self.discretized:
if self.check_discretized():
pass
else:
self.training = True
@ -884,7 +886,7 @@ class VqVae(nn.Module):
def load_state_dict(self, *args, **kwargs):
super(VqVae, self).state_dict(self, *args, **kwargs)
self.eval()
self.discretized = True
self.toggle_discretized(True)
@ -910,7 +912,7 @@ def pretrain_vqvae(vqvae_model, discretize_step, actions):
n_different_combinations = len(torch.unique(metric[2], dim=0))
vqvae_model.optimized_steps += 1
if vqvae_model.optimized_steps >= discretize_step:
vqvae_model.discretized = True
vqvae_model.toggle_discretized(True)
return loss, n_different_codes, n_different_combinations