Add maniskill support.
Co-authored-by: Michel Aractingi <michel.aractingi@gmail.com>
This commit is contained in:
parent
291358d6a2
commit
b7a0ffc3b8
|
@ -5,11 +5,16 @@ fps: 20
|
||||||
env:
|
env:
|
||||||
name: maniskill/pushcube
|
name: maniskill/pushcube
|
||||||
task: PushCube-v1
|
task: PushCube-v1
|
||||||
image_size: 64
|
image_size: 128
|
||||||
control_mode: pd_ee_delta_pose
|
control_mode: pd_ee_delta_pose
|
||||||
state_dim: 25
|
state_dim: 25
|
||||||
action_dim: 7
|
action_dim: 7
|
||||||
fps: ${fps}
|
fps: ${fps}
|
||||||
obs: rgb
|
obs: rgb
|
||||||
render_mode: rgb_array
|
render_mode: rgb_array
|
||||||
render_size: 64
|
render_size: 128
|
||||||
|
device: cuda
|
||||||
|
|
||||||
|
reward_classifier:
|
||||||
|
pretrained_path: null
|
||||||
|
config_path: null
|
|
@ -8,7 +8,7 @@
|
||||||
# env.gym.obs_type=environment_state_agent_pos \
|
# env.gym.obs_type=environment_state_agent_pos \
|
||||||
|
|
||||||
seed: 1
|
seed: 1
|
||||||
dataset_repo_id: aractingi/hil-serl-maniskill-pushcube
|
dataset_repo_id: null
|
||||||
|
|
||||||
training:
|
training:
|
||||||
# Offline training dataloader
|
# Offline training dataloader
|
||||||
|
@ -20,7 +20,7 @@ training:
|
||||||
lr: 3e-4
|
lr: 3e-4
|
||||||
|
|
||||||
eval_freq: 2500
|
eval_freq: 2500
|
||||||
log_freq: 500
|
log_freq: 10
|
||||||
save_freq: 2000000
|
save_freq: 2000000
|
||||||
|
|
||||||
online_steps: 1000000
|
online_steps: 1000000
|
||||||
|
@ -52,14 +52,16 @@ policy:
|
||||||
n_action_steps: 1
|
n_action_steps: 1
|
||||||
|
|
||||||
shared_encoder: true
|
shared_encoder: true
|
||||||
# vision_encoder_name: null
|
vision_encoder_name: null
|
||||||
|
# vision_encoder_name: "helper2424/resnet10"
|
||||||
|
# freeze_vision_encoder: true
|
||||||
freeze_vision_encoder: false
|
freeze_vision_encoder: false
|
||||||
input_shapes:
|
input_shapes:
|
||||||
# # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
# # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||||
observation.state: ["${env.state_dim}"]
|
observation.state: ["${env.state_dim}"]
|
||||||
observation.image: [3, 64, 64]
|
observation.image: [3, 128, 128]
|
||||||
output_shapes:
|
output_shapes:
|
||||||
action: ["${env.action_dim}"]
|
action: [7]
|
||||||
|
|
||||||
# Normalization / Unnormalization
|
# Normalization / Unnormalization
|
||||||
input_normalization_modes: null
|
input_normalization_modes: null
|
||||||
|
@ -67,8 +69,8 @@ policy:
|
||||||
action: min_max
|
action: min_max
|
||||||
output_normalization_params:
|
output_normalization_params:
|
||||||
action:
|
action:
|
||||||
min: [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0]
|
min: [-10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0]
|
||||||
max: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
|
max: [10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0]
|
||||||
|
|
||||||
# Architecture / modeling.
|
# Architecture / modeling.
|
||||||
# Neural networks.
|
# Neural networks.
|
||||||
|
@ -88,14 +90,3 @@ policy:
|
||||||
actor_learner_config:
|
actor_learner_config:
|
||||||
actor_ip: "127.0.0.1"
|
actor_ip: "127.0.0.1"
|
||||||
port: 50051
|
port: 50051
|
||||||
|
|
||||||
# # Loss coefficients.
|
|
||||||
# reward_coeff: 0.5
|
|
||||||
# expectile_weight: 0.9
|
|
||||||
# value_coeff: 0.1
|
|
||||||
# consistency_coeff: 20.0
|
|
||||||
# advantage_scaling: 3.0
|
|
||||||
# pi_coeff: 0.5
|
|
||||||
# temporal_decay_coeff: 0.5
|
|
||||||
# # Target model.
|
|
||||||
# target_model_momentum: 0.995
|
|
||||||
|
|
|
@ -251,7 +251,7 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
|
||||||
sum_reward_episode += float(reward)
|
sum_reward_episode += float(reward)
|
||||||
|
|
||||||
# NOTE: We overide the action if the intervention is True, because the action applied is the intervention action
|
# NOTE: We overide the action if the intervention is True, because the action applied is the intervention action
|
||||||
if info["is_intervention"]:
|
if "is_intervention" in info and info["is_intervention"]:
|
||||||
# TODO: Check the shape
|
# TODO: Check the shape
|
||||||
# NOTE: The action space for demonstration before hand is with the full action space
|
# NOTE: The action space for demonstration before hand is with the full action space
|
||||||
# but sometimes for example we want to deactivate the gripper
|
# but sometimes for example we want to deactivate the gripper
|
||||||
|
@ -348,6 +348,14 @@ def actor_cli(cfg: dict):
|
||||||
robot = make_robot(cfg=cfg.robot)
|
robot = make_robot(cfg=cfg.robot)
|
||||||
|
|
||||||
server_thread = Thread(target=serve_actor_service, args=(cfg.actor_learner_config.port,), daemon=True)
|
server_thread = Thread(target=serve_actor_service, args=(cfg.actor_learner_config.port,), daemon=True)
|
||||||
|
|
||||||
|
# HACK: FOR MANISKILL we do not have a reward classifier
|
||||||
|
# TODO: Remove this once we merge into main
|
||||||
|
reward_classifier = None
|
||||||
|
if (
|
||||||
|
cfg.env.reward_classifier.pretrained_path is not None
|
||||||
|
and cfg.env.reward_classifier.config_path is not None
|
||||||
|
):
|
||||||
reward_classifier = get_classifier(
|
reward_classifier = get_classifier(
|
||||||
pretrained_path=cfg.env.reward_classifier.pretrained_path,
|
pretrained_path=cfg.env.reward_classifier.pretrained_path,
|
||||||
config_path=cfg.env.reward_classifier.config_path,
|
config_path=cfg.env.reward_classifier.config_path,
|
||||||
|
|
|
@ -13,6 +13,8 @@ from lerobot.common.envs.utils import preprocess_observation
|
||||||
from lerobot.common.robot_devices.control_utils import busy_wait, is_headless, reset_follower_position
|
from lerobot.common.robot_devices.control_utils import busy_wait, is_headless, reset_follower_position
|
||||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||||
from lerobot.common.utils.utils import init_hydra_config, log_say
|
from lerobot.common.utils.utils import init_hydra_config, log_say
|
||||||
|
from lerobot.scripts.server.maniskill_manipulator import make_maniskill
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
@ -661,6 +663,9 @@ class BatchCompitableWrapper(gym.ObservationWrapper):
|
||||||
return observation
|
return observation
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: REMOVE TH
|
||||||
|
|
||||||
|
|
||||||
def make_robot_env(
|
def make_robot_env(
|
||||||
robot,
|
robot,
|
||||||
reward_classifier,
|
reward_classifier,
|
||||||
|
@ -679,7 +684,17 @@ def make_robot_env(
|
||||||
Returns:
|
Returns:
|
||||||
A vectorized gym environment with all the necessary wrappers applied.
|
A vectorized gym environment with all the necessary wrappers applied.
|
||||||
"""
|
"""
|
||||||
|
if "maniskill" in cfg.name:
|
||||||
|
logging.warning("WE SHOULD REMOVE THE MANISKILL BEFORE THE MERGE INTO MAIN")
|
||||||
|
env = make_maniskill(
|
||||||
|
task=cfg.task,
|
||||||
|
obs_mode=cfg.obs,
|
||||||
|
control_mode=cfg.control_mode,
|
||||||
|
render_mode=cfg.render_mode,
|
||||||
|
sensor_configs={"width": cfg.render_size, "height": cfg.render_size},
|
||||||
|
device=cfg.device,
|
||||||
|
)
|
||||||
|
return env
|
||||||
# Create base environment
|
# Create base environment
|
||||||
env = HILSerlRobotEnv(
|
env = HILSerlRobotEnv(
|
||||||
robot=robot,
|
robot=robot,
|
||||||
|
|
|
@ -362,7 +362,7 @@ def add_actor_information_and_train(
|
||||||
# If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging
|
# If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging
|
||||||
interaction_message["Interaction step"] += interaction_step_shift
|
interaction_message["Interaction step"] += interaction_step_shift
|
||||||
logger.log_dict(interaction_message, mode="train", custom_step_key="Interaction step")
|
logger.log_dict(interaction_message, mode="train", custom_step_key="Interaction step")
|
||||||
logging.info(f"Interaction message: {interaction_message}")
|
# logging.info(f"Interaction message: {interaction_message}")
|
||||||
|
|
||||||
if len(replay_buffer) < cfg.training.online_step_before_learning:
|
if len(replay_buffer) < cfg.training.online_step_before_learning:
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -0,0 +1,176 @@
|
||||||
|
import einops
|
||||||
|
import numpy as np
|
||||||
|
import gymnasium as gym
|
||||||
|
import torch
|
||||||
|
|
||||||
|
"""Make ManiSkill3 gym environment"""
|
||||||
|
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_maniskill_observation(observations: dict[str, np.ndarray]) -> dict[str, torch.Tensor]:
|
||||||
|
"""Convert environment observation to LeRobot format observation.
|
||||||
|
Args:
|
||||||
|
observation: Dictionary of observation batches from a Gym vector environment.
|
||||||
|
Returns:
|
||||||
|
Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
|
||||||
|
"""
|
||||||
|
# map to expected inputs for the policy
|
||||||
|
return_observations = {}
|
||||||
|
# TODO: You have to merge all tensors from agent key and extra key
|
||||||
|
# You don't keep sensor param key in the observation
|
||||||
|
# And you keep sensor data rgb
|
||||||
|
q_pos = observations["agent"]["qpos"]
|
||||||
|
q_vel = observations["agent"]["qvel"]
|
||||||
|
tcp_pos = observations["extra"]["tcp_pose"]
|
||||||
|
img = observations["sensor_data"]["base_camera"]["rgb"]
|
||||||
|
|
||||||
|
_, h, w, c = img.shape
|
||||||
|
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
|
||||||
|
|
||||||
|
# sanity check that images are uint8
|
||||||
|
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||||
|
|
||||||
|
# convert to channel first of type float32 in range [0,1]
|
||||||
|
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
|
||||||
|
img = img.type(torch.float32)
|
||||||
|
img /= 255
|
||||||
|
|
||||||
|
state = torch.cat([q_pos, q_vel, tcp_pos], dim=-1)
|
||||||
|
|
||||||
|
return_observations["observation.image"] = img
|
||||||
|
return_observations["observation.state"] = state
|
||||||
|
return return_observations
|
||||||
|
|
||||||
|
|
||||||
|
class ManiSkillObservationWrapper(gym.ObservationWrapper):
|
||||||
|
def __init__(self, env):
|
||||||
|
super().__init__(env)
|
||||||
|
|
||||||
|
def observation(self, observation):
|
||||||
|
return preprocess_maniskill_observation(observation)
|
||||||
|
|
||||||
|
|
||||||
|
class ManiSkillToDeviceWrapper(gym.Wrapper):
|
||||||
|
def __init__(self, env, device: torch.device = "cuda"):
|
||||||
|
super().__init__(env)
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
def reset(self, seed=None, options=None):
|
||||||
|
obs, info = self.env.reset(seed=seed, options=options)
|
||||||
|
obs = {k: v.to(self.device) for k, v in obs.items()}
|
||||||
|
return obs, info
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||||
|
obs = {k: v.to(self.device) for k, v in obs.items()}
|
||||||
|
return obs, reward, terminated, truncated, info
|
||||||
|
|
||||||
|
|
||||||
|
class ManiSkillCompat(gym.Wrapper):
|
||||||
|
def __init__(self, env):
|
||||||
|
super().__init__(env)
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||||
|
reward = reward.item()
|
||||||
|
terminated = terminated.item()
|
||||||
|
truncated = truncated.item()
|
||||||
|
return obs, reward, terminated, truncated, info
|
||||||
|
|
||||||
|
|
||||||
|
class ManiSkillActionWrapper(gym.ActionWrapper):
|
||||||
|
def __init__(self, env):
|
||||||
|
super().__init__(env)
|
||||||
|
self.action_space = gym.spaces.Tuple(spaces=(env.action_space, gym.spaces.Discrete(2)))
|
||||||
|
|
||||||
|
def action(self, action):
|
||||||
|
action, telop = action
|
||||||
|
return action
|
||||||
|
|
||||||
|
|
||||||
|
class ManiSkillMultiplyActionWrapper(gym.Wrapper):
|
||||||
|
def __init__(self, env, multiply_factor: float = 10):
|
||||||
|
super().__init__(env)
|
||||||
|
self.multiply_factor = multiply_factor
|
||||||
|
action_space_agent: gym.spaces.Box = env.action_space[0]
|
||||||
|
action_space_agent.low = action_space_agent.low * multiply_factor
|
||||||
|
action_space_agent.high = action_space_agent.high * multiply_factor
|
||||||
|
self.action_space = gym.spaces.Tuple(spaces=(action_space_agent, gym.spaces.Discrete(2)))
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
if isinstance(action, tuple):
|
||||||
|
action, telop = action
|
||||||
|
else:
|
||||||
|
telop = 0
|
||||||
|
action = action / self.multiply_factor
|
||||||
|
obs, reward, terminated, truncated, info = self.env.step((action, telop))
|
||||||
|
return obs, reward, terminated, truncated, info
|
||||||
|
|
||||||
|
|
||||||
|
def make_maniskill(
|
||||||
|
task: str = "PushCube-v1",
|
||||||
|
obs_mode: str = "rgb",
|
||||||
|
control_mode: str = "pd_ee_delta_pose",
|
||||||
|
render_mode: str = "rgb_array",
|
||||||
|
sensor_configs: dict[str, int] | None = None,
|
||||||
|
n_envs: int = 1,
|
||||||
|
device: torch.device = "cuda",
|
||||||
|
) -> gym.Env:
|
||||||
|
"""
|
||||||
|
Factory function to create a ManiSkill environment with standard wrappers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: Name of the ManiSkill task
|
||||||
|
obs_mode: Observation mode (rgb, rgbd, etc)
|
||||||
|
control_mode: Control mode for the robot
|
||||||
|
render_mode: Rendering mode
|
||||||
|
sensor_configs: Camera sensor configurations
|
||||||
|
n_envs: Number of parallel environments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A wrapped ManiSkill environment
|
||||||
|
"""
|
||||||
|
if sensor_configs is None:
|
||||||
|
sensor_configs = {"width": 64, "height": 64}
|
||||||
|
|
||||||
|
env = gym.make(
|
||||||
|
task,
|
||||||
|
obs_mode=obs_mode,
|
||||||
|
control_mode=control_mode,
|
||||||
|
render_mode=render_mode,
|
||||||
|
sensor_configs=sensor_configs,
|
||||||
|
num_envs=n_envs,
|
||||||
|
)
|
||||||
|
env = ManiSkillCompat(env)
|
||||||
|
env = ManiSkillObservationWrapper(env)
|
||||||
|
env = ManiSkillActionWrapper(env)
|
||||||
|
env = ManiSkillMultiplyActionWrapper(env)
|
||||||
|
env = ManiSkillToDeviceWrapper(env, device=device)
|
||||||
|
return env
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
import hydra
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--config", type=str, default="lerobot/configs/env/maniskill_example.yaml")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Initialize config
|
||||||
|
with hydra.initialize(version_base=None, config_path="../../configs"):
|
||||||
|
cfg = hydra.compose(config_name="env/maniskill_example.yaml")
|
||||||
|
|
||||||
|
env = make_maniskill(
|
||||||
|
task=cfg.env.task,
|
||||||
|
obs_mode=cfg.env.obs,
|
||||||
|
control_mode=cfg.env.control_mode,
|
||||||
|
render_mode=cfg.env.render_mode,
|
||||||
|
sensor_configs={"width": cfg.env.render_size, "height": cfg.env.render_size},
|
||||||
|
)
|
||||||
|
|
||||||
|
print("env done")
|
||||||
|
obs, info = env.reset()
|
||||||
|
random_action = env.action_space.sample()
|
||||||
|
obs, reward, terminated, truncated, info = env.step(random_action)
|
Loading…
Reference in New Issue