remove check_discretized at VQBeTPolicy, and directly use the function in class VqVae
This commit is contained in:
parent
651d9f46e5
commit
6d9a65ca39
|
@ -73,9 +73,6 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
"action": deque(maxlen=self.config.n_action_pred_chunk),
|
||||
}
|
||||
|
||||
def check_discretized(self):
|
||||
return self.vqbet.action_head.vqvae_model.check_discretized()
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
@ -90,7 +87,7 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
batch = self.normalize_inputs(batch)
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
assert self.check_discretized(), "To evaluate in the environment, your VQ-BeT model should contain a pretrained Residual VQ."
|
||||
assert self.vqbet.action_head.vqvae_model.check_discretized(), "To evaluate in the environment, your VQ-BeT model should contain a pretrained Residual VQ."
|
||||
assert "observation.image" in batch
|
||||
assert "observation.state" in batch
|
||||
|
||||
|
@ -112,7 +109,7 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
# VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://arxiv.org/pdf/2403.03181)
|
||||
if not self.check_discretized():
|
||||
if not self.vqbet.action_head.vqvae_model.check_discretized():
|
||||
loss, n_different_codes, n_different_combinations = self.vqbet.discretize(self.config.discretize_step, batch['action'])
|
||||
return {"loss": loss, "n_different_codes": n_different_codes, "n_different_combinations": n_different_combinations}
|
||||
# if Residual VQ is already trained, VQ-BeT trains its GPT and bin ped header / offset header parts.
|
||||
|
|
Loading…
Reference in New Issue