From 729b4ed697990fc6f84122c5cab6970879ee169b Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Thu, 6 Feb 2025 16:29:37 +0100 Subject: [PATCH] - Added `lerobot/scripts/server/gym_manipulator.py` that contains all the necessary wrappers to run a gym-style env around the real robot. - Added `lerobot/scripts/server/find_joint_limits.py` to test the min and max angles of the motion you wish the robot to explore during RL training. - Added logic in `manipulator.py` to limit the maximum possible joint angles to allow motion within a predefined joint position range. The limits are specified in the yaml config for each robot. Checkout the so100.yaml. Co-authored-by: Adil Zouitine --- lerobot/common/envs/factory.py | 12 +- lerobot/common/envs/utils.py | 25 +- .../hilserl/classifier/modeling_classifier.py | 15 +- .../robot_devices/robots/manipulator.py | 18 +- .../configs/policy/hilserl_classifier.yaml | 7 +- lerobot/configs/robot/so100.yaml | 3 + lerobot/scripts/server/find_joint_limits.py | 64 ++ lerobot/scripts/server/gym_manipulator.py | 697 ++++++++++++++++++ 8 files changed, 812 insertions(+), 29 deletions(-) create mode 100644 lerobot/scripts/server/find_joint_limits.py create mode 100644 lerobot/scripts/server/gym_manipulator.py diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index 8aec915c..96ee7448 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -126,28 +126,30 @@ class PixelWrapper(gym.Wrapper): obs, reward, terminated, truncated, info = self.env.step(action) return self._get_obs(obs), reward, terminated, truncated, info + class ConvertToLeRobotEnv(gym.Wrapper): def __init__(self, env, num_envs): super().__init__(env) + def reset(self, seed=None, options=None): obs, info = self.env.reset(seed=seed, options={}) return self._get_obs(obs), info + def step(self, action): obs, reward, terminated, truncated, info = self.env.step(action) return self._get_obs(obs), reward, terminated, truncated, info + def _get_obs(self, observation): sensor_data = observation.pop("sensor_data") del observation["sensor_param"] images = [] for cam_data in sensor_data.values(): - images.append(cam_data["rgb"]) + images.append(cam_data["rgb"]) images = torch.concat(images, axis=-1) # flatten the rest of the data which should just be state data - observation = common.flatten_state_dict( - observation, use_torch=True, device=self.base_env.device - ) + observation = common.flatten_state_dict(observation, use_torch=True, device=self.base_env.device) ret = dict() ret["state"] = observation ret["pixels"] = images - return ret \ No newline at end of file + return ret diff --git a/lerobot/common/envs/utils.py b/lerobot/common/envs/utils.py index ead6bf45..b44e041b 100644 --- a/lerobot/common/envs/utils.py +++ b/lerobot/common/envs/utils.py @@ -36,11 +36,9 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten # TODO: You have to merge all tensors from agent key and extra key # You don't keep sensor param key in the observation # And you keep sensor data rgb - if "pixels" in observations: - if isinstance(observations["pixels"], dict): - imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()} - else: - imgs = {"observation.image": observations["pixels"]} + for key, img in observations.items(): + if "images" not in key: + continue for imgkey, img in imgs.items(): # TODO(aliberts, rcadene): use transforms.ToTensor()? @@ -50,15 +48,15 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten _, h, w, c = img.shape assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}" - # sanity check that images are uint8 - assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}" + # sanity check that images are uint8 + assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}" - # convert to channel first of type float32 in range [0,1] - img = einops.rearrange(img, "b h w c -> b c h w").contiguous() - img = img.type(torch.float32) - img /= 255 + # convert to channel first of type float32 in range [0,1] + img = einops.rearrange(img, "b h w c -> b c h w").contiguous() + img = img.type(torch.float32) + img /= 255 - return_observations[imgkey] = img + return_observations[key] = img # obs state agent qpos and qvel # image @@ -69,7 +67,8 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten # TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing # requirement for "agent_pos" - return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float() + # return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float() + return_observations["observation.state"] = observations["observation.state"].float() return return_observations diff --git a/lerobot/common/policies/hilserl/classifier/modeling_classifier.py b/lerobot/common/policies/hilserl/classifier/modeling_classifier.py index 4a022335..58532302 100644 --- a/lerobot/common/policies/hilserl/classifier/modeling_classifier.py +++ b/lerobot/common/policies/hilserl/classifier/modeling_classifier.py @@ -47,7 +47,7 @@ class Classifier( super().__init__() self.config = config - self.processor = AutoImageProcessor.from_pretrained(self.config.model_name, trust_remote_code=True) + # self.processor = AutoImageProcessor.from_pretrained(self.config.model_name, trust_remote_code=True) encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True) # Extract vision model if we're given a multimodal model if hasattr(encoder, "vision_model"): @@ -108,11 +108,12 @@ class Classifier( def _get_encoder_output(self, x: torch.Tensor) -> torch.Tensor: """Extract the appropriate output from the encoder.""" # Process images with the processor (handles resizing and normalization) - processed = self.processor( - images=x, # LeRobotDataset already provides proper tensor format - return_tensors="pt", - ) - processed = processed["pixel_values"].to(x.device) + # processed = self.processor( + # images=x, # LeRobotDataset already provides proper tensor format + # return_tensors="pt", + # ) + # processed = processed["pixel_values"].to(x.device) + processed = x with torch.no_grad(): if self.is_cnn: @@ -146,6 +147,6 @@ class Classifier( def predict_reward(self, x): if self.config.num_classes == 2: - return (self.forward(x).probabilities > 0.5).float() + return (self.forward(x).probabilities > 0.6).float() else: return torch.argmax(self.forward(x).probabilities, dim=1) diff --git a/lerobot/common/robot_devices/robots/manipulator.py b/lerobot/common/robot_devices/robots/manipulator.py index 9173abc6..05ced833 100644 --- a/lerobot/common/robot_devices/robots/manipulator.py +++ b/lerobot/common/robot_devices/robots/manipulator.py @@ -45,7 +45,7 @@ def ensure_safe_goal_position( safe_goal_pos = present_pos + safe_diff if not torch.allclose(goal_pos, safe_goal_pos): - logging.warning( + logging.debug( "Relative goal position magnitude had to be clamped to be safe.\n" f" requested relative goal position target: {diff}\n" f" clamped relative goal position target: {safe_diff}" @@ -464,6 +464,14 @@ class ManipulatorRobot: before_fwrite_t = time.perf_counter() goal_pos = leader_pos[name] + # If specified, clip the goal positions within predefined bounds specified in the config of the robot + if self.config.joint_position_relative_bounds is not None: + goal_pos = torch.clamp( + goal_pos, + self.config.joint_position_relative_bounds["min"], + self.config.joint_position_relative_bounds["max"], + ) + # Cap goal position when too far away from present position. # Slower fps expected due to reading from the follower. if self.config.max_relative_target is not None: @@ -585,6 +593,14 @@ class ManipulatorRobot: goal_pos = action[from_idx:to_idx] from_idx = to_idx + # If specified, clip the goal positions within predefined bounds specified in the config of the robot + if self.config.joint_position_relative_bounds is not None: + goal_pos = torch.clamp( + goal_pos, + self.config.joint_position_relative_bounds["min"], + self.config.joint_position_relative_bounds["max"], + ) + # Cap goal position when too far away from present position. # Slower fps expected due to reading from the follower. if self.config.max_relative_target is not None: diff --git a/lerobot/configs/policy/hilserl_classifier.yaml b/lerobot/configs/policy/hilserl_classifier.yaml index f8137b69..21fd4a1a 100644 --- a/lerobot/configs/policy/hilserl_classifier.yaml +++ b/lerobot/configs/policy/hilserl_classifier.yaml @@ -4,7 +4,7 @@ defaults: - _self_ seed: 13 -dataset_repo_id: aractingi/pick_place_lego_cube_1 +dataset_repo_id: aractingi/push_green_cube_hf_cropped_resized train_split_proportion: 0.8 # Required by logger @@ -24,7 +24,8 @@ training: eval_freq: 1 # How often to run validation (in epochs) save_freq: 1 # How often to save checkpoints (in epochs) save_checkpoint: true - image_keys: ["observation.images.top", "observation.images.wrist"] + # image_keys: ["observation.images.top", "observation.images.wrist"] + image_keys: ["observation.images.laptop", "observation.images.phone"] label_key: "next.reward" eval: @@ -32,7 +33,7 @@ eval: num_samples_to_log: 30 # Number of validation samples to log in the table policy: - name: "hilserl/classifier/pick_place_lego_cube_1" + name: "hilserl/classifier/push_green_cube_hf_cropped_resized" #"hilserl/classifier/pick_place_lego_cube_1" model_name: "facebook/convnext-base-224" model_type: "cnn" num_cameras: 2 # Has to be len(training.image_keys) diff --git a/lerobot/configs/robot/so100.yaml b/lerobot/configs/robot/so100.yaml index 0978de64..d57ae721 100644 --- a/lerobot/configs/robot/so100.yaml +++ b/lerobot/configs/robot/so100.yaml @@ -14,6 +14,9 @@ calibration_dir: .cache/calibration/so100 # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as # the number of motors in your follower arms. max_relative_target: null +joint_position_relative_bounds: + min: [-88.50586, 23.81836, 0.87890625, -32.16797, 78.66211, 0.53691274] + max: [84.55078, 187.11914, 145.98633, 101.60156, 146.60156, 88.18792] leader_arms: main: diff --git a/lerobot/scripts/server/find_joint_limits.py b/lerobot/scripts/server/find_joint_limits.py new file mode 100644 index 00000000..6ec9d89f --- /dev/null +++ b/lerobot/scripts/server/find_joint_limits.py @@ -0,0 +1,64 @@ +import argparse +import time + +import cv2 +import numpy as np + +from lerobot.common.robot_devices.control_utils import is_headless +from lerobot.common.robot_devices.robots.factory import make_robot +from lerobot.common.utils.utils import init_hydra_config + + +def find_joint_bounds( + robot, + control_time_s=20, + display_cameras=False, +): + # TODO(rcadene): Add option to record logs + if not robot.is_connected: + robot.connect() + + control_time_s = float("inf") + + timestamp = 0 + start_episode_t = time.perf_counter() + pos_list = [] + while timestamp < control_time_s: + observation, action = robot.teleop_step(record_data=True) + + pos_list.append(robot.follower_arms["main"].read("Present_Position")) + + if display_cameras and not is_headless(): + image_keys = [key for key in observation if "image" in key] + for key in image_keys: + cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)) + cv2.waitKey(1) + + timestamp = time.perf_counter() - start_episode_t + if timestamp > 60: + max = np.max(np.stack(pos_list), 0) + min = np.min(np.stack(pos_list), 0) + print(f"Max angle position per joint {max}") + print(f"Min angle position per joint {min}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--robot-path", + type=str, + default="lerobot/configs/robot/koch.yaml", + help="Path to robot yaml file used to instantiate the robot using `make_robot` factory function.", + ) + parser.add_argument( + "--robot-overrides", + type=str, + nargs="*", + help="Any key=value arguments to override config values (use dots for.nested=overrides)", + ) + parser.add_argument("--control-time-s", type=float, default=20, help="Maximum episode length in seconds") + args = parser.parse_args() + robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides) + + robot = make_robot(robot_cfg) + find_joint_bounds(robot, control_time_s=args.control_time_s) diff --git a/lerobot/scripts/server/gym_manipulator.py b/lerobot/scripts/server/gym_manipulator.py new file mode 100644 index 00000000..40dc2784 --- /dev/null +++ b/lerobot/scripts/server/gym_manipulator.py @@ -0,0 +1,697 @@ +import argparse +import logging +import time +from threading import Lock +from typing import Annotated, Any, Callable, Dict, Optional, Tuple + +import cv2 +import gymnasium as gym +import numpy as np +import torch +import torchvision.transforms.functional as F # noqa: N812 + +from lerobot.common.envs.utils import preprocess_observation +from lerobot.common.robot_devices.control_utils import busy_wait, is_headless, reset_follower_position +from lerobot.common.robot_devices.robots.factory import make_robot +from lerobot.common.utils.utils import init_hydra_config, log_say + +logging.basicConfig(level=logging.INFO) + + +class HILSerlRobotEnv(gym.Env): + """ + Gym-like environment wrapper for robot policy evaluation. + + This wrapper provides a consistent interface for interacting with the robot, + following the OpenAI Gym environment conventions. + """ + + def __init__( + self, + robot, + use_delta_action_space: bool = True, + delta: float | None = None, + display_cameras=False, + ): + """ + Initialize the robot environment. + + Args: + robot: The robot interface object + reward_classifier: Optional reward classifier + fps: Frames per second for control + control_time_s: Total control time for each episode + display_cameras: Whether to display camera feeds + output_normalization_params_action: Bound parameters for the action space + delta: The delta for the relative joint position action space + """ + super().__init__() + + self.robot = robot + self.display_cameras = display_cameras + + # connect robot + if not self.robot.is_connected: + self.robot.connect() + + self.initial_follower_position = robot.follower_arms["main"].read("Present_Position") + + # Episode tracking + self.current_step = 0 + self.episode_data = None + + self.delta = delta + self.use_delta_action_space = use_delta_action_space + self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position") + + self.relative_bounds_size = ( + self.robot.config.joint_position_relative_bounds["max"] + - self.robot.config.joint_position_relative_bounds["min"] + ) + + self.delta_relative_bounds_size = self.relative_bounds_size * self.delta + + self.robot.config.max_relative_target = self.delta_relative_bounds_size.float() + + # Dynamically determine observation and action spaces + self._setup_spaces() + + def _setup_spaces(self): + """ + Dynamically determine observation and action spaces based on robot capabilities. + + This method should be customized based on the specific robot's observation + and action representations. + """ + # Example space setup - you'll need to adapt this to your specific robot + example_obs = self.robot.capture_observation() + + # Observation space (assuming image-based observations) + image_keys = [key for key in example_obs if "image" in key] + state_keys = [key for key in example_obs if "image" not in key] + observation_spaces = { + key: gym.spaces.Box(low=0, high=255, shape=example_obs[key].shape, dtype=np.uint8) + for key in image_keys + } + observation_spaces["observation.state"] = gym.spaces.Dict( + { + key: gym.spaces.Box(low=0, high=10, shape=example_obs[key].shape, dtype=np.float32) + for key in state_keys + } + ) + + self.observation_space = gym.spaces.Dict(observation_spaces) + + # Action space (assuming joint positions) + action_dim = len(self.robot.follower_arms["main"].read("Present_Position")) + if self.use_delta_action_space: + action_space_robot = gym.spaces.Box( + low=-self.relative_bounds_size.cpu().numpy(), + high=self.relative_bounds_size.cpu().numpy(), + shape=(action_dim,), + dtype=np.float32, + ) + else: + action_space_robot = gym.spaces.Box( + low=self.robot.config.joint_position_relative_bounds["min"].cpu().numpy(), + high=self.robot.config.joint_position_relative_bounds["max"].cpu().numpy(), + shape=(action_dim,), + dtype=np.float32, + ) + + self.action_space = gym.spaces.Tuple( + ( + action_space_robot, + gym.spaces.Discrete(2), + ), + ) + + def reset(self, seed=None, options=None) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]: + """ + Reset the environment to initial state. + + Returns: + observation (dict): Initial observation + info (dict): Additional information + """ + super().reset(seed=seed, options=options) + + # Capture initial observation + observation = self.robot.capture_observation() + + # Reset tracking variables + self.current_step = 0 + self.episode_data = None + + return observation, {"initial_position": self.initial_follower_position} + + def step( + self, action: Tuple[np.ndarray, bool] + ) -> Tuple[Dict[str, np.ndarray], float, bool, bool, Dict[str, Any]]: + """ + Take a step in the environment. + + Args: + action tuple(np.ndarray, bool): + Policy action to be executed on the robot and boolean to determine + whether to choose policy action or expert action. + + Returns: + observation (dict): Next observation + reward (float): Reward for this step + terminated (bool): Whether the episode has terminated + truncated (bool): Whether the episode was truncated + info (dict): Additional information + """ + # The actions recieved are the in form of a tuple containing the policy action and an intervention bool + # The boolean inidicated whether we will use the expert's actions (through teleoperation) or the policy actions + policy_action, intervention_bool = action + teleop_action = None + self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position") + if isinstance(policy_action, torch.Tensor): + policy_action = policy_action.cpu().numpy() + olicy_action = np.clip(policy_action, self.action_space[0].low, self.action_space[0].high) + if not intervention_bool: + if self.use_delta_action_space: + target_joint_positions = self.current_joint_positions + self.delta * policy_action + else: + target_joint_positions = policy_action + self.robot.send_action(torch.from_numpy(target_joint_positions)) + observation = self.robot.capture_observation() + else: + observation, teleop_action = self.robot.teleop_step(record_data=True) + teleop_action = teleop_action["action"] # teleop step returns torch tensors but in a dict + + # teleop actions are returned in absolute joint space + # If we are using a relative joint position action space, + # there will be a mismatch between the spaces of the policy and teleop actions + # Solution is to transform the teleop actions into relative space. + # teleop relative action is: + if self.use_delta_action_space: + teleop_action = teleop_action - self.current_joint_positions + if torch.any(teleop_action < -self.delta_relative_bounds_size * self.delta) and torch.any( + teleop_action > self.delta_relative_bounds_size + ): + print( + f"relative teleop delta exceeded bounds {self.delta_relative_bounds_size}, teleop_action {teleop_action}\n" + f"lower bounds condition {teleop_action < -self.delta_relative_bounds_size}\n" + f"upper bounds condition {teleop_action > self.delta_relative_bounds_size}" + ) + teleop_action = torch.clamp( + teleop_action, -self.delta_relative_bounds_size, self.delta_relative_bounds_size + ) + + self.current_step += 1 + + reward = 0.0 + terminated = False + truncated = False + + return ( + observation, + reward, + terminated, + truncated, + {"action_intervention": teleop_action, "is_intervention": teleop_action is not None}, + ) + + def render(self): + """ + Render the environment (in this case, display camera feeds). + """ + import cv2 + + observation = self.robot.capture_observation() + image_keys = [key for key in observation if "image" in key] + + for key in image_keys: + cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)) + + cv2.waitKey(1) + + def close(self): + """ + Close the environment and disconnect the robot. + """ + if self.robot.is_connected: + self.robot.disconnect() + + +class ActionRepeatWrapper(gym.Wrapper): + def __init__(self, env, nb_repeat: int = 1): + super().__init__(env) + self.nb_repeat = nb_repeat + + def step(self, action): + for _ in range(self.nb_repeat): + obs, reward, done, truncated, info = self.env.step(action) + if done or truncated: + break + return obs, reward, done, truncated, info + + +class RelativeJointPositionActionWrapper(gym.Wrapper): + def __init__( + self, + env: HILSerlRobotEnv, + # output_normalization_params_action: dict[str, list[float]], + delta: float = 0.1, + ): + super().__init__(env) + self.joint_positions = self.unwrapped.robot.follower_arms["main"].read("Present_Position") + self.delta = delta + if delta > 1: + raise ValueError("Delta should be less than 1") + + def step(self, action): + action_joint = action + self.joint_positions = self.unwrapped.robot.follower_arms["main"].read("Present_Position") + if isinstance(self.env.action_space, gym.spaces.Tuple): + action_joint = action[0] + joint_positions = self.joint_positions + (self.delta * action_joint) + # clip the joint positions to the joint limits with the action space + joint_positions = np.clip(joint_positions, self.action_space.low, self.action_space.high) + + if isinstance(self.env.action_space, gym.spaces.Tuple): + return self.env.step((joint_positions, action[1])) + + obs, reward, terminated, truncated, info = self.env.step(joint_positions) + if info["is_intervention"]: + # teleop actions are returned in absolute joint space + # If we are using a relative joint position action space, + # there will be a mismatch between the spaces of the policy and teleop actions + # Solution is to transform the teleop actions into relative space. + self.joint_positions = self.unwrapped.robot.follower_arms["main"].read("Present_Position") + teleop_action = info["action_intervention"] # teleop actions are in absolute joint space + relative_teleop_action = (teleop_action - self.joint_positions) / self.delta + info["action_intervention"] = relative_teleop_action + + return self.env.step(joint_positions) + + +class RewardWrapper(gym.Wrapper): + def __init__(self, env, reward_classifier: Optional[None], device: torch.device = "cuda"): + self.env = env + self.reward_classifier = torch.compile(reward_classifier) + self.device = device + + def step(self, action): + observation, _, terminated, truncated, info = self.env.step(action) + images = [ + observation[key].to(self.device, non_blocking=True) for key in observation if "image" in key + ] + start_time = time.perf_counter() + with torch.inference_mode(): + reward = ( + self.reward_classifier.predict_reward(images) if self.reward_classifier is not None else 0.0 + ) + # print(f"fps for reward classifier {1/(time.perf_counter() - start_time)}") + reward = reward.item() + # print(f"Reward from reward classifier {reward}") + return observation, reward, terminated, truncated, info + + def reset(self, seed=None, options=None): + return self.env.reset(seed=seed, options=options) + + +class TimeLimitWrapper(gym.Wrapper): + def __init__(self, env, control_time_s, fps): + self.env = env + self.control_time_s = control_time_s + self.fps = fps + + self.last_timestamp = 0.0 + self.episode_time_in_s = 0.0 + + def step(self, action): + obs, reward, terminated, truncated, info = self.env.step(action) + time_since_last_step = time.perf_counter() - self.last_timestamp + self.episode_time_in_s += time_since_last_step + self.last_timestamp = time.perf_counter() + + # check if last timestep took more time than the expected fps + if 1.0 / time_since_last_step < self.fps: + logging.warning(f"Current timestep is lower than the expected fps {self.fps}") + + if self.episode_time_in_s > self.control_time_s: + # Terminated = True + terminated = True + return obs, reward, terminated, truncated, info + + def reset(self, seed=None, options=None): + self.episode_time_in_s = 0.0 + self.last_timestamp = time.perf_counter() + return self.env.reset(seed=seed, options=options) + + +class ImageCropResizeWrapper(gym.Wrapper): + def __init__(self, env, crop_params_dict: Dict[str, Annotated[Tuple[int], 4]], resize_size=None): + self.env = env + self.crop_params_dict = crop_params_dict + print(f"obs_keys , {self.env.observation_space}") + print(f"crop params dict {crop_params_dict.keys()}") + for key_crop in crop_params_dict: + if key_crop not in self.env.observation_space.keys(): + raise ValueError(f"Key {key_crop} not in observation space") + for key in crop_params_dict: + top, left, height, width = crop_params_dict[key] + new_shape = (top + height, left + width) + self.observation_space[key] = gym.spaces.Box(low=0, high=255, shape=new_shape) + + self.resize_size = resize_size + if self.resize_size is None: + self.resize_size = (128, 128) + + def step(self, action): + obs, reward, terminated, truncated, info = self.env.step(action) + for k in self.crop_params_dict: + device = obs[k].device + if device == torch.device("mps:0"): + obs[k] = obs[k].cpu() + obs[k] = F.crop(obs[k], *self.crop_params_dict[k]) + obs[k] = F.resize(obs[k], self.resize_size) + obs[k] = obs[k].to(device) + # print(f"observation with key {k} with size {obs[k].size()}") + cv2.imshow(k, cv2.cvtColor(obs[k].cpu().squeeze(0).permute(1, 2, 0).numpy(), cv2.COLOR_RGB2BGR)) + cv2.waitKey(1) + return obs, reward, terminated, truncated, info + + +class ConvertToLeRobotObservation(gym.ObservationWrapper): + def __init__(self, env, device): + super().__init__(env) + self.device = device + + def observation(self, observation): + observation = preprocess_observation(observation) + + observation = {key: observation[key].to(self.device, non_blocking=True) for key in observation} + observation = {k: torch.tensor(v, device=self.device) for k, v in observation.items()} + return observation + + +class KeyboardInterfaceWrapper(gym.Wrapper): + def __init__(self, env): + super().__init__(env) + self.listener = None + self.events = { + "exit_early": False, + "pause_policy": False, + "reset_env": False, + "human_intervention_step": False, + } + self.event_lock = Lock() # Thread-safe access to events + self._init_keyboard_listener() + + def _init_keyboard_listener(self): + """Initialize keyboard listener if not in headless mode""" + + if is_headless(): + logging.warning( + "Headless environment detected. On-screen cameras display and keyboard inputs will not be available." + ) + return + try: + from pynput import keyboard + + def on_press(key): + with self.event_lock: + try: + if key == keyboard.Key.right or key == keyboard.Key.esc: + print("Right arrow key pressed. Exiting loop...") + self.events["exit_early"] = True + elif key == keyboard.Key.space: + if not self.events["pause_policy"]: + print( + "Space key pressed. Human intervention required.\n" + "Place the leader in similar pose to the follower and press space again." + ) + self.events["pause_policy"] = True + log_say("Human intervention stage. Get ready to take over.", play_sounds=True) + elif self.events["pause_policy"] and not self.events["human_intervention_step"]: + self.events["human_intervention_step"] = True + print("Space key pressed. Human intervention starting.") + log_say("Starting human intervention.", play_sounds=True) + else: + self.events["pause_policy"] = False + self.events["human_intervention_step"] = False + print("Space key pressed for a third time.") + log_say("Continuing with policy actions.", play_sounds=True) + except Exception as e: + print(f"Error handling key press: {e}") + + self.listener = keyboard.Listener(on_press=on_press) + self.listener.start() + except ImportError: + logging.warning("Could not import pynput. Keyboard interface will not be available.") + self.listener = None + + def step(self, action: Any) -> Tuple[Any, float, bool, bool, Dict]: + is_intervention = False + terminated_by_keyboard = False + + # Extract policy_action if needed + if isinstance(self.env.action_space, gym.spaces.Tuple): + policy_action = action[0] + + # Check the event flags without holding the lock for too long. + with self.event_lock: + if self.events["exit_early"]: + terminated_by_keyboard = True + # If we need to wait for human intervention, we note that outside the lock. + pause_policy = self.events["pause_policy"] + + if pause_policy: + # Now, wait for human_intervention_step without holding the lock + while True: + with self.event_lock: + if self.events["human_intervention_step"]: + is_intervention = True + break + time.sleep(0.1) # Check more frequently if desired + + # Execute the step in the underlying environment + obs, reward, terminated, truncated, info = self.env.step((policy_action, is_intervention)) + return obs, reward, terminated or terminated_by_keyboard, truncated, info + + def reset(self, **kwargs) -> Tuple[Any, Dict]: + """ + Reset the environment and clear any pending events + """ + with self.event_lock: + self.events = {k: False for k in self.events} + return self.env.reset(**kwargs) + + def close(self): + """ + Properly clean up the keyboard listener when the environment is closed + """ + if self.listener is not None: + self.listener.stop() + super().close() + + +class ResetWrapper(gym.Wrapper): + def __init__( + self, env: HILSerlRobotEnv, reset_fn: Optional[Callable[[], None]] = None, reset_time_s: float = 5 + ): + super().__init__(env) + self.reset_fn = reset_fn + self.reset_time_s = reset_time_s + + self.robot = self.unwrapped.robot + self.init_pos = self.unwrapped.initial_follower_position + + def reset(self, *, seed=None, options=None): + if self.reset_fn is not None: + self.reset_fn(self.env) + else: + log_say(f"Manually reset the environment for {self.reset_time_s} seconds.", play_sounds=True) + start_time = time.perf_counter() + while time.perf_counter() - start_time < self.reset_time_s: + self.robot.teleop_step() + + log_say("Manual reseting of the environment done.", play_sounds=True) + return super().reset(seed=seed, options=options) + + +def make_robot_env( + robot, + reward_classifier, + crop_params_dict=None, + fps=30, + control_time_s=20, + reset_follower_pos=True, + display_cameras=False, + device="cuda:0", + resize_size=None, + reset_time_s=10, + delta_action=0.1, + nb_repeats=1, + use_relative_joint_positions=False, +): + """ + Factory function to create the robot environment. + + Mimics gym.make() for consistent environment creation. + """ + env = HILSerlRobotEnv( + robot, + display_cameras=display_cameras, + delta=delta_action, + use_delta_action_space=use_relative_joint_positions, + ) + env = ConvertToLeRobotObservation(env, device) + if crop_params_dict is not None: + env = ImageCropResizeWrapper(env, crop_params_dict, resize_size=resize_size) + env = RewardWrapper(env, reward_classifier, device=device) + env = TimeLimitWrapper(env, control_time_s, fps) + # env = ActionRepeatWrapper(env, nb_repeat=nb_repeats) + env = KeyboardInterfaceWrapper(env) + env = ResetWrapper(env, reset_fn=None, reset_time_s=reset_time_s) + return env + + +def get_classifier(pretrained_path, config_path, device="mps"): + if pretrained_path is None or config_path is None: + return + + from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg + from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig + from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier + + cfg = init_hydra_config(config_path) + + classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg) + classifier_config.num_cameras = len(cfg.training.image_keys) # TODO automate these paths + model = Classifier(classifier_config) + model.load_state_dict(Classifier.from_pretrained(pretrained_path).state_dict()) + model = model.to(device) + return model + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--fps", type=int, default=30, help="control frequency") + parser.add_argument( + "--robot-path", + type=str, + default="lerobot/configs/robot/koch.yaml", + help="Path to robot yaml file used to instantiate the robot using `make_robot` factory function.", + ) + parser.add_argument( + "--robot-overrides", + type=str, + nargs="*", + help="Any key=value arguments to override config values (use dots for.nested=overrides)", + ) + parser.add_argument( + "-p", + "--pretrained-policy-name-or-path", + help=( + "Either the repo ID of a model hosted on the Hub or a path to a directory containing weights " + "saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch " + "(useful for debugging). This argument is mutually exclusive with `--config`." + ), + ) + parser.add_argument( + "--config", + help=( + "Path to a yaml config you want to use for initializing a policy from scratch (useful for " + "debugging). This argument is mutually exclusive with `--pretrained-policy-name-or-path` (`-p`)." + ), + ) + parser.add_argument( + "--display-cameras", help=("Whether to display the camera feed while the rollout is happening") + ) + parser.add_argument( + "--reward-classifier-pretrained-path", + type=str, + default=None, + help="Path to the pretrained classifier weights.", + ) + parser.add_argument( + "--reward-classifier-config-file", + type=str, + default=None, + help="Path to a yaml config file that is necessary to build the reward classifier model.", + ) + parser.add_argument("--control-time-s", type=float, default=20, help="Maximum episode length in seconds") + parser.add_argument("--reset-follower-pos", type=int, default=1, help="Reset follower between episodes") + args = parser.parse_args() + + robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides) + robot = make_robot(robot_cfg) + + reward_classifier = get_classifier( + args.reward_classifier_pretrained_path, args.reward_classifier_config_file + ) + + crop_parameters = { + "observation.images.laptop": (58, 89, 357, 455), + "observation.images.phone": (3, 4, 471, 633), + } + + user_relative_joint_positions = True + + env = make_robot_env( + robot, + reward_classifier, + crop_parameters, + args.fps, + args.control_time_s, + args.reset_follower_pos, + args.display_cameras, + device="mps", + resize_size=None, + reset_time_s=10, + delta_action=0.1, + nb_repeats=1, + use_relative_joint_positions=user_relative_joint_positions, + ) + + env.reset() + init_pos = env.unwrapped.initial_follower_position + + right_goal = init_pos.copy() + right_goal[0] += 50 + + left_goal = init_pos.copy() + left_goal[0] -= 50 + + pitch_angle = np.linspace(left_goal[0], right_goal[0], 1000) + + delta_angle = np.concatenate((-np.ones(50), np.ones(50))) * 100 + + while True: + action = np.zeros(len(init_pos)) + for i in range(len(delta_angle)): + start_loop_s = time.perf_counter() + action[0] = delta_angle[i] + obs, reward, terminated, truncated, info = env.step((torch.from_numpy(action), False)) + if terminated or truncated: + env.reset() + + dt_s = time.perf_counter() - start_loop_s + busy_wait(1 / args.fps - dt_s) + # action = np.zeros(len(init_pos)) if user_relative_joint_positions else init_pos + # for i in range(len(pitch_angle)): + # if user_relative_joint_positions: + # action[0] = delta_angle[i] + # else: + # action[0] = pitch_angle[i] + # obs, reward, terminated, truncated, info = env.step((torch.from_numpy(action), False)) + # if terminated or truncated: + # logging.info("Max control time reached, reset environment.") + # env.reset() + + # for i in reversed(range(len(pitch_angle))): + # if user_relative_joint_positions: + # action[0] = delta_angle[i] + # else: + # action[0] = pitch_angle[i] + # obs, reward, terminated, truncated, info = env.step((torch.from_numpy(action), False)) + + # if terminated or truncated: + # logging.info("Max control time reached, reset environment.") + # env.reset()