From b9051118956e6f77317e3c03435c963de469ee18 Mon Sep 17 00:00:00 2001 From: Cadene Date: Sun, 24 Mar 2024 17:48:02 +0000 Subject: [PATCH] fix render issue --- lerobot/common/datasets/abstract.py | 6 +++--- lerobot/common/datasets/simxarm.py | 6 +++++- lerobot/common/envs/simxarm/simxarm/task/base.py | 7 +++---- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py index d0206cc7..794d60a2 100644 --- a/lerobot/common/datasets/abstract.py +++ b/lerobot/common/datasets/abstract.py @@ -44,9 +44,9 @@ class AbstractExperienceReplay(TensorDictReplayBuffer): f"The version of the dataset ({self.version}) is not enforced when root is provided ({self.root})." ) - # HACK - if dataset_id == "xarm_lift_medium": - self.data_dir = self.root / self.dataset_id + # HACK: to remove before merge + self.data_dir = self.root / self.dataset_id + if not (self.data_dir / "replay_buffer").exists(): storage = self._download_and_preproc_obsolete() else: storage = self._download_or_load_dataset() diff --git a/lerobot/common/datasets/simxarm.py b/lerobot/common/datasets/simxarm.py index b5cec7e1..af006f76 100644 --- a/lerobot/common/datasets/simxarm.py +++ b/lerobot/common/datasets/simxarm.py @@ -113,7 +113,11 @@ class SimxarmExperienceReplay(AbstractExperienceReplay): if episode_id == 0: # hack to initialize tensordict data structure to store episodes - td_data = episode[0].expand(total_frames).memmap_like(self.root / f"{self.dataset_id}") + td_data = ( + episode[0] + .expand(total_frames) + .memmap_like(self.root / f"{self.dataset_id}" / "replay_buffer") + ) td_data[idx0:idx1] = episode diff --git a/lerobot/common/envs/simxarm/simxarm/task/base.py b/lerobot/common/envs/simxarm/simxarm/task/base.py index d5f54f72..d91f61c1 100644 --- a/lerobot/common/envs/simxarm/simxarm/task/base.py +++ b/lerobot/common/envs/simxarm/simxarm/task/base.py @@ -155,10 +155,9 @@ class Base(robot_env.MujocoRobotEnv): def render(self, mode="rgb_array", width=384, height=384): self._render_callback() - # if mode == 'rgb_array': - # return self.sim.render(width, height, camera_name='camera0', depth=False)[::-1, :, :] - # elif mode == "human": - # self._get_viewer(mode).render() + # hack + self.model.vis.global_.offwidth = width + self.model.vis.global_.offheight = height return self.mujoco_renderer.render(mode) def close(self):