diff --git a/lerobot/configs/env/maniskill_example.yaml b/lerobot/configs/env/maniskill_example.yaml index cedf7a30..03814614 100644 --- a/lerobot/configs/env/maniskill_example.yaml +++ b/lerobot/configs/env/maniskill_example.yaml @@ -5,11 +5,16 @@ fps: 20 env: name: maniskill/pushcube task: PushCube-v1 - image_size: 64 + image_size: 128 control_mode: pd_ee_delta_pose state_dim: 25 action_dim: 7 fps: ${fps} obs: rgb render_mode: rgb_array - render_size: 64 \ No newline at end of file + render_size: 128 + device: cuda + + reward_classifier: + pretrained_path: null + config_path: null \ No newline at end of file diff --git a/lerobot/configs/policy/sac_maniskill.yaml b/lerobot/configs/policy/sac_maniskill.yaml index aaf59e53..8a36947c 100644 --- a/lerobot/configs/policy/sac_maniskill.yaml +++ b/lerobot/configs/policy/sac_maniskill.yaml @@ -8,7 +8,7 @@ # env.gym.obs_type=environment_state_agent_pos \ seed: 1 -dataset_repo_id: aractingi/hil-serl-maniskill-pushcube +dataset_repo_id: null training: # Offline training dataloader @@ -20,7 +20,7 @@ training: lr: 3e-4 eval_freq: 2500 - log_freq: 500 + log_freq: 10 save_freq: 2000000 online_steps: 1000000 @@ -52,14 +52,16 @@ policy: n_action_steps: 1 shared_encoder: true - # vision_encoder_name: null + vision_encoder_name: null + # vision_encoder_name: "helper2424/resnet10" + # freeze_vision_encoder: true freeze_vision_encoder: false input_shapes: # # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env? observation.state: ["${env.state_dim}"] - observation.image: [3, 64, 64] + observation.image: [3, 128, 128] output_shapes: - action: ["${env.action_dim}"] + action: [7] # Normalization / Unnormalization input_normalization_modes: null @@ -67,8 +69,8 @@ policy: action: min_max output_normalization_params: action: - min: [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0] - max: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] + min: [-10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0] + max: [10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0] # Architecture / modeling. # Neural networks. @@ -88,14 +90,3 @@ policy: actor_learner_config: actor_ip: "127.0.0.1" port: 50051 - - # # Loss coefficients. - # reward_coeff: 0.5 - # expectile_weight: 0.9 - # value_coeff: 0.1 - # consistency_coeff: 20.0 - # advantage_scaling: 3.0 - # pi_coeff: 0.5 - # temporal_decay_coeff: 0.5 - # # Target model. - # target_model_momentum: 0.995 diff --git a/lerobot/scripts/server/actor_server.py b/lerobot/scripts/server/actor_server.py index 8284f024..b5a6183d 100644 --- a/lerobot/scripts/server/actor_server.py +++ b/lerobot/scripts/server/actor_server.py @@ -251,7 +251,7 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module) sum_reward_episode += float(reward) # NOTE: We overide the action if the intervention is True, because the action applied is the intervention action - if info["is_intervention"]: + if "is_intervention" in info and info["is_intervention"]: # TODO: Check the shape # NOTE: The action space for demonstration before hand is with the full action space # but sometimes for example we want to deactivate the gripper @@ -348,10 +348,18 @@ def actor_cli(cfg: dict): robot = make_robot(cfg=cfg.robot) server_thread = Thread(target=serve_actor_service, args=(cfg.actor_learner_config.port,), daemon=True) - reward_classifier = get_classifier( - pretrained_path=cfg.env.reward_classifier.pretrained_path, - config_path=cfg.env.reward_classifier.config_path, - ) + + # HACK: FOR MANISKILL we do not have a reward classifier + # TODO: Remove this once we merge into main + reward_classifier = None + if ( + cfg.env.reward_classifier.pretrained_path is not None + and cfg.env.reward_classifier.config_path is not None + ): + reward_classifier = get_classifier( + pretrained_path=cfg.env.reward_classifier.pretrained_path, + config_path=cfg.env.reward_classifier.config_path, + ) policy_thread = Thread( target=act_with_policy, daemon=True, diff --git a/lerobot/scripts/server/gym_manipulator.py b/lerobot/scripts/server/gym_manipulator.py index b3d71d4d..d0fabc8e 100644 --- a/lerobot/scripts/server/gym_manipulator.py +++ b/lerobot/scripts/server/gym_manipulator.py @@ -13,6 +13,8 @@ from lerobot.common.envs.utils import preprocess_observation from lerobot.common.robot_devices.control_utils import busy_wait, is_headless, reset_follower_position from lerobot.common.robot_devices.robots.factory import make_robot from lerobot.common.utils.utils import init_hydra_config, log_say +from lerobot.scripts.server.maniskill_manipulator import make_maniskill + logging.basicConfig(level=logging.INFO) @@ -661,6 +663,9 @@ class BatchCompitableWrapper(gym.ObservationWrapper): return observation +# TODO: REMOVE TH + + def make_robot_env( robot, reward_classifier, @@ -679,7 +684,17 @@ def make_robot_env( Returns: A vectorized gym environment with all the necessary wrappers applied. """ - + if "maniskill" in cfg.name: + logging.warning("WE SHOULD REMOVE THE MANISKILL BEFORE THE MERGE INTO MAIN") + env = make_maniskill( + task=cfg.task, + obs_mode=cfg.obs, + control_mode=cfg.control_mode, + render_mode=cfg.render_mode, + sensor_configs={"width": cfg.render_size, "height": cfg.render_size}, + device=cfg.device, + ) + return env # Create base environment env = HILSerlRobotEnv( robot=robot, diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index 2d8eab67..78b5d7b8 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -362,7 +362,7 @@ def add_actor_information_and_train( # If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging interaction_message["Interaction step"] += interaction_step_shift logger.log_dict(interaction_message, mode="train", custom_step_key="Interaction step") - logging.info(f"Interaction message: {interaction_message}") + # logging.info(f"Interaction message: {interaction_message}") if len(replay_buffer) < cfg.training.online_step_before_learning: continue diff --git a/lerobot/scripts/server/maniskill_manipulator.py b/lerobot/scripts/server/maniskill_manipulator.py new file mode 100644 index 00000000..8544d157 --- /dev/null +++ b/lerobot/scripts/server/maniskill_manipulator.py @@ -0,0 +1,176 @@ +import einops +import numpy as np +import gymnasium as gym +import torch + +"""Make ManiSkill3 gym environment""" +from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv + + +def preprocess_maniskill_observation(observations: dict[str, np.ndarray]) -> dict[str, torch.Tensor]: + """Convert environment observation to LeRobot format observation. + Args: + observation: Dictionary of observation batches from a Gym vector environment. + Returns: + Dictionary of observation batches with keys renamed to LeRobot format and values as tensors. + """ + # map to expected inputs for the policy + return_observations = {} + # TODO: You have to merge all tensors from agent key and extra key + # You don't keep sensor param key in the observation + # And you keep sensor data rgb + q_pos = observations["agent"]["qpos"] + q_vel = observations["agent"]["qvel"] + tcp_pos = observations["extra"]["tcp_pose"] + img = observations["sensor_data"]["base_camera"]["rgb"] + + _, h, w, c = img.shape + assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}" + + # sanity check that images are uint8 + assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}" + + # convert to channel first of type float32 in range [0,1] + img = einops.rearrange(img, "b h w c -> b c h w").contiguous() + img = img.type(torch.float32) + img /= 255 + + state = torch.cat([q_pos, q_vel, tcp_pos], dim=-1) + + return_observations["observation.image"] = img + return_observations["observation.state"] = state + return return_observations + + +class ManiSkillObservationWrapper(gym.ObservationWrapper): + def __init__(self, env): + super().__init__(env) + + def observation(self, observation): + return preprocess_maniskill_observation(observation) + + +class ManiSkillToDeviceWrapper(gym.Wrapper): + def __init__(self, env, device: torch.device = "cuda"): + super().__init__(env) + self.device = device + + def reset(self, seed=None, options=None): + obs, info = self.env.reset(seed=seed, options=options) + obs = {k: v.to(self.device) for k, v in obs.items()} + return obs, info + + def step(self, action): + obs, reward, terminated, truncated, info = self.env.step(action) + obs = {k: v.to(self.device) for k, v in obs.items()} + return obs, reward, terminated, truncated, info + + +class ManiSkillCompat(gym.Wrapper): + def __init__(self, env): + super().__init__(env) + + def step(self, action): + obs, reward, terminated, truncated, info = self.env.step(action) + reward = reward.item() + terminated = terminated.item() + truncated = truncated.item() + return obs, reward, terminated, truncated, info + + +class ManiSkillActionWrapper(gym.ActionWrapper): + def __init__(self, env): + super().__init__(env) + self.action_space = gym.spaces.Tuple(spaces=(env.action_space, gym.spaces.Discrete(2))) + + def action(self, action): + action, telop = action + return action + + +class ManiSkillMultiplyActionWrapper(gym.Wrapper): + def __init__(self, env, multiply_factor: float = 10): + super().__init__(env) + self.multiply_factor = multiply_factor + action_space_agent: gym.spaces.Box = env.action_space[0] + action_space_agent.low = action_space_agent.low * multiply_factor + action_space_agent.high = action_space_agent.high * multiply_factor + self.action_space = gym.spaces.Tuple(spaces=(action_space_agent, gym.spaces.Discrete(2))) + + def step(self, action): + if isinstance(action, tuple): + action, telop = action + else: + telop = 0 + action = action / self.multiply_factor + obs, reward, terminated, truncated, info = self.env.step((action, telop)) + return obs, reward, terminated, truncated, info + + +def make_maniskill( + task: str = "PushCube-v1", + obs_mode: str = "rgb", + control_mode: str = "pd_ee_delta_pose", + render_mode: str = "rgb_array", + sensor_configs: dict[str, int] | None = None, + n_envs: int = 1, + device: torch.device = "cuda", +) -> gym.Env: + """ + Factory function to create a ManiSkill environment with standard wrappers. + + Args: + task: Name of the ManiSkill task + obs_mode: Observation mode (rgb, rgbd, etc) + control_mode: Control mode for the robot + render_mode: Rendering mode + sensor_configs: Camera sensor configurations + n_envs: Number of parallel environments + + Returns: + A wrapped ManiSkill environment + """ + if sensor_configs is None: + sensor_configs = {"width": 64, "height": 64} + + env = gym.make( + task, + obs_mode=obs_mode, + control_mode=control_mode, + render_mode=render_mode, + sensor_configs=sensor_configs, + num_envs=n_envs, + ) + env = ManiSkillCompat(env) + env = ManiSkillObservationWrapper(env) + env = ManiSkillActionWrapper(env) + env = ManiSkillMultiplyActionWrapper(env) + env = ManiSkillToDeviceWrapper(env, device=device) + return env + + +if __name__ == "__main__": + import argparse + import hydra + from omegaconf import OmegaConf + + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, default="lerobot/configs/env/maniskill_example.yaml") + args = parser.parse_args() + + # Initialize config + with hydra.initialize(version_base=None, config_path="../../configs"): + cfg = hydra.compose(config_name="env/maniskill_example.yaml") + + env = make_maniskill( + task=cfg.env.task, + obs_mode=cfg.env.obs, + control_mode=cfg.env.control_mode, + render_mode=cfg.env.render_mode, + sensor_configs={"width": cfg.env.render_size, "height": cfg.env.render_size}, + ) + + print("env done") + obs, info = env.reset() + random_action = env.action_space.sample() + obs, reward, terminated, truncated, info = env.step(random_action)