fixing tdmpc

This commit is contained in:
mshukor 2025-04-04 11:38:23 +02:00
parent 05c321d5f2
commit b7cbc5867e
2 changed files with 1 additions and 2 deletions

View File

@ -122,7 +122,7 @@ class TDMPCPolicy(PreTrainedPolicy):
# When the action queue is depleted, populate it again by querying the policy.
if len(self._queues["action"]) == 0:
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch if key in self._queues}
# Remove the time dimensions as it is not handled yet.
for key in batch:

View File

@ -124,7 +124,6 @@ def rollout(
# Reset the policy and environments.
policy.reset()
observation, info = env.reset(seed=seeds)
if render_callback is not None:
render_callback(env)