Fix eval for other envs
This commit is contained in:
parent
d0cd39f9b5
commit
dbdd7d0c47
|
@ -49,6 +49,7 @@ from torchrl.envs.batched_envs import BatchedEnvBase
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_offline_buffer
|
from lerobot.common.datasets.factory import make_offline_buffer
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
|
from lerobot.common.envs.simxarm.env import SimxarmEnv
|
||||||
from lerobot.common.logger import log_output_dir
|
from lerobot.common.logger import log_output_dir
|
||||||
from lerobot.common.policies.abstract import AbstractPolicy
|
from lerobot.common.policies.abstract import AbstractPolicy
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
|
@ -86,7 +87,10 @@ def eval_policy(
|
||||||
|
|
||||||
def maybe_render_frame(env: EnvBase, _):
|
def maybe_render_frame(env: EnvBase, _):
|
||||||
if save_video or (return_first_video and i == 0): # noqa: B023
|
if save_video or (return_first_video and i == 0): # noqa: B023
|
||||||
ep_frames.append(env.render(mode="visualization")) # noqa: B023
|
# HACK
|
||||||
|
# TODO(aliberts): set render_mode for all envs
|
||||||
|
render_mode = "visualization" if isinstance(env, SimxarmEnv) else "rgb_array"
|
||||||
|
ep_frames.append(env.render(mode=render_mode)) # noqa: B023
|
||||||
|
|
||||||
# Clear the policy's action queue before the start of a new rollout.
|
# Clear the policy's action queue before the start of a new rollout.
|
||||||
if policy is not None:
|
if policy is not None:
|
||||||
|
|
Loading…
Reference in New Issue