From 4f1955edfdc7515f85ec9d70361932cd45e1c327 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Wed, 20 Mar 2024 08:31:06 +0000 Subject: [PATCH] Clear action queue when environment is reset --- lerobot/common/policies/abstract.py | 8 ++++++-- lerobot/scripts/eval.py | 4 +++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/lerobot/common/policies/abstract.py b/lerobot/common/policies/abstract.py index 9f16f5d7..272ffcf4 100644 --- a/lerobot/common/policies/abstract.py +++ b/lerobot/common/policies/abstract.py @@ -20,8 +20,7 @@ class AbstractPolicy(nn.Module, ABC): """ super().__init__() self.n_action_steps = n_action_steps - if n_action_steps is not None: - self._action_queue = deque([], maxlen=n_action_steps) + self.clear_action_queue() @abstractmethod def update(self, replay_buffer, step): @@ -42,6 +41,11 @@ class AbstractPolicy(nn.Module, ABC): actions. Otherwise if n_actions_steps is None, this should return a (batch_size, *) tensor of actions. """ + def clear_action_queue(self): + """This should be called whenever the environment is reset.""" + if self.n_action_steps is not None: + self._action_queue = deque([], maxlen=self.n_action_steps) + def forward(self, *args, **kwargs) -> Tensor: """Inference step that makes multi-step policies compatible with their single-step environments. diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 86d4158e..1e44c5df 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -15,6 +15,7 @@ from torchrl.envs.batched_envs import BatchedEnvBase from lerobot.common.datasets.factory import make_offline_buffer from lerobot.common.envs.factory import make_env from lerobot.common.logger import log_output_dir +from lerobot.common.policies.abstract import AbstractPolicy from lerobot.common.policies.factory import make_policy from lerobot.common.utils import init_logging, set_seed @@ -25,7 +26,7 @@ def write_video(video_path, stacked_frames, fps): def eval_policy( env: BatchedEnvBase, - policy: TensorDictModule = None, + policy: AbstractPolicy, num_episodes: int = 10, max_steps: int = 30, save_video: bool = False, @@ -53,6 +54,7 @@ def eval_policy( with torch.inference_mode(): # TODO(alexander-soare): When `break_when_any_done == False` this rolls out for max_steps even when all # envs are done the first time. But we only use the first rollout. This is a waste of compute. + policy.clear_action_queue() rollout = env.rollout( max_steps=max_steps, policy=policy,