remove check_discretized at VQBeTPolicy, and directly use the function in class VqVae

This commit is contained in:
jayLEE0301 2024-06-03 18:29:26 -04:00
parent 651d9f46e5
commit 6d9a65ca39
1 changed files with 2 additions and 5 deletions

View File

@ -73,9 +73,6 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
"action": deque(maxlen=self.config.n_action_pred_chunk),
}
def check_discretized(self):
return self.vqbet.action_head.vqvae_model.check_discretized()
@torch.no_grad
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.
@ -90,7 +87,7 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
batch = self.normalize_inputs(batch)
self._queues = populate_queues(self._queues, batch)
assert self.check_discretized(), "To evaluate in the environment, your VQ-BeT model should contain a pretrained Residual VQ."
assert self.vqbet.action_head.vqvae_model.check_discretized(), "To evaluate in the environment, your VQ-BeT model should contain a pretrained Residual VQ."
assert "observation.image" in batch
assert "observation.state" in batch
@ -112,7 +109,7 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
# VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://arxiv.org/pdf/2403.03181)
if not self.check_discretized():
if not self.vqbet.action_head.vqvae_model.check_discretized():
loss, n_different_codes, n_different_combinations = self.vqbet.discretize(self.config.discretize_step, batch['action'])
return {"loss": loss, "n_different_codes": n_different_codes, "n_different_combinations": n_different_combinations}
# if Residual VQ is already trained, VQ-BeT trains its GPT and bin ped header / offset header parts.