From e0527b4a6bf41650cd25389f0cf82d4f29ccf68e Mon Sep 17 00:00:00 2001 From: Michel Aractingi <michel.aractingi@huggingface.co> Date: Tue, 4 Feb 2025 17:41:14 +0000 Subject: [PATCH] Added additional wrappers for the environment: Action repeat, keyboard interface, reset wrapper Tested the reset mechanism and keyboard interface and the convert wrapper on the robots. Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com> --- lerobot/scripts/server/actor_server.py | 5 +- lerobot/scripts/server/learner_server.py | 3 + .../server/wrappers/gym_manipulator.py | 306 +++++++++++++++--- 3 files changed, 262 insertions(+), 52 deletions(-) diff --git a/lerobot/scripts/server/actor_server.py b/lerobot/scripts/server/actor_server.py index 952590e8..be5c0818 100644 --- a/lerobot/scripts/server/actor_server.py +++ b/lerobot/scripts/server/actor_server.py @@ -273,6 +273,9 @@ def act_with_policy(cfg: DictConfig): # TODO (michel-aractingi): Label the reward # if config.label_reward_on_actor: # reward = reward_classifier(obs) + if info["is_intervention"]: + # TODO: Check the shape + action = info["action_intervention"] list_transition_to_send_to_learner.append( Transition( @@ -281,7 +284,7 @@ def act_with_policy(cfg: DictConfig): reward=reward, next_state=next_obs, done=done, - complementary_info=None, + complementary_info=info, # TODO Handle information for the transition, is_demonstraction: bool ) ) diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index 6dd33fed..5766c69c 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -332,6 +332,9 @@ def add_actor_information_and_train( transition = move_transition_to_device(transition, device=device) replay_buffer.add(**transition) + if transition.get("complementary_info", {}).get("is_interaction"): + offline_replay_buffer.add(**transition) + while not interaction_message_queue.empty(): interaction_message = interaction_message_queue.get() # If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging diff --git a/lerobot/scripts/server/wrappers/gym_manipulator.py b/lerobot/scripts/server/wrappers/gym_manipulator.py index 749d4358..f95b7731 100644 --- a/lerobot/scripts/server/wrappers/gym_manipulator.py +++ b/lerobot/scripts/server/wrappers/gym_manipulator.py @@ -1,18 +1,18 @@ import argparse import logging import time -from typing import Annotated, Any, Dict, Optional, Tuple +from threading import Lock +from typing import Annotated, Any, Callable, Dict, Optional, Tuple import gymnasium as gym import numpy as np import torch -import torch.nn as nn import torchvision.transforms.functional as F # noqa: N812 from lerobot.common.envs.utils import preprocess_observation -from lerobot.common.robot_devices.control_utils import reset_follower_position +from lerobot.common.robot_devices.control_utils import is_headless, reset_follower_position from lerobot.common.robot_devices.robots.factory import make_robot -from lerobot.common.utils.utils import init_hydra_config +from lerobot.common.utils.utils import init_hydra_config, log_say logging.basicConfig(level=logging.INFO) @@ -28,7 +28,6 @@ class HILSerlRobotEnv(gym.Env): def __init__( self, robot, - reset_follower_position=True, display_cameras=False, ): """ @@ -53,8 +52,7 @@ class HILSerlRobotEnv(gym.Env): # Dynamically determine observation and action spaces self._setup_spaces() - self._initial_follower_position = robot.follower_arms["main"].read("Present_Position") - self.reset_follower_position = reset_follower_position + self.initial_follower_position = robot.follower_arms["main"].read("Present_Position") # Episode tracking self.current_step = 0 @@ -105,9 +103,6 @@ class HILSerlRobotEnv(gym.Env): """ super().reset(seed=seed, options=options) - if self.reset_follower_position: - reset_follower_position(self.robot, target_position=self._initial_follower_position) - # Capture initial observation observation = self.robot.capture_observation() @@ -115,7 +110,7 @@ class HILSerlRobotEnv(gym.Env): self.current_step = 0 self.episode_data = None - return observation, {} + return observation, {"initial_position": self.initial_follower_position} def step( self, action: Tuple[np.ndarray, bool] @@ -140,7 +135,7 @@ class HILSerlRobotEnv(gym.Env): policy_action, intervention_bool = action teleop_action = None if not intervention_bool: - self.robot.send_action(policy_action.cpu().numpy()) + self.robot.send_action(policy_action.cpu()) observation = self.robot.capture_observation() else: observation, teleop_action = self.robot.teleop_step(record_data=True) @@ -152,7 +147,13 @@ class HILSerlRobotEnv(gym.Env): terminated = False truncated = False - return observation, reward, terminated, truncated, {"action": teleop_action} + return ( + observation, + reward, + terminated, + truncated, + {"action_intervention": teleop_action, "is_intervention": teleop_action is not None}, + ) def render(self): """ @@ -176,37 +177,51 @@ class HILSerlRobotEnv(gym.Env): self.robot.disconnect() -class HILSerlTimeLimitWrapper(gym.Wrapper): - def __init__(self, env, control_time_s, fps): - self.env = env - self.control_time_s = control_time_s - self.fps = fps - - self.last_timestamp = 0.0 - self.episode_time_in_s = 0.0 +class ActionRepeatWrapper(gym.Wrapper): + def __init__(self, env, nb_repeat: int = 1): + super().__init__(env) + self.nb_repeat = nb_repeat def step(self, action): - ret = self.env.step(action) - time_since_last_step = time.perf_counter() - self.last_timestamp - self.episode_time_in_s += time_since_last_step - self.last_timestamp = time.perf_counter() - - # check if last timestep took more time than the expected fps - if 1.0 / time_since_last_step > self.fps: - logging.warning(f"Current timestep exceeded expected fps {self.fps}") - - if self.episode_time_in_s > self.control_time_s: - # Terminated = True - ret[2] = True - return ret - - def reset(self, seed=None, options=None): - self.episode_time_in_s = 0.0 - self.last_timestamp = time.perf_counter() - return self.env.reset(seed, options=None) + for _ in range(self.nb_repeat): + obs, reward, done, truncated, info = self.env.step(action) + if done or truncated: + break + return obs, reward, done, truncated, info -class HILSerlRewardWrapper(gym.Wrapper): +class RelativeJointPositionActionWrapper(gym.Wrapper): + def __init__(self, env: HILSerlRobotEnv, delta: float = 0.1): + super().__init__(env) + self.joint_positions = self.unwrapped.robot.follower_arms["main"].read("Present_Position") + self.delta = delta + + def step(self, action): + action_joint = action + self.joint_positions = self.unwrapped.robot.follower_arms["main"].read("Present_Position") + if isinstance(self.env.action_space, gym.spaces.Tuple): + action_joint = action[0] + joint_positions = self.joint_positions + (self.delta * action_joint) + # clip the joint positions to the joint limits with the action space + joint_positions = np.clip(joint_positions, self.action_space.low, self.action_space.high) + + if isinstance(self.env.action_space, gym.spaces.Tuple): + return self.env.step((joint_positions, action[1])) + + obs, reward, terminated, truncated, info = self.env.step(joint_positions) + if info["is_intervention"]: + # teleop actions are returned in absolute joint space + # If we are using a relative joint position action space, + # there will be a mismatch between the spaces of the policy and teleop actions + # Solution is to transform the teleop actions into relative space. + teleop_action = info["action_intervention"] # teleop actions are in absolute joint space + relative_teleop_action = (teleop_action - self.joint_positions) / self.delta + info["action_intervention"] = relative_teleop_action + + return self.env.step(joint_positions) + + +class RewardWrapper(gym.Wrapper): def __init__(self, env, reward_classifier: Optional[None], device: torch.device = "cuda"): self.env = env self.reward_classifier = reward_classifier @@ -225,7 +240,37 @@ class HILSerlRewardWrapper(gym.Wrapper): return self.env.reset(seed=seed, options=options) -class HILSerlImageCropResizeWrapper(gym.Wrapper): +class TimeLimitWrapper(gym.Wrapper): + def __init__(self, env, control_time_s, fps): + self.env = env + self.control_time_s = control_time_s + self.fps = fps + + self.last_timestamp = 0.0 + self.episode_time_in_s = 0.0 + + def step(self, action): + obs, reward, terminated, truncated, info = self.env.step(action) + time_since_last_step = time.perf_counter() - self.last_timestamp + self.episode_time_in_s += time_since_last_step + self.last_timestamp = time.perf_counter() + + # check if last timestep took more time than the expected fps + if 1.0 / time_since_last_step < self.fps: + logging.warning(f"Current timestep exceeded expected fps {self.fps}") + + if self.episode_time_in_s > self.control_time_s: + # Terminated = True + terminated = True + return obs, reward, terminated, truncated, info + + def reset(self, seed=None, options=None): + self.episode_time_in_s = 0.0 + self.last_timestamp = time.perf_counter() + return self.env.reset(seed=seed, options=options) + + +class ImageCropResizeWrapper(gym.Wrapper): def __init__(self, env, crop_params_dict: Dict[str, Annotated[Tuple[int], 4]], resize_size=None): self.env = env self.crop_params_dict = crop_params_dict @@ -260,6 +305,131 @@ class ConvertToLeRobotObservation(gym.ObservationWrapper): return observation +class KeyboardInterfaceWrapper(gym.Wrapper): + def __init__(self, env): + super().__init__(env) + self.listener = None + self.events = { + "exit_early": False, + "pause_policy": False, + "reset_env": False, + "human_intervention_step": False, + } + self.event_lock = Lock() # Thread-safe access to events + self._init_keyboard_listener() + + def _init_keyboard_listener(self): + """Initialize keyboard listener if not in headless mode""" + + if is_headless(): + logging.warning( + "Headless environment detected. On-screen cameras display and keyboard inputs will not be available." + ) + return + try: + from pynput import keyboard + + def on_press(key): + with self.event_lock: + try: + if key == keyboard.Key.right or key == keyboard.Key.esc: + print("Right arrow key pressed. Exiting loop...") + self.events["exit_early"] = True + elif key == keyboard.Key.space: + if not self.events["pause_policy"]: + print( + "Space key pressed. Human intervention required.\n" + "Place the leader in similar pose to the follower and press space again." + ) + self.events["pause_policy"] = True + log_say("Human intervention stage. Get ready to take over.", play_sounds=True) + elif self.events["pause_policy"] and not self.events["human_intervention_step"]: + self.events["human_intervention_step"] = True + print("Space key pressed. Human intervention starting.") + log_say("Starting human intervention.", play_sounds=True) + else: + self.events["pause_policy"] = False + self.events["human_intervention_step"] = False + print("Space key pressed for a third time.") + log_say("Continuing with policy actions.", play_sounds=True) + except Exception as e: + print(f"Error handling key press: {e}") + + self.listener = keyboard.Listener(on_press=on_press) + self.listener.start() + except ImportError: + logging.warning("Could not import pynput. Keyboard interface will not be available.") + self.listener = None + + def step(self, action: Any) -> Tuple[Any, float, bool, bool, Dict]: + is_intervention = False + terminated_by_keyboard = False + + # Extract policy_action if needed + if isinstance(self.env.action_space, gym.spaces.Tuple): + policy_action = action[0] + + # Check the event flags without holding the lock for too long. + with self.event_lock: + if self.events["exit_early"]: + terminated_by_keyboard = True + # If we need to wait for human intervention, we note that outside the lock. + pause_policy = self.events["pause_policy"] + + if pause_policy: + # Now, wait for human_intervention_step without holding the lock + while True: + with self.event_lock: + if self.events["human_intervention_step"]: + is_intervention = True + break + time.sleep(0.1) # Check more frequently if desired + + # Execute the step in the underlying environment + obs, reward, terminated, truncated, info = self.env.step((policy_action, is_intervention)) + return obs, reward, terminated or terminated_by_keyboard, truncated, info + + def reset(self, **kwargs) -> Tuple[Any, Dict]: + """ + Reset the environment and clear any pending events + """ + with self.event_lock: + self.events = {k: False for k in self.events} + return self.env.reset(**kwargs) + + def close(self): + """ + Properly clean up the keyboard listener when the environment is closed + """ + if self.listener is not None: + self.listener.stop() + super().close() + + +class ResetWrapper(gym.Wrapper): + def __init__( + self, env: HILSerlRobotEnv, reset_fn: Optional[Callable[[], None]] = None, reset_time_s: float = 5 + ): + super().__init__(env) + self.reset_fn = reset_fn + self.reset_time_s = reset_time_s + + self.robot = self.unwrapped.robot + self.init_pos = self.unwrapped.initial_follower_position + + def reset(self, *, seed=None, options=None): + if self.reset_fn is not None: + self.reset_fn(self.env) + else: + log_say(f"Manually reset the environment for {self.reset_time_s} seconds.", play_sounds=True) + start_time = time.perf_counter() + while time.perf_counter() - start_time < self.reset_time_s: + self.robot.teleop_step() + + log_say("Manual reseting of the environment done.", play_sounds=True) + return super().reset(seed=seed, options=options) + + def make_robot_env( robot, reward_classifier, @@ -270,18 +440,27 @@ def make_robot_env( display_cameras=False, device="cuda:0", resize_size=None, + reset_time_s=10, + delta_action=0.1, + nb_repeats=1, + use_relative_joint_positions=False, ): """ Factory function to create the robot environment. Mimics gym.make() for consistent environment creation. """ - env = HILSerlRobotEnv(robot, reset_follower_pos, display_cameras) + env = HILSerlRobotEnv(robot, display_cameras) env = ConvertToLeRobotObservation(env, device) - if crop_params_dict is not None: - env = HILSerlImageCropResizeWrapper(env, crop_params_dict, resize_size=resize_size) - env = HILSerlRewardWrapper(env, reward_classifier) - env = HILSerlTimeLimitWrapper(env, control_time_s, fps) + # if crop_params_dict is not None: + # env = ImageCropResizeWrapper(env, crop_params_dict, resize_size=resize_size) + # env = RewardWrapper(env, reward_classifier) + env = TimeLimitWrapper(env, control_time_s, fps) + # if use_relative_joint_positions: + # env = RelativeJointPositionActionWrapper(env, delta=delta_action) + # env = ActionRepeatWrapper(env, nb_repeat=nb_repeats) + env = KeyboardInterfaceWrapper(env) + env = ResetWrapper(env, reset_fn=None, reset_time_s=reset_time_s) return env @@ -369,12 +548,37 @@ if __name__ == "__main__": args.reset_follower_pos, args.display_cameras, device="mps", + resize_size=None, + reset_time_s=10, + delta_action=0.1, + nb_repeats=1, + use_relative_joint_positions=False, ) env.reset() + init_pos = env.unwrapped.initial_follower_position + goal_pos = init_pos + + right_goal = init_pos.copy() + right_goal[0] += 50 + + left_goal = init_pos.copy() + left_goal[0] -= 50 + + # Michel is a beast + pitch_angle = np.linspace(left_goal[0], right_goal[0], 1000) + while True: - intervention_action = (None, True) - obs, reward, terminated, truncated, info = env.step(intervention_action) - if terminated or truncated: - logging.info("Max control time reached, reset environment.") - env.reset() + for i in range(len(pitch_angle)): + goal_pos[0] = pitch_angle[i] + obs, reward, terminated, truncated, info = env.step((torch.from_numpy(goal_pos), False)) + if terminated or truncated: + logging.info("Max control time reached, reset environment.") + env.reset() + + for i in reversed(range(len(pitch_angle))): + goal_pos[0] = pitch_angle[i] + obs, reward, terminated, truncated, info = env.step((torch.from_numpy(goal_pos), False)) + if terminated or truncated: + logging.info("Max control time reached, reset environment.") + env.reset()