From 29032fbcd3309395a2a81289a2b23b900ca93d1c Mon Sep 17 00:00:00 2001 From: Cadene Date: Fri, 5 Apr 2024 17:17:14 +0000 Subject: [PATCH] wrap dm_control aloha into gymnasium (TODO: properly seeding the env) --- lerobot/common/envs/aloha/env.py | 303 +++++++++---------------------- lerobot/common/envs/factory.py | 7 + tests/test_envs.py | 26 ++- 3 files changed, 104 insertions(+), 232 deletions(-) diff --git a/lerobot/common/envs/aloha/env.py b/lerobot/common/envs/aloha/env.py index 8f907650..719c2d19 100644 --- a/lerobot/common/envs/aloha/env.py +++ b/lerobot/common/envs/aloha/env.py @@ -1,22 +1,9 @@ -import importlib -import logging -from collections import deque -from typing import Optional - -import einops +import gymnasium as gym import numpy as np -import torch from dm_control import mujoco from dm_control.rl import control -from tensordict import TensorDict -from torchrl.data.tensor_specs import ( - BoundedTensorSpec, - CompositeSpec, - DiscreteTensorSpec, - UnboundedContinuousTensorSpec, -) +from gymnasium import spaces -from lerobot.common.envs.abstract import AbstractEnv from lerobot.common.envs.aloha.constants import ( ACTIONS, ASSETS_DIR, @@ -31,49 +18,67 @@ from lerobot.common.envs.aloha.tasks.sim_end_effector import ( from lerobot.common.envs.aloha.utils import sample_box_pose, sample_insertion_pose from lerobot.common.utils import set_global_seed -_has_gym = importlib.util.find_spec("gymnasium") is not None - -class AlohaEnv(AbstractEnv): - name = "aloha" - available_tasks = ["sim_insertion", "sim_transfer_cube"] - _reset_warning_issued = False +class AlohaEnv(gym.Env): + metadata = {"render_modes": [], "render_fps": 50} def __init__( self, task, - frame_skip: int = 1, - from_pixels: bool = False, - pixels_only: bool = False, - image_size=None, - seed=1337, - device="cpu", - num_prev_obs=1, - num_prev_action=0, + obs_type="state", + observation_width=640, + observation_height=480, + visualization_width=640, + visualization_height=480, ): - super().__init__( - task=task, - frame_skip=frame_skip, - from_pixels=from_pixels, - pixels_only=pixels_only, - image_size=image_size, - seed=seed, - device=device, - num_prev_obs=num_prev_obs, - num_prev_action=num_prev_action, - ) - - def _make_env(self): - if not _has_gym: - raise ImportError("Cannot import gymnasium.") - - if not self.from_pixels: - raise NotImplementedError() + super().__init__() + self.task = task + self.obs_type = obs_type + self.observation_width = observation_width + self.observation_height = observation_height + self.visualization_width = visualization_width + self.visualization_height = visualization_height self._env = self._make_env_task(self.task) - def render(self, mode="rgb_array", width=640, height=480): + if self.obs_type == "state": + raise NotImplementedError() + self.observation_space = spaces.Box( + low=np.array([0] * len(JOINTS)), # ??? + high=np.array([255] * len(JOINTS)), # ??? + dtype=np.float64, + ) + elif self.obs_type == "pixels": + self.observation_space = spaces.Box( + low=0, high=255, shape=(self.observation_height, self.observation_width, 3), dtype=np.uint8 + ) + elif self.obs_type == "pixels_agent_pos": + self.observation_space = spaces.Dict( + { + "pixels": spaces.Box( + low=0, + high=255, + shape=(self.observation_height, self.observation_width, 3), + dtype=np.uint8, + ), + "agent_pos": spaces.Box( + low=np.array([-1] * len(JOINTS)), # ??? + high=np.array([1] * len(JOINTS)), # ??? + dtype=np.float64, + ), + } + ) + + self.action_space = spaces.Box(low=-1, high=1, shape=(len(ACTIONS),), dtype=np.float32) + + def render(self, mode="rgb_array"): # TODO(rcadene): render and visualizer several cameras (e.g. angle, front_close) + if mode in ["visualize", "human"]: + height, width = self.visualize_height, self.visualize_width + elif mode == "rgb_array": + height, width = self.observation_height, self.observation_width + else: + raise ValueError(mode) image = self._env.physics.render(height=height, width=width, camera_id="top") return image @@ -81,20 +86,20 @@ class AlohaEnv(AbstractEnv): # time limit is controlled by StepCounter in env factory time_limit = float("inf") - if "sim_transfer_cube" in task_name: + if "transfer_cube" in task_name: xml_path = ASSETS_DIR / "bimanual_viperx_transfer_cube.xml" physics = mujoco.Physics.from_xml_path(str(xml_path)) task = TransferCubeTask(random=False) - elif "sim_insertion" in task_name: + elif "insertion" in task_name: xml_path = ASSETS_DIR / "bimanual_viperx_insertion.xml" physics = mujoco.Physics.from_xml_path(str(xml_path)) task = InsertionTask(random=False) - elif "sim_end_effector_transfer_cube" in task_name: + elif "end_effector_transfer_cube" in task_name: raise NotImplementedError() xml_path = ASSETS_DIR / "bimanual_viperx_end_effector_transfer_cube.xml" physics = mujoco.Physics.from_xml_path(str(xml_path)) task = TransferCubeEndEffectorTask(random=False) - elif "sim_end_effector_insertion" in task_name: + elif "end_effector_insertion" in task_name: raise NotImplementedError() xml_path = ASSETS_DIR / "bimanual_viperx_end_effector_insertion.xml" physics = mujoco.Physics.from_xml_path(str(xml_path)) @@ -108,191 +113,55 @@ class AlohaEnv(AbstractEnv): return env def _format_raw_obs(self, raw_obs): - if self.from_pixels: - image = torch.from_numpy(raw_obs["images"]["top"].copy()) - image = einops.rearrange(image, "h w c -> c h w") - assert image.dtype == torch.uint8 - obs = {"image": {"top": image}} - - if not self.pixels_only: - obs["state"] = torch.from_numpy(raw_obs["qpos"]).type(torch.float32) - else: - # TODO(rcadene): + if self.obs_type == "state": raise NotImplementedError() - # obs = {"state": torch.from_numpy(raw_obs["observation"]).type(torch.float32)} - + elif self.obs_type == "pixels": + obs = raw_obs["images"]["top"].copy() + elif self.obs_type == "pixels_agent_pos": + obs = { + "pixels": raw_obs["images"]["top"].copy(), + "agent_pos": raw_obs["qpos"], + } return obs - def _reset(self, tensordict: Optional[TensorDict] = None): - if tensordict is not None and not AlohaEnv._reset_warning_issued: - logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.") - AlohaEnv._reset_warning_issued = True + def reset(self, seed=None, options=None): + super().reset(seed=seed) - # Seed the environment and update the seed to be used for the next reset. - self._next_seed = self.set_seed(self._next_seed) + # TODO(rcadene): how to seed the env? + if seed is not None: + set_global_seed(seed) + self._env.task.random.seed(seed) # TODO(rcadene): do not use global variable for this - if "sim_transfer_cube" in self.task: + if "transfer_cube" in self.task: BOX_POSE[0] = sample_box_pose() # used in sim reset - elif "sim_insertion" in self.task: + elif "insertion" in self.task: BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset + else: + raise ValueError(self.task) raw_obs = self._env.reset() - obs = self._format_raw_obs(raw_obs.observation) + observation = self._format_raw_obs(raw_obs.observation) - if self.num_prev_obs > 0: - stacked_obs = {} - if "image" in obs: - self._prev_obs_image_queue = deque( - [obs["image"]["top"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1) - ) - stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))} - if "state" in obs: - self._prev_obs_state_queue = deque( - [obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1) - ) - stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue)) - obs = stacked_obs + info = {"is_success": False} + return observation, info - td = TensorDict( - { - "observation": TensorDict(obs, batch_size=[]), - "done": torch.tensor([False], dtype=torch.bool), - }, - batch_size=[], - ) - - return td - - def _step(self, tensordict: TensorDict): - td = tensordict - action = td["action"].numpy() + def step(self, action): assert action.ndim == 1 # TODO(rcadene): add info["is_success"] and info["success"] ? _, reward, _, raw_obs = self._env.step(action) # TODO(rcadene): add an enum - success = done = reward == 4 - obs = self._format_raw_obs(raw_obs) + terminated = is_success = reward == 4 - if self.num_prev_obs > 0: - stacked_obs = {} - if "image" in obs: - self._prev_obs_image_queue.append(obs["image"]["top"]) - stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))} - if "state" in obs: - self._prev_obs_state_queue.append(obs["state"]) - stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue)) - obs = stacked_obs + info = {"is_success": is_success} - td = TensorDict( - { - "observation": TensorDict(obs, batch_size=[]), - "reward": torch.tensor([reward], dtype=torch.float32), - # success and done are true when coverage > self.success_threshold in env - "done": torch.tensor([done], dtype=torch.bool), - "success": torch.tensor([success], dtype=torch.bool), - }, - batch_size=[], - ) - return td + observation = self._format_raw_obs(raw_obs) - def _make_spec(self): - obs = {} - from omegaconf import OmegaConf + truncated = False + return observation, reward, terminated, truncated, info - if self.from_pixels: - if isinstance(self.image_size, int): - image_shape = (3, self.image_size, self.image_size) - elif OmegaConf.is_list(self.image_size) or isinstance(self.image_size, list): - assert len(self.image_size) == 3 # c h w - assert self.image_size[0] == 3 # c is RGB - image_shape = tuple(self.image_size) - else: - raise ValueError(self.image_size) - if self.num_prev_obs > 0: - image_shape = (self.num_prev_obs + 1, *image_shape) - - obs["image"] = { - "top": BoundedTensorSpec( - low=0, - high=255, - shape=image_shape, - dtype=torch.uint8, - device=self.device, - ) - } - if not self.pixels_only: - state_shape = (len(JOINTS),) - if self.num_prev_obs > 0: - state_shape = (self.num_prev_obs + 1, *state_shape) - - obs["state"] = UnboundedContinuousTensorSpec( - # TODO: add low and high bounds - shape=state_shape, - dtype=torch.float32, - device=self.device, - ) - else: - # TODO(rcadene): add observation_space achieved_goal and desired_goal? - state_shape = (len(JOINTS),) - if self.num_prev_obs > 0: - state_shape = (self.num_prev_obs + 1, *state_shape) - - obs["state"] = UnboundedContinuousTensorSpec( - # TODO: add low and high bounds - shape=state_shape, - dtype=torch.float32, - device=self.device, - ) - self.observation_spec = CompositeSpec({"observation": obs}) - - # TODO(rcadene): valid when controling end effector? - # action_space = self._env.action_spec() - # self.action_spec = BoundedTensorSpec( - # low=action_space.minimum, - # high=action_space.maximum, - # shape=action_space.shape, - # dtype=torch.float32, - # device=self.device, - # ) - - # TODO(rcaene): add bounds (where are they????) - self.action_spec = BoundedTensorSpec( - shape=(len(ACTIONS)), - low=-1, - high=1, - dtype=torch.float32, - device=self.device, - ) - - self.reward_spec = UnboundedContinuousTensorSpec( - shape=(1,), - dtype=torch.float32, - device=self.device, - ) - - self.done_spec = CompositeSpec( - { - "done": DiscreteTensorSpec( - 2, - shape=(1,), - dtype=torch.bool, - device=self.device, - ), - "success": DiscreteTensorSpec( - 2, - shape=(1,), - dtype=torch.bool, - device=self.device, - ), - } - ) - - def _set_seed(self, seed: Optional[int]): - set_global_seed(seed) - # TODO(rcadene): seed the env - # self._env.seed(seed) - logging.warning("Aloha env is not seeded") + def close(self): + pass diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index 4ddb81a2..9225cbc5 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -30,7 +30,14 @@ def make_env(cfg, num_parallel_envs=0) -> gym.Env | gym.vector.SyncVectorEnv: **kwargs, ) elif cfg.env.name == "aloha": + from lerobot.common.envs import aloha as gym_aloha # noqa: F401 + kwargs["task"] = cfg.env.task + + env_fn = lambda: gym.make( # noqa: E731 + "gym_aloha/AlohaInsertion-v0", + **kwargs, + ) else: raise ValueError(cfg.env.name) diff --git a/tests/test_envs.py b/tests/test_envs.py index a94d76f2..495453e2 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -41,25 +41,21 @@ from .utils import DEVICE, DEFAULT_CONFIG_PATH # print("data from rollout:", simple_rollout(100)) -@pytest.mark.skip("TODO") @pytest.mark.parametrize( - "task,from_pixels,pixels_only", + "env_task, obs_type", [ - ("sim_insertion", True, False), - ("sim_insertion", True, True), - ("sim_transfer_cube", True, False), - ("sim_transfer_cube", True, True), + # ("AlohaInsertion-v0", "state"), + ("AlohaInsertion-v0", "pixels"), + ("AlohaInsertion-v0", "pixels_agent_pos"), + ("AlohaTransferCube-v0", "pixels"), + ("AlohaTransferCube-v0", "pixels_agent_pos"), ], ) -def test_aloha(task, from_pixels, pixels_only): - env = AlohaEnv( - task, - from_pixels=from_pixels, - pixels_only=pixels_only, - image_size=[3, 480, 640] if from_pixels else None, - ) - # print_spec_rollout(env) - check_env_specs(env) +def test_aloha(env_task, obs_type): + from lerobot.common.envs import aloha as gym_aloha # noqa: F401 + env = gym.make(f"gym_aloha/{env_task}", obs_type=obs_type) + check_env(env) + @pytest.mark.parametrize(