Fix visualization render size
This commit is contained in:
parent
bf93fc1875
commit
5bba325fd0
|
@ -20,6 +20,8 @@ def make_env(cfg, transform=None):
|
||||||
from lerobot.common.envs.simxarm.env import SimxarmEnv
|
from lerobot.common.envs.simxarm.env import SimxarmEnv
|
||||||
|
|
||||||
kwargs["task"] = cfg.env.task
|
kwargs["task"] = cfg.env.task
|
||||||
|
kwargs["visualization_width"] = cfg.env.visualization_width
|
||||||
|
kwargs["visualization_height"] = cfg.env.visualization_height
|
||||||
clsfunc = SimxarmEnv
|
clsfunc = SimxarmEnv
|
||||||
elif cfg.env.name == "pusht":
|
elif cfg.env.name == "pusht":
|
||||||
from lerobot.common.envs.pusht.env import PushtEnv
|
from lerobot.common.envs.pusht.env import PushtEnv
|
||||||
|
|
|
@ -38,7 +38,11 @@ class SimxarmEnv(AbstractEnv):
|
||||||
device="cpu",
|
device="cpu",
|
||||||
num_prev_obs=0,
|
num_prev_obs=0,
|
||||||
num_prev_action=0,
|
num_prev_action=0,
|
||||||
|
visualization_width=400,
|
||||||
|
visualization_height=400,
|
||||||
):
|
):
|
||||||
|
self.visualization_width = visualization_width
|
||||||
|
self.visualization_height = visualization_height
|
||||||
self.image_size = image_size
|
self.image_size = image_size
|
||||||
super().__init__(
|
super().__init__(
|
||||||
task=task,
|
task=task,
|
||||||
|
@ -63,7 +67,12 @@ class SimxarmEnv(AbstractEnv):
|
||||||
if self.task not in TASKS:
|
if self.task not in TASKS:
|
||||||
raise ValueError(f"Unknown task {self.task}. Must be one of {list(TASKS.keys())}")
|
raise ValueError(f"Unknown task {self.task}. Must be one of {list(TASKS.keys())}")
|
||||||
|
|
||||||
kwargs = {"width": self.image_size, "height": self.image_size}
|
kwargs = {
|
||||||
|
"width": self.image_size,
|
||||||
|
"height": self.image_size,
|
||||||
|
"visualization_width": self.visualization_width,
|
||||||
|
"visualization_height": self.visualization_height,
|
||||||
|
}
|
||||||
self._env = TASKS[self.task]["env"](**kwargs)
|
self._env = TASKS[self.task]["env"](**kwargs)
|
||||||
|
|
||||||
num_actions = len(TASKS[self.task]["action_space"])
|
num_actions = len(TASKS[self.task]["action_space"])
|
||||||
|
@ -72,12 +81,12 @@ class SimxarmEnv(AbstractEnv):
|
||||||
if "w" not in TASKS[self.task]["action_space"]:
|
if "w" not in TASKS[self.task]["action_space"]:
|
||||||
self._action_padding[-1] = 1.0
|
self._action_padding[-1] = 1.0
|
||||||
|
|
||||||
def render(self, mode="rgb_array", width=384, height=384):
|
def render(self, mode="rgb_array"):
|
||||||
return self._env.render(mode, width=width, height=height)
|
return self._env.render(mode)
|
||||||
|
|
||||||
def _format_raw_obs(self, raw_obs):
|
def _format_raw_obs(self, raw_obs):
|
||||||
if self.from_pixels:
|
if self.from_pixels:
|
||||||
image = self.render(mode="rgb_array", width=self.image_size, height=self.image_size)
|
image = self.render(mode="rgb_array")
|
||||||
image = image.transpose(2, 0, 1) # (H, W, C) -> (C, H, W)
|
image = image.transpose(2, 0, 1) # (H, W, C) -> (C, H, W)
|
||||||
image = torch.tensor(image.copy(), dtype=torch.uint8)
|
image = torch.tensor(image.copy(), dtype=torch.uint8)
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
|
# from copy import deepcopy
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import mujoco
|
import mujoco
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from gymnasium.envs.mujoco.mujoco_rendering import MujocoRenderer
|
||||||
from gymnasium_robotics.envs import robot_env
|
from gymnasium_robotics.envs import robot_env
|
||||||
|
|
||||||
from lerobot.common.envs.simxarm.simxarm.tasks import mocap
|
from lerobot.common.envs.simxarm.simxarm.tasks import mocap
|
||||||
|
@ -22,6 +24,9 @@ class Base(robot_env.MujocoRobotEnv):
|
||||||
self.center_of_table = np.array([1.655, 0.3, 0.63625])
|
self.center_of_table = np.array([1.655, 0.3, 0.63625])
|
||||||
self.max_z = 1.2
|
self.max_z = 1.2
|
||||||
self.min_z = 0.2
|
self.min_z = 0.2
|
||||||
|
visualization_width = kwargs.pop("visualization_width")
|
||||||
|
visualization_height = kwargs.pop("visualization_height")
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
model_path=os.path.join(os.path.dirname(__file__), "assets", xml_name + ".xml"),
|
model_path=os.path.join(os.path.dirname(__file__), "assets", xml_name + ".xml"),
|
||||||
n_substeps=20,
|
n_substeps=20,
|
||||||
|
@ -30,6 +35,8 @@ class Base(robot_env.MujocoRobotEnv):
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._set_custom_size_renderer(width=visualization_width, height=visualization_height)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dt(self):
|
def dt(self):
|
||||||
return self.n_substeps * self.model.opt.timestep
|
return self.n_substeps * self.model.opt.timestep
|
||||||
|
@ -134,10 +141,27 @@ class Base(robot_env.MujocoRobotEnv):
|
||||||
info = {"is_success": self.is_success(), "success": self.is_success()}
|
info = {"is_success": self.is_success(), "success": self.is_success()}
|
||||||
return obs, reward, done, info
|
return obs, reward, done, info
|
||||||
|
|
||||||
def render(self, mode="rgb_array", width=384, height=384):
|
def render(self, mode="rgb_array"):
|
||||||
self._render_callback()
|
self._render_callback()
|
||||||
|
|
||||||
|
if mode == "visualization":
|
||||||
|
return self._custom_size_render()
|
||||||
|
|
||||||
return self.mujoco_renderer.render(mode, camera_name="camera0")
|
return self.mujoco_renderer.render(mode, camera_name="camera0")
|
||||||
|
|
||||||
|
def _set_custom_size_renderer(self, width, height):
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
# HACK
|
||||||
|
custom_render_model = deepcopy(self.model)
|
||||||
|
custom_render_model.vis.global_.offwidth = width
|
||||||
|
custom_render_model.vis.global_.offheight = height
|
||||||
|
self.custom_size_renderer = MujocoRenderer(custom_render_model, self.data)
|
||||||
|
del custom_render_model
|
||||||
|
|
||||||
|
def _custom_size_render(self):
|
||||||
|
return self.custom_size_renderer.render("rgb_array", camera_name="camera0")
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
if self.mujoco_renderer is not None:
|
if self.mujoco_renderer is not None:
|
||||||
self.mujoco_renderer.close()
|
self.mujoco_renderer.close()
|
||||||
|
|
|
@ -20,6 +20,8 @@ env:
|
||||||
action_repeat: 2
|
action_repeat: 2
|
||||||
episode_length: 25
|
episode_length: 25
|
||||||
fps: ${fps}
|
fps: ${fps}
|
||||||
|
visualization_width: 400
|
||||||
|
visualization_height: 400
|
||||||
|
|
||||||
policy:
|
policy:
|
||||||
state_dim: 4
|
state_dim: 4
|
||||||
|
|
|
@ -86,7 +86,7 @@ def eval_policy(
|
||||||
|
|
||||||
def maybe_render_frame(env: EnvBase, _):
|
def maybe_render_frame(env: EnvBase, _):
|
||||||
if save_video or (return_first_video and i == 0): # noqa: B023
|
if save_video or (return_first_video and i == 0): # noqa: B023
|
||||||
ep_frames.append(env.render()) # noqa: B023
|
ep_frames.append(env.render(mode="visualization")) # noqa: B023
|
||||||
|
|
||||||
# Clear the policy's action queue before the start of a new rollout.
|
# Clear the policy's action queue before the start of a new rollout.
|
||||||
if policy is not None:
|
if policy is not None:
|
||||||
|
|
Loading…
Reference in New Issue