fixing tdmpc
This commit is contained in:
parent
05c321d5f2
commit
b7cbc5867e
|
@ -122,7 +122,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||||
|
|
||||||
# When the action queue is depleted, populate it again by querying the policy.
|
# When the action queue is depleted, populate it again by querying the policy.
|
||||||
if len(self._queues["action"]) == 0:
|
if len(self._queues["action"]) == 0:
|
||||||
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
|
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch if key in self._queues}
|
||||||
|
|
||||||
# Remove the time dimensions as it is not handled yet.
|
# Remove the time dimensions as it is not handled yet.
|
||||||
for key in batch:
|
for key in batch:
|
||||||
|
|
|
@ -124,7 +124,6 @@ def rollout(
|
||||||
|
|
||||||
# Reset the policy and environments.
|
# Reset the policy and environments.
|
||||||
policy.reset()
|
policy.reset()
|
||||||
|
|
||||||
observation, info = env.reset(seed=seeds)
|
observation, info = env.reset(seed=seeds)
|
||||||
if render_callback is not None:
|
if render_callback is not None:
|
||||||
render_callback(env)
|
render_callback(env)
|
||||||
|
|
Loading…
Reference in New Issue