diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index 855e073b..14d29bcc 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -20,6 +20,8 @@ def make_env(cfg, transform=None): from lerobot.common.envs.simxarm.env import SimxarmEnv kwargs["task"] = cfg.env.task + kwargs["visualization_width"] = cfg.env.visualization_width + kwargs["visualization_height"] = cfg.env.visualization_height clsfunc = SimxarmEnv elif cfg.env.name == "pusht": from lerobot.common.envs.pusht.env import PushtEnv diff --git a/lerobot/common/envs/simxarm/env.py b/lerobot/common/envs/simxarm/env.py index b8f19057..dfd684cf 100644 --- a/lerobot/common/envs/simxarm/env.py +++ b/lerobot/common/envs/simxarm/env.py @@ -38,7 +38,11 @@ class SimxarmEnv(AbstractEnv): device="cpu", num_prev_obs=0, num_prev_action=0, + visualization_width=400, + visualization_height=400, ): + self.visualization_width = visualization_width + self.visualization_height = visualization_height self.image_size = image_size super().__init__( task=task, @@ -63,7 +67,12 @@ class SimxarmEnv(AbstractEnv): if self.task not in TASKS: raise ValueError(f"Unknown task {self.task}. Must be one of {list(TASKS.keys())}") - kwargs = {"width": self.image_size, "height": self.image_size} + kwargs = { + "width": self.image_size, + "height": self.image_size, + "visualization_width": self.visualization_width, + "visualization_height": self.visualization_height, + } self._env = TASKS[self.task]["env"](**kwargs) num_actions = len(TASKS[self.task]["action_space"]) @@ -72,12 +81,12 @@ class SimxarmEnv(AbstractEnv): if "w" not in TASKS[self.task]["action_space"]: self._action_padding[-1] = 1.0 - def render(self, mode="rgb_array", width=384, height=384): - return self._env.render(mode, width=width, height=height) + def render(self, mode="rgb_array"): + return self._env.render(mode) def _format_raw_obs(self, raw_obs): if self.from_pixels: - image = self.render(mode="rgb_array", width=self.image_size, height=self.image_size) + image = self.render(mode="rgb_array") image = image.transpose(2, 0, 1) # (H, W, C) -> (C, H, W) image = torch.tensor(image.copy(), dtype=torch.uint8) diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/base.py b/lerobot/common/envs/simxarm/simxarm/tasks/base.py index 167dafe8..fa7afad6 100644 --- a/lerobot/common/envs/simxarm/simxarm/tasks/base.py +++ b/lerobot/common/envs/simxarm/simxarm/tasks/base.py @@ -1,7 +1,9 @@ +# from copy import deepcopy import os import mujoco import numpy as np +from gymnasium.envs.mujoco.mujoco_rendering import MujocoRenderer from gymnasium_robotics.envs import robot_env from lerobot.common.envs.simxarm.simxarm.tasks import mocap @@ -22,6 +24,9 @@ class Base(robot_env.MujocoRobotEnv): self.center_of_table = np.array([1.655, 0.3, 0.63625]) self.max_z = 1.2 self.min_z = 0.2 + visualization_width = kwargs.pop("visualization_width") + visualization_height = kwargs.pop("visualization_height") + super().__init__( model_path=os.path.join(os.path.dirname(__file__), "assets", xml_name + ".xml"), n_substeps=20, @@ -30,6 +35,8 @@ class Base(robot_env.MujocoRobotEnv): **kwargs, ) + self._set_custom_size_renderer(width=visualization_width, height=visualization_height) + @property def dt(self): return self.n_substeps * self.model.opt.timestep @@ -134,10 +141,27 @@ class Base(robot_env.MujocoRobotEnv): info = {"is_success": self.is_success(), "success": self.is_success()} return obs, reward, done, info - def render(self, mode="rgb_array", width=384, height=384): + def render(self, mode="rgb_array"): self._render_callback() + + if mode == "visualization": + return self._custom_size_render() + return self.mujoco_renderer.render(mode, camera_name="camera0") + def _set_custom_size_renderer(self, width, height): + from copy import deepcopy + + # HACK + custom_render_model = deepcopy(self.model) + custom_render_model.vis.global_.offwidth = width + custom_render_model.vis.global_.offheight = height + self.custom_size_renderer = MujocoRenderer(custom_render_model, self.data) + del custom_render_model + + def _custom_size_render(self): + return self.custom_size_renderer.render("rgb_array", camera_name="camera0") + def close(self): if self.mujoco_renderer is not None: self.mujoco_renderer.close() diff --git a/lerobot/configs/env/simxarm.yaml b/lerobot/configs/env/simxarm.yaml index f79db8f7..4255fa13 100644 --- a/lerobot/configs/env/simxarm.yaml +++ b/lerobot/configs/env/simxarm.yaml @@ -20,6 +20,8 @@ env: action_repeat: 2 episode_length: 25 fps: ${fps} + visualization_width: 400 + visualization_height: 400 policy: state_dim: 4 diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 28a25e43..75173e3f 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -86,7 +86,7 @@ def eval_policy( def maybe_render_frame(env: EnvBase, _): if save_video or (return_first_video and i == 0): # noqa: B023 - ep_frames.append(env.render()) # noqa: B023 + ep_frames.append(env.render(mode="visualization")) # noqa: B023 # Clear the policy's action queue before the start of a new rollout. if policy is not None: