Fix eval for other envs

This commit is contained in:
Simon Alibert 2024-04-01 00:25:36 +02:00
parent d0cd39f9b5
commit dbdd7d0c47
1 changed files with 5 additions and 1 deletions

View File

@ -49,6 +49,7 @@ from torchrl.envs.batched_envs import BatchedEnvBase
from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.envs.factory import make_env
from lerobot.common.envs.simxarm.env import SimxarmEnv
from lerobot.common.logger import log_output_dir
from lerobot.common.policies.abstract import AbstractPolicy
from lerobot.common.policies.factory import make_policy
@ -86,7 +87,10 @@ def eval_policy(
def maybe_render_frame(env: EnvBase, _):
if save_video or (return_first_video and i == 0): # noqa: B023
ep_frames.append(env.render(mode="visualization")) # noqa: B023
# HACK
# TODO(aliberts): set render_mode for all envs
render_mode = "visualization" if isinstance(env, SimxarmEnv) else "rgb_array"
ep_frames.append(env.render(mode=render_mode)) # noqa: B023
# Clear the policy's action queue before the start of a new rollout.
if policy is not None: