diff --git a/lerobot/common/datasets/aloha.py b/lerobot/common/datasets/aloha.py index 7397327d..95a33450 100644 --- a/lerobot/common/datasets/aloha.py +++ b/lerobot/common/datasets/aloha.py @@ -73,11 +73,11 @@ def download(data_dir, dataset_id): data_dir.mkdir(parents=True, exist_ok=True) - gdown.download_folder(FOLDER_URLS[dataset_id], output=data_dir) + gdown.download_folder(FOLDER_URLS[dataset_id], output=str(data_dir)) # because of the 50 files limit per directory, two files episode 48 and 49 were missing - gdown.download(EP48_URLS[dataset_id], output=data_dir / "episode_48.hdf5", fuzzy=True) - gdown.download(EP49_URLS[dataset_id], output=data_dir / "episode_49.hdf5", fuzzy=True) + gdown.download(EP48_URLS[dataset_id], output=str(data_dir / "episode_48.hdf5"), fuzzy=True) + gdown.download(EP49_URLS[dataset_id], output=str(data_dir / "episode_49.hdf5"), fuzzy=True) class AlohaExperienceReplay(AbstractExperienceReplay): diff --git a/lerobot/common/envs/aloha/env.py b/lerobot/common/envs/aloha/env.py index b9b13d66..d92c2f49 100644 --- a/lerobot/common/envs/aloha/env.py +++ b/lerobot/common/envs/aloha/env.py @@ -370,14 +370,10 @@ class AlohaEnv(EnvBase): raise NotImplementedError() # self._prev_action_queue = deque(maxlen=self.num_prev_action) - def render(self, mode="rgb_array", width=384, height=384): - if width != height: - raise NotImplementedError() - tmp = self._env.render_size - self._env.render_size = width - out = self._env.render(mode) - self._env.render_size = tmp - return out + def render(self, mode="rgb_array", width=640, height=480): + # TODO(rcadene): render and visualizer several cameras (e.g. angle, front_close) + image = self._env.physics.render(height=height, width=width, camera_id="top") + return image def _format_raw_obs(self, raw_obs): if self.from_pixels: @@ -535,8 +531,10 @@ class AlohaEnv(EnvBase): # ) # TODO(rcaene): add bounds (where are they????) - self.action_spec = UnboundedContinuousTensorSpec( + self.action_spec = BoundedTensorSpec( shape=(len(ACTIONS)), + low=-1, + high=1, dtype=torch.float32, device=self.device, )