167 lines
5.9 KiB
Python
167 lines
5.9 KiB
Python
|
from collections import OrderedDict, deque
|
||
|
|
||
|
import gymnasium as gym
|
||
|
import numpy as np
|
||
|
from gymnasium.wrappers import TimeLimit
|
||
|
|
||
|
from xarm.tasks.base import Base as Base
|
||
|
from xarm.tasks.lift import Lift
|
||
|
from xarm.tasks.peg_in_box import PegInBox
|
||
|
from xarm.tasks.push import Push
|
||
|
from xarm.tasks.reach import Reach
|
||
|
|
||
|
TASKS = OrderedDict(
|
||
|
(
|
||
|
(
|
||
|
"reach",
|
||
|
{
|
||
|
"env": Reach,
|
||
|
"action_space": "xyz",
|
||
|
"episode_length": 50,
|
||
|
"description": "Reach a target location with the end effector",
|
||
|
},
|
||
|
),
|
||
|
(
|
||
|
"push",
|
||
|
{
|
||
|
"env": Push,
|
||
|
"action_space": "xyz",
|
||
|
"episode_length": 50,
|
||
|
"description": "Push a cube to a target location",
|
||
|
},
|
||
|
),
|
||
|
(
|
||
|
"peg_in_box",
|
||
|
{
|
||
|
"env": PegInBox,
|
||
|
"action_space": "xyz",
|
||
|
"episode_length": 50,
|
||
|
"description": "Insert a peg into a box",
|
||
|
},
|
||
|
),
|
||
|
(
|
||
|
"lift",
|
||
|
{
|
||
|
"env": Lift,
|
||
|
"action_space": "xyzw",
|
||
|
"episode_length": 50,
|
||
|
"description": "Lift a cube above a height threshold",
|
||
|
},
|
||
|
),
|
||
|
)
|
||
|
)
|
||
|
|
||
|
|
||
|
class SimXarmWrapper(gym.Wrapper):
|
||
|
"""
|
||
|
A wrapper for the SimXarm environments. This wrapper is used to
|
||
|
convert the action and observation spaces to the correct format.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, env, task, obs_mode, image_size, action_repeat, frame_stack=1, channel_last=False):
|
||
|
super().__init__(env)
|
||
|
self._env = env
|
||
|
self.obs_mode = obs_mode
|
||
|
self.image_size = image_size
|
||
|
self.action_repeat = action_repeat
|
||
|
self.frame_stack = frame_stack
|
||
|
self._frames = deque([], maxlen=frame_stack)
|
||
|
self.channel_last = channel_last
|
||
|
self._max_episode_steps = task["episode_length"] // action_repeat
|
||
|
|
||
|
image_shape = (
|
||
|
(image_size, image_size, 3 * frame_stack)
|
||
|
if channel_last
|
||
|
else (3 * frame_stack, image_size, image_size)
|
||
|
)
|
||
|
if obs_mode == "state":
|
||
|
self.observation_space = env.observation_space["observation"]
|
||
|
elif obs_mode == "rgb":
|
||
|
self.observation_space = gym.spaces.Box(low=0, high=255, shape=image_shape, dtype=np.uint8)
|
||
|
elif obs_mode == "all":
|
||
|
self.observation_space = gym.spaces.Dict(
|
||
|
state=gym.spaces.Box(low=-np.inf, high=np.inf, shape=(4,), dtype=np.float32),
|
||
|
rgb=gym.spaces.Box(low=0, high=255, shape=image_shape, dtype=np.uint8),
|
||
|
)
|
||
|
else:
|
||
|
raise ValueError(f"Unknown obs_mode {obs_mode}. Must be one of [rgb, all, state]")
|
||
|
self.action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(len(task["action_space"]),))
|
||
|
self.action_padding = np.zeros(4 - len(task["action_space"]), dtype=np.float32)
|
||
|
if "w" not in task["action_space"]:
|
||
|
self.action_padding[-1] = 1.0
|
||
|
|
||
|
def _render_obs(self):
|
||
|
obs = self.render(mode="rgb_array", width=self.image_size, height=self.image_size)
|
||
|
if not self.channel_last:
|
||
|
obs = obs.transpose(2, 0, 1)
|
||
|
return obs.copy()
|
||
|
|
||
|
def _update_frames(self, reset=False):
|
||
|
pixels = self._render_obs()
|
||
|
self._frames.append(pixels)
|
||
|
if reset:
|
||
|
for _ in range(1, self.frame_stack):
|
||
|
self._frames.append(pixels)
|
||
|
assert len(self._frames) == self.frame_stack
|
||
|
|
||
|
def transform_obs(self, obs, reset=False):
|
||
|
if self.obs_mode == "state":
|
||
|
return obs["observation"]
|
||
|
elif self.obs_mode == "rgb":
|
||
|
self._update_frames(reset=reset)
|
||
|
rgb_obs = np.concatenate(list(self._frames), axis=-1 if self.channel_last else 0)
|
||
|
return rgb_obs
|
||
|
elif self.obs_mode == "all":
|
||
|
self._update_frames(reset=reset)
|
||
|
rgb_obs = np.concatenate(list(self._frames), axis=-1 if self.channel_last else 0)
|
||
|
return OrderedDict((("rgb", rgb_obs), ("state", self.robot_state)))
|
||
|
else:
|
||
|
raise ValueError(f"Unknown obs_mode {self.obs_mode}. Must be one of [rgb, all, state]")
|
||
|
|
||
|
def reset(self):
|
||
|
return self.transform_obs(self._env.reset(), reset=True)
|
||
|
|
||
|
def step(self, action):
|
||
|
action = np.concatenate([action, self.action_padding])
|
||
|
reward = 0.0
|
||
|
for _ in range(self.action_repeat):
|
||
|
obs, r, done, info = self._env.step(action)
|
||
|
reward += r
|
||
|
return self.transform_obs(obs), reward, done, info
|
||
|
|
||
|
def render(self, mode="rgb_array", width=384, height=384, **kwargs):
|
||
|
return self._env.render(mode, width=width, height=height)
|
||
|
|
||
|
@property
|
||
|
def state(self):
|
||
|
return self._env.robot_state
|
||
|
|
||
|
|
||
|
def make(task, obs_mode="state", image_size=84, action_repeat=1, frame_stack=1, channel_last=False, seed=0):
|
||
|
"""
|
||
|
Create a new environment.
|
||
|
Args:
|
||
|
task (str): The task to create an environment for. Must be one of:
|
||
|
- 'reach'
|
||
|
- 'push'
|
||
|
- 'peg-in-box'
|
||
|
- 'lift'
|
||
|
obs_mode (str): The observation mode to use. Must be one of:
|
||
|
- 'state': Only state observations
|
||
|
- 'rgb': RGB images
|
||
|
- 'all': RGB images and state observations
|
||
|
image_size (int): The size of the image observations
|
||
|
action_repeat (int): The number of times to repeat the action
|
||
|
seed (int): The random seed to use
|
||
|
Returns:
|
||
|
gym.Env: The environment
|
||
|
"""
|
||
|
if task not in TASKS:
|
||
|
raise ValueError(f"Unknown task {task}. Must be one of {list(TASKS.keys())}")
|
||
|
env = TASKS[task]["env"]()
|
||
|
env = TimeLimit(env, TASKS[task]["episode_length"])
|
||
|
env = SimXarmWrapper(env, TASKS[task], obs_mode, image_size, action_repeat, frame_stack, channel_last)
|
||
|
env.seed(seed)
|
||
|
|
||
|
return env
|