add self.reset() at the bottom of __init__ in class VQBeTPolicy

This commit is contained in:
jayLEE0301 2024-06-03 17:45:59 -04:00
parent a33fbd4d44
commit f0508d02b9
1 changed files with 4 additions and 3 deletions

View File

@ -61,9 +61,7 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
self.vqbet = VQBeTModel(config)
def check_discretized(self):
return self.vqbet.action_head.vqvae_model.discretized
self.reset()
def reset(self):
"""
@ -75,6 +73,9 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
"action": deque(maxlen=self.config.n_action_pred_chunk),
}
def check_discretized(self):
return self.vqbet.action_head.vqvae_model.discretized
@torch.no_grad
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.