WIP Aloha env tests pass

This commit is contained in:
Cadene 2024-03-08 14:37:23 +00:00
parent d98b435b4c
commit ebbcad8c05
4 changed files with 234 additions and 126 deletions

View File

@ -1,8 +1,57 @@
from pathlib import Path from pathlib import Path
### Simulation envs fixed constants ### Simulation envs fixed constants
DT = 0.02 DT = 0.02 # 0.02 ms -> 1/0.2 = 50 hz
JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"] FPS = 50
JOINTS = [
# absolute joint position
"left_arm_waist",
"left_arm_shoulder",
"left_arm_elbow",
"left_arm_forearm_roll",
"left_arm_wrist_angle",
"left_arm_wrist_rotate",
# normalized gripper position 0: close, 1: open
"left_arm_gripper",
# absolute joint position
"right_arm_waist",
"right_arm_shoulder",
"right_arm_elbow",
"right_arm_forearm_roll",
"right_arm_wrist_angle",
"right_arm_wrist_rotate",
# normalized gripper position 0: close, 1: open
"right_arm_gripper",
]
# TODO(rcadene): this is for end to end, not when we control end effector
# TODO(rcadene): dimension names are wrong
ACTIONS = [
# position and quaternion for end effector
"left_arm_waist",
"left_arm_shoulder",
"left_arm_elbow",
"left_arm_forearm_roll",
"left_arm_wrist_angle",
"left_arm_wrist_rotate",
"left_arm_left_finger",
# normalized gripper position (0: close, 1: open)
"left_arm_right_finger",
# position and quaternion for end effector
"right_arm_waist",
"right_arm_shoulder",
"right_arm_elbow",
"right_arm_forearm_roll",
"right_arm_wrist_angle",
"right_arm_wrist_rotate",
"right_arm_left_finger",
# normalized gripper position (0: close, 1: open)
"right_arm_right_finger",
]
START_ARM_POSE = [ START_ARM_POSE = [
0, 0,
-0.96, -0.96,
@ -36,62 +85,84 @@ MASTER_GRIPPER_JOINT_CLOSE = -0.6842
PUPPET_GRIPPER_JOINT_OPEN = 1.4910 PUPPET_GRIPPER_JOINT_OPEN = 1.4910
PUPPET_GRIPPER_JOINT_CLOSE = -0.6213 PUPPET_GRIPPER_JOINT_CLOSE = -0.6213
MASTER_GRIPPER_JOINT_MID = (MASTER_GRIPPER_JOINT_OPEN + MASTER_GRIPPER_JOINT_CLOSE) / 2
############################ Helper functions ############################ ############################ Helper functions ############################
MASTER_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_POSITION_CLOSE) / (
MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE
)
PUPPET_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_POSITION_CLOSE) / (
PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE
)
MASTER_GRIPPER_POSITION_UNNORMALIZE_FN = (
lambda x: x * (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE)
+ MASTER_GRIPPER_POSITION_CLOSE
)
PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN = (
lambda x: x * (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE)
+ PUPPET_GRIPPER_POSITION_CLOSE
)
MASTER2PUPPET_POSITION_FN = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(
MASTER_GRIPPER_POSITION_NORMALIZE_FN(x)
)
MASTER_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_JOINT_CLOSE) / ( def normalize_master_gripper_position(x):
MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE return (x - MASTER_GRIPPER_POSITION_CLOSE) / (
) MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE
PUPPET_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_JOINT_CLOSE) / ( )
PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE
)
MASTER_GRIPPER_JOINT_UNNORMALIZE_FN = (
lambda x: x * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE
)
PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN = (
lambda x: x * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE
)
MASTER2PUPPET_JOINT_FN = lambda x: PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(MASTER_GRIPPER_JOINT_NORMALIZE_FN(x))
MASTER_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (
MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE
)
PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (
PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE
)
MASTER_POS2JOINT = ( def normalize_puppet_gripper_position(x):
lambda x: MASTER_GRIPPER_POSITION_NORMALIZE_FN(x) return (x - PUPPET_GRIPPER_POSITION_CLOSE) / (
* (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE
+ MASTER_GRIPPER_JOINT_CLOSE )
)
MASTER_JOINT2POS = lambda x: MASTER_GRIPPER_POSITION_UNNORMALIZE_FN(
(x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
)
PUPPET_POS2JOINT = (
lambda x: PUPPET_GRIPPER_POSITION_NORMALIZE_FN(x)
* (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
+ PUPPET_GRIPPER_JOINT_CLOSE
)
PUPPET_JOINT2POS = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(
(x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
)
MASTER_GRIPPER_JOINT_MID = (MASTER_GRIPPER_JOINT_OPEN + MASTER_GRIPPER_JOINT_CLOSE) / 2
def unnormalize_master_gripper_position(x):
return x * (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) + MASTER_GRIPPER_POSITION_CLOSE
def unnormalize_puppet_gripper_position(x):
return x * (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) + PUPPET_GRIPPER_POSITION_CLOSE
def convert_position_from_master_to_puppet(x):
return unnormalize_puppet_gripper_position(normalize_master_gripper_position(x))
def normalizer_master_gripper_joint(x):
return (x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
def normalize_puppet_gripper_joint(x):
return (x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
def unnormalize_master_gripper_joint(x):
return x * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE
def unnormalize_puppet_gripper_joint(x):
return x * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE
def convert_join_from_master_to_puppet(x):
return unnormalize_puppet_gripper_joint(normalizer_master_gripper_joint(x))
def normalize_master_gripper_velocity(x):
return x / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE)
def normalize_puppet_gripper_velocity(x):
return x / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE)
def convert_master_from_position_to_joint(x):
return (
normalize_master_gripper_position(x) * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
+ MASTER_GRIPPER_JOINT_CLOSE
)
def convert_master_from_joint_to_position(x):
return unnormalize_master_gripper_position(
(x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
)
def convert_puppet_from_position_to_join(x):
return (
normalize_puppet_gripper_position(x) * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
+ PUPPET_GRIPPER_JOINT_CLOSE
)
def convert_puppet_from_joint_to_position(x):
return unnormalize_puppet_gripper_position(
(x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
)

View File

@ -1,8 +1,10 @@
import collections import collections
import importlib import importlib
import logging
from collections import deque from collections import deque
from typing import Optional from typing import Optional
import einops
import numpy as np import numpy as np
import torch import torch
from dm_control import mujoco from dm_control import mujoco
@ -16,59 +18,60 @@ from torchrl.data.tensor_specs import (
UnboundedContinuousTensorSpec, UnboundedContinuousTensorSpec,
) )
from torchrl.envs import EnvBase from torchrl.envs import EnvBase
from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform
from lerobot.common.utils import set_seed from lerobot.common.utils import set_seed
from .constants import ( from .constants import (
ACTIONS,
ASSETS_DIR, ASSETS_DIR,
DT, DT,
JOINTS,
PUPPET_GRIPPER_POSITION_CLOSE, PUPPET_GRIPPER_POSITION_CLOSE,
PUPPET_GRIPPER_POSITION_NORMALIZE_FN,
PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN,
PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN,
START_ARM_POSE, START_ARM_POSE,
normalize_puppet_gripper_position,
normalize_puppet_gripper_velocity,
unnormalize_puppet_gripper_position,
) )
from .utils import sample_box_pose, sample_insertion_pose from .utils import sample_box_pose, sample_insertion_pose
_has_gym = importlib.util.find_spec("gym") is not None _has_gym = importlib.util.find_spec("gym") is not None
def make_ee_sim_env(task_name): # def make_ee_sim_env(task_name):
""" # """
Environment for simulated robot bi-manual manipulation, with end-effector control. # Environment for simulated robot bi-manual manipulation, with end-effector control.
Action space: [left_arm_pose (7), # position and quaternion for end effector # Action space: [left_arm_pose (7), # position and quaternion for end effector
left_gripper_positions (1), # normalized gripper position (0: close, 1: open) # left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
right_arm_pose (7), # position and quaternion for end effector # right_arm_pose (7), # position and quaternion for end effector
right_gripper_positions (1),] # normalized gripper position (0: close, 1: open) # right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position # Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
left_gripper_position (1), # normalized gripper position (0: close, 1: open) # left_gripper_position (1), # normalized gripper position (0: close, 1: open)
right_arm_qpos (6), # absolute joint position # right_arm_qpos (6), # absolute joint position
right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open) # right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
"qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad) # "qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing) # left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
right_arm_qvel (6), # absolute joint velocity (rad) # right_arm_qvel (6), # absolute joint velocity (rad)
right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing) # right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
"images": {"main": (480x640x3)} # h, w, c, dtype='uint8' # "images": {"main": (480x640x3)} # h, w, c, dtype='uint8'
""" # """
if "sim_transfer_cube" in task_name: # if "sim_transfer_cube" in task_name:
xml_path = ASSETS_DIR / "bimanual_viperx_ee_transfer_cube.xml" # xml_path = ASSETS_DIR / "bimanual_viperx_ee_transfer_cube.xml"
physics = mujoco.Physics.from_xml_path(xml_path) # physics = mujoco.Physics.from_xml_path(xml_path)
task = TransferCubeEETask(random=False) # task = TransferCubeEETask(random=False)
env = control.Environment( # env = control.Environment(
physics, task, time_limit=20, control_timestep=DT, n_sub_steps=None, flat_observation=False # physics, task, time_limit=20, control_timestep=DT, n_sub_steps=None, flat_observation=False
) # )
elif "sim_insertion" in task_name: # elif "sim_insertion" in task_name:
xml_path = ASSETS_DIR / "bimanual_viperx_ee_insertion.xml" # xml_path = ASSETS_DIR / "bimanual_viperx_ee_insertion.xml"
physics = mujoco.Physics.from_xml_path(xml_path) # physics = mujoco.Physics.from_xml_path(xml_path)
task = InsertionEETask(random=False) # task = InsertionEETask(random=False)
env = control.Environment( # env = control.Environment(
physics, task, time_limit=20, control_timestep=DT, n_sub_steps=None, flat_observation=False # physics, task, time_limit=20, control_timestep=DT, n_sub_steps=None, flat_observation=False
) # )
else: # else:
raise NotImplementedError # raise NotImplementedError
return env # return env
class BimanualViperXEETask(base.Task): class BimanualViperXEETask(base.Task):
@ -89,8 +92,8 @@ class BimanualViperXEETask(base.Task):
np.copyto(physics.data.mocap_quat[1], action_right[3:7]) np.copyto(physics.data.mocap_quat[1], action_right[3:7])
# set gripper # set gripper
g_left_ctrl = PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(action_left[7]) g_left_ctrl = unnormalize_puppet_gripper_position(action_left[7])
g_right_ctrl = PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(action_right[7]) g_right_ctrl = unnormalize_puppet_gripper_position(action_right[7])
np.copyto(physics.data.ctrl, np.array([g_left_ctrl, -g_left_ctrl, g_right_ctrl, -g_right_ctrl])) np.copyto(physics.data.ctrl, np.array([g_left_ctrl, -g_left_ctrl, g_right_ctrl, -g_right_ctrl]))
def initialize_robots(self, physics): def initialize_robots(self, physics):
@ -131,8 +134,8 @@ class BimanualViperXEETask(base.Task):
right_qpos_raw = qpos_raw[8:16] right_qpos_raw = qpos_raw[8:16]
left_arm_qpos = left_qpos_raw[:6] left_arm_qpos = left_qpos_raw[:6]
right_arm_qpos = right_qpos_raw[:6] right_arm_qpos = right_qpos_raw[:6]
left_gripper_qpos = [PUPPET_GRIPPER_POSITION_NORMALIZE_FN(left_qpos_raw[6])] left_gripper_qpos = [normalize_puppet_gripper_position(left_qpos_raw[6])]
right_gripper_qpos = [PUPPET_GRIPPER_POSITION_NORMALIZE_FN(right_qpos_raw[6])] right_gripper_qpos = [normalize_puppet_gripper_position(right_qpos_raw[6])]
return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos]) return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
@staticmethod @staticmethod
@ -142,8 +145,8 @@ class BimanualViperXEETask(base.Task):
right_qvel_raw = qvel_raw[8:16] right_qvel_raw = qvel_raw[8:16]
left_arm_qvel = left_qvel_raw[:6] left_arm_qvel = left_qvel_raw[:6]
right_arm_qvel = right_qvel_raw[:6] right_arm_qvel = right_qvel_raw[:6]
left_gripper_qvel = [PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(left_qvel_raw[6])] left_gripper_qvel = [normalize_puppet_gripper_velocity(left_qvel_raw[6])]
right_gripper_qvel = [PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(right_qvel_raw[6])] right_gripper_qvel = [normalize_puppet_gripper_velocity(right_qvel_raw[6])]
return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel]) return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
@staticmethod @staticmethod
@ -156,7 +159,7 @@ class BimanualViperXEETask(base.Task):
obs["qpos"] = self.get_qpos(physics) obs["qpos"] = self.get_qpos(physics)
obs["qvel"] = self.get_qvel(physics) obs["qvel"] = self.get_qvel(physics)
obs["env_state"] = self.get_env_state(physics) obs["env_state"] = self.get_env_state(physics)
obs["images"] = dict() obs["images"] = {}
obs["images"]["top"] = physics.render(height=480, width=640, camera_id="top") obs["images"]["top"] = physics.render(height=480, width=640, camera_id="top")
obs["images"]["angle"] = physics.render(height=480, width=640, camera_id="angle") obs["images"]["angle"] = physics.render(height=480, width=640, camera_id="angle")
obs["images"]["vis"] = physics.render(height=480, width=640, camera_id="front_close") obs["images"]["vis"] = physics.render(height=480, width=640, camera_id="front_close")
@ -234,7 +237,9 @@ class InsertionEETask(BimanualViperXEETask):
self.initialize_robots(physics) self.initialize_robots(physics)
# randomize peg and socket position # randomize peg and socket position
peg_pose, socket_pose = sample_insertion_pose() peg_pose, socket_pose = sample_insertion_pose()
id2index = lambda j_id: 16 + (j_id - 16) * 7 # first 16 is robot qpos, 7 is pose dim # hacky
def id2index(j_id):
return 16 + (j_id - 16) * 7 # first 16 is robot qpos, 7 is pose dim # hacky
peg_start_id = physics.model.name2id("red_peg_joint", "joint") peg_start_id = physics.model.name2id("red_peg_joint", "joint")
peg_start_idx = id2index(peg_start_id) peg_start_idx = id2index(peg_start_id)
@ -333,19 +338,22 @@ class AlohaEnv(EnvBase):
if not from_pixels: if not from_pixels:
raise NotImplementedError() raise NotImplementedError()
# time limit is controlled by StepCounter in factory
time_limit = float("inf")
if "sim_transfer_cube" in task: if "sim_transfer_cube" in task:
xml_path = ASSETS_DIR / "bimanual_viperx_ee_transfer_cube.xml" xml_path = ASSETS_DIR / "bimanual_viperx_ee_transfer_cube.xml"
physics = mujoco.Physics.from_xml_path(xml_path) physics = mujoco.Physics.from_xml_path(str(xml_path))
task = TransferCubeEETask(random=False) task = TransferCubeEETask(random=False)
env = control.Environment( env = control.Environment(
physics, task, time_limit=20, control_timestep=DT, n_sub_steps=None, flat_observation=False physics, task, time_limit, control_timestep=DT, n_sub_steps=None, flat_observation=False
) )
elif "sim_insertion" in task: elif "sim_insertion" in task:
xml_path = ASSETS_DIR / "bimanual_viperx_ee_insertion.xml" xml_path = ASSETS_DIR / "bimanual_viperx_ee_insertion.xml"
physics = mujoco.Physics.from_xml_path(xml_path) physics = mujoco.Physics.from_xml_path(str(xml_path))
task = InsertionEETask(random=False) task = InsertionEETask(random=False)
env = control.Environment( env = control.Environment(
physics, task, time_limit=20, control_timestep=DT, n_sub_steps=None, flat_observation=False physics, task, time_limit, control_timestep=DT, n_sub_steps=None, flat_observation=False
) )
else: else:
raise NotImplementedError raise NotImplementedError
@ -373,14 +381,16 @@ class AlohaEnv(EnvBase):
def _format_raw_obs(self, raw_obs): def _format_raw_obs(self, raw_obs):
if self.from_pixels: if self.from_pixels:
image = torch.from_numpy(raw_obs["image"]) image = torch.from_numpy(raw_obs["images"]["top"].copy())
obs = {"image": image} image = einops.rearrange(image, "h w c -> c h w")
obs = {"image": image.type(torch.float32) / 255.0}
if not self.pixels_only: if not self.pixels_only:
obs["state"] = torch.from_numpy(raw_obs["agent_pos"]).type(torch.float32) obs["state"] = torch.from_numpy(raw_obs["qpos"]).type(torch.float32)
else: else:
# TODO: # TODO(rcadene):
obs = {"state": torch.from_numpy(raw_obs["observation"]).type(torch.float32)} raise NotImplementedError()
# obs = {"state": torch.from_numpy(raw_obs["observation"]).type(torch.float32)}
return obs return obs
@ -391,9 +401,10 @@ class AlohaEnv(EnvBase):
self._current_seed += 1 self._current_seed += 1
self.set_seed(self._current_seed) self.set_seed(self._current_seed)
raw_obs = self._env.reset() raw_obs = self._env.reset()
assert self._current_seed == self._env._seed # TODO(rcadene): add assert
# assert self._current_seed == self._env._seed
obs = self._format_raw_obs(raw_obs) obs = self._format_raw_obs(raw_obs.observation)
if self.num_prev_obs > 0: if self.num_prev_obs > 0:
stacked_obs = {} stacked_obs = {}
@ -435,9 +446,12 @@ class AlohaEnv(EnvBase):
num_action_steps = action.shape[0] num_action_steps = action.shape[0]
for i in range(num_action_steps): for i in range(num_action_steps):
raw_obs, reward, done, info = self._env.step(action[i]) _, reward, discount, raw_obs = self._env.step(action[i])
sum_reward += reward del discount # not used
# TOOD(rcadene): add an enum
success = done = reward == 4
sum_reward += reward
obs = self._format_raw_obs(raw_obs) obs = self._format_raw_obs(raw_obs)
if self.num_prev_obs > 0: if self.num_prev_obs > 0:
@ -456,7 +470,7 @@ class AlohaEnv(EnvBase):
"reward": torch.tensor([sum_reward], dtype=torch.float32), "reward": torch.tensor([sum_reward], dtype=torch.float32),
# succes and done are true when coverage > self.success_threshold in env # succes and done are true when coverage > self.success_threshold in env
"done": torch.tensor([done], dtype=torch.bool), "done": torch.tensor([done], dtype=torch.bool),
"success": torch.tensor([done], dtype=torch.bool), "success": torch.tensor([success], dtype=torch.bool),
}, },
batch_size=[], batch_size=[],
) )
@ -464,8 +478,17 @@ class AlohaEnv(EnvBase):
def _make_spec(self): def _make_spec(self):
obs = {} obs = {}
from omegaconf import OmegaConf
if self.from_pixels: if self.from_pixels:
image_shape = (3, self.image_size, self.image_size) if isinstance(self.image_size, int):
image_shape = (3, self.image_size, self.image_size)
elif OmegaConf.is_list(self.image_size):
assert len(self.image_size) == 3 # c h w
assert self.image_size[0] == 3 # c is RGB
image_shape = tuple(self.image_size)
else:
raise ValueError(self.image_size)
if self.num_prev_obs > 0: if self.num_prev_obs > 0:
image_shape = (self.num_prev_obs + 1, *image_shape) image_shape = (self.num_prev_obs + 1, *image_shape)
@ -477,33 +500,44 @@ class AlohaEnv(EnvBase):
device=self.device, device=self.device,
) )
if not self.pixels_only: if not self.pixels_only:
state_shape = self._env.observation_space["agent_pos"].shape state_shape = (len(JOINTS),)
if self.num_prev_obs > 0: if self.num_prev_obs > 0:
state_shape = (self.num_prev_obs + 1, *state_shape) state_shape = (self.num_prev_obs + 1, *state_shape)
obs["state"] = BoundedTensorSpec( obs["state"] = UnboundedContinuousTensorSpec(
low=0, # TODO: add low and high bounds
high=512,
shape=state_shape, shape=state_shape,
dtype=torch.float32, dtype=torch.float32,
device=self.device, device=self.device,
) )
else: else:
# TODO(rcadene): add observation_space achieved_goal and desired_goal? # TODO(rcadene): add observation_space achieved_goal and desired_goal?
state_shape = self._env.observation_space["observation"].shape state_shape = (len(JOINTS),)
if self.num_prev_obs > 0: if self.num_prev_obs > 0:
state_shape = (self.num_prev_obs + 1, *state_shape) state_shape = (self.num_prev_obs + 1, *state_shape)
obs["state"] = UnboundedContinuousTensorSpec( obs["state"] = UnboundedContinuousTensorSpec(
# TODO: # TODO: add low and high bounds
shape=state_shape, shape=state_shape,
dtype=torch.float32, dtype=torch.float32,
device=self.device, device=self.device,
) )
self.observation_spec = CompositeSpec({"observation": obs}) self.observation_spec = CompositeSpec({"observation": obs})
self.action_spec = _gym_to_torchrl_spec_transform( # TODO(rcadene): valid when controling end effector?
self._env.action_space, # action_space = self._env.action_spec()
# self.action_spec = BoundedTensorSpec(
# low=action_space.minimum,
# high=action_space.maximum,
# shape=action_space.shape,
# dtype=torch.float32,
# device=self.device,
# )
# TODO(rcaene): add bounds (where are they????)
self.action_spec = UnboundedContinuousTensorSpec(
shape=(len(ACTIONS)),
dtype=torch.float32,
device=self.device, device=self.device,
) )
@ -532,4 +566,6 @@ class AlohaEnv(EnvBase):
def _set_seed(self, seed: Optional[int]): def _set_seed(self, seed: Optional[int]):
set_seed(seed) set_seed(seed)
self._env.seed(seed) # TODO(rcadene): seed the env
# self._env.seed(seed)
logging.warning("Aloha env is not seeded")

View File

@ -26,6 +26,7 @@ def make_env(cfg, transform=None):
elif cfg.env.name == "aloha": elif cfg.env.name == "aloha":
from lerobot.common.envs.aloha.env import AlohaEnv from lerobot.common.envs.aloha.env import AlohaEnv
kwargs["task"] = cfg.env.task
clsfunc = AlohaEnv clsfunc = AlohaEnv
else: else:
raise ValueError(cfg.env.name) raise ValueError(cfg.env.name)

View File

@ -15,7 +15,7 @@ env:
task: sim_insertion_human task: sim_insertion_human
from_pixels: True from_pixels: True
pixels_only: False pixels_only: False
image_size: 96 image_size: [3, 480, 640]
action_repeat: 1 action_repeat: 1
episode_length: 300 episode_length: 300
fps: ${fps} fps: ${fps}