Clear action queue when environment is reset
This commit is contained in:
parent
c5010fee9a
commit
4f1955edfd
|
@ -20,8 +20,7 @@ class AbstractPolicy(nn.Module, ABC):
|
|||
"""
|
||||
super().__init__()
|
||||
self.n_action_steps = n_action_steps
|
||||
if n_action_steps is not None:
|
||||
self._action_queue = deque([], maxlen=n_action_steps)
|
||||
self.clear_action_queue()
|
||||
|
||||
@abstractmethod
|
||||
def update(self, replay_buffer, step):
|
||||
|
@ -42,6 +41,11 @@ class AbstractPolicy(nn.Module, ABC):
|
|||
actions. Otherwise if n_actions_steps is None, this should return a (batch_size, *) tensor of actions.
|
||||
"""
|
||||
|
||||
def clear_action_queue(self):
|
||||
"""This should be called whenever the environment is reset."""
|
||||
if self.n_action_steps is not None:
|
||||
self._action_queue = deque([], maxlen=self.n_action_steps)
|
||||
|
||||
def forward(self, *args, **kwargs) -> Tensor:
|
||||
"""Inference step that makes multi-step policies compatible with their single-step environments.
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@ from torchrl.envs.batched_envs import BatchedEnvBase
|
|||
from lerobot.common.datasets.factory import make_offline_buffer
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.logger import log_output_dir
|
||||
from lerobot.common.policies.abstract import AbstractPolicy
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.utils import init_logging, set_seed
|
||||
|
||||
|
@ -25,7 +26,7 @@ def write_video(video_path, stacked_frames, fps):
|
|||
|
||||
def eval_policy(
|
||||
env: BatchedEnvBase,
|
||||
policy: TensorDictModule = None,
|
||||
policy: AbstractPolicy,
|
||||
num_episodes: int = 10,
|
||||
max_steps: int = 30,
|
||||
save_video: bool = False,
|
||||
|
@ -53,6 +54,7 @@ def eval_policy(
|
|||
with torch.inference_mode():
|
||||
# TODO(alexander-soare): When `break_when_any_done == False` this rolls out for max_steps even when all
|
||||
# envs are done the first time. But we only use the first rollout. This is a waste of compute.
|
||||
policy.clear_action_queue()
|
||||
rollout = env.rollout(
|
||||
max_steps=max_steps,
|
||||
policy=policy,
|
||||
|
|
Loading…
Reference in New Issue