add self.reset() at the bottom of __init__ in class VQBeTPolicy
This commit is contained in:
parent
a33fbd4d44
commit
f0508d02b9
|
@ -61,9 +61,7 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
|
||||
self.vqbet = VQBeTModel(config)
|
||||
|
||||
def check_discretized(self):
|
||||
return self.vqbet.action_head.vqvae_model.discretized
|
||||
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
|
@ -75,6 +73,9 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
"action": deque(maxlen=self.config.n_action_pred_chunk),
|
||||
}
|
||||
|
||||
def check_discretized(self):
|
||||
return self.vqbet.action_head.vqvae_model.discretized
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
|
Loading…
Reference in New Issue