Clear action queue when environment is reset

This commit is contained in:
Alexander Soare 2024-03-20 08:31:06 +00:00
parent c5010fee9a
commit 4f1955edfd
2 changed files with 9 additions and 3 deletions

View File

@ -20,8 +20,7 @@ class AbstractPolicy(nn.Module, ABC):
"""
super().__init__()
self.n_action_steps = n_action_steps
if n_action_steps is not None:
self._action_queue = deque([], maxlen=n_action_steps)
self.clear_action_queue()
@abstractmethod
def update(self, replay_buffer, step):
@ -42,6 +41,11 @@ class AbstractPolicy(nn.Module, ABC):
actions. Otherwise if n_actions_steps is None, this should return a (batch_size, *) tensor of actions.
"""
def clear_action_queue(self):
"""This should be called whenever the environment is reset."""
if self.n_action_steps is not None:
self._action_queue = deque([], maxlen=self.n_action_steps)
def forward(self, *args, **kwargs) -> Tensor:
"""Inference step that makes multi-step policies compatible with their single-step environments.

View File

@ -15,6 +15,7 @@ from torchrl.envs.batched_envs import BatchedEnvBase
from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.envs.factory import make_env
from lerobot.common.logger import log_output_dir
from lerobot.common.policies.abstract import AbstractPolicy
from lerobot.common.policies.factory import make_policy
from lerobot.common.utils import init_logging, set_seed
@ -25,7 +26,7 @@ def write_video(video_path, stacked_frames, fps):
def eval_policy(
env: BatchedEnvBase,
policy: TensorDictModule = None,
policy: AbstractPolicy,
num_episodes: int = 10,
max_steps: int = 30,
save_video: bool = False,
@ -53,6 +54,7 @@ def eval_policy(
with torch.inference_mode():
# TODO(alexander-soare): When `break_when_any_done == False` this rolls out for max_steps even when all
# envs are done the first time. But we only use the first rollout. This is a waste of compute.
policy.clear_action_queue()
rollout = env.rollout(
max_steps=max_steps,
policy=policy,