fixing tdmpc
This commit is contained in:
parent
05c321d5f2
commit
b7cbc5867e
|
@ -122,7 +122,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
|||
|
||||
# When the action queue is depleted, populate it again by querying the policy.
|
||||
if len(self._queues["action"]) == 0:
|
||||
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
|
||||
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch if key in self._queues}
|
||||
|
||||
# Remove the time dimensions as it is not handled yet.
|
||||
for key in batch:
|
||||
|
|
|
@ -124,7 +124,6 @@ def rollout(
|
|||
|
||||
# Reset the policy and environments.
|
||||
policy.reset()
|
||||
|
||||
observation, info = env.reset(seed=seeds)
|
||||
if render_callback is not None:
|
||||
render_callback(env)
|
||||
|
|
Loading…
Reference in New Issue