From e0527b4a6bf41650cd25389f0cf82d4f29ccf68e Mon Sep 17 00:00:00 2001
From: Michel Aractingi <michel.aractingi@huggingface.co>
Date: Tue, 4 Feb 2025 17:41:14 +0000
Subject: [PATCH] Added additional wrappers for the environment: Action repeat,
 keyboard interface, reset wrapper Tested the reset mechanism and keyboard
 interface and the convert wrapper on the robots.

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
---
 lerobot/scripts/server/actor_server.py        |   5 +-
 lerobot/scripts/server/learner_server.py      |   3 +
 .../server/wrappers/gym_manipulator.py        | 306 +++++++++++++++---
 3 files changed, 262 insertions(+), 52 deletions(-)

diff --git a/lerobot/scripts/server/actor_server.py b/lerobot/scripts/server/actor_server.py
index 952590e8..be5c0818 100644
--- a/lerobot/scripts/server/actor_server.py
+++ b/lerobot/scripts/server/actor_server.py
@@ -273,6 +273,9 @@ def act_with_policy(cfg: DictConfig):
         # TODO (michel-aractingi): Label the reward
         # if config.label_reward_on_actor:
         #     reward = reward_classifier(obs)
+        if info["is_intervention"]:
+            # TODO: Check the shape
+            action = info["action_intervention"]
 
         list_transition_to_send_to_learner.append(
             Transition(
@@ -281,7 +284,7 @@ def act_with_policy(cfg: DictConfig):
                 reward=reward,
                 next_state=next_obs,
                 done=done,
-                complementary_info=None,
+                complementary_info=info,  # TODO Handle information for the transition, is_demonstraction: bool
             )
         )
 
diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py
index 6dd33fed..5766c69c 100644
--- a/lerobot/scripts/server/learner_server.py
+++ b/lerobot/scripts/server/learner_server.py
@@ -332,6 +332,9 @@ def add_actor_information_and_train(
                 transition = move_transition_to_device(transition, device=device)
                 replay_buffer.add(**transition)
 
+                if transition.get("complementary_info", {}).get("is_interaction"):
+                    offline_replay_buffer.add(**transition)
+
         while not interaction_message_queue.empty():
             interaction_message = interaction_message_queue.get()
             # If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging
diff --git a/lerobot/scripts/server/wrappers/gym_manipulator.py b/lerobot/scripts/server/wrappers/gym_manipulator.py
index 749d4358..f95b7731 100644
--- a/lerobot/scripts/server/wrappers/gym_manipulator.py
+++ b/lerobot/scripts/server/wrappers/gym_manipulator.py
@@ -1,18 +1,18 @@
 import argparse
 import logging
 import time
-from typing import Annotated, Any, Dict, Optional, Tuple
+from threading import Lock
+from typing import Annotated, Any, Callable, Dict, Optional, Tuple
 
 import gymnasium as gym
 import numpy as np
 import torch
-import torch.nn as nn
 import torchvision.transforms.functional as F  # noqa: N812
 
 from lerobot.common.envs.utils import preprocess_observation
-from lerobot.common.robot_devices.control_utils import reset_follower_position
+from lerobot.common.robot_devices.control_utils import is_headless, reset_follower_position
 from lerobot.common.robot_devices.robots.factory import make_robot
-from lerobot.common.utils.utils import init_hydra_config
+from lerobot.common.utils.utils import init_hydra_config, log_say
 
 logging.basicConfig(level=logging.INFO)
 
@@ -28,7 +28,6 @@ class HILSerlRobotEnv(gym.Env):
     def __init__(
         self,
         robot,
-        reset_follower_position=True,
         display_cameras=False,
     ):
         """
@@ -53,8 +52,7 @@ class HILSerlRobotEnv(gym.Env):
         # Dynamically determine observation and action spaces
         self._setup_spaces()
 
-        self._initial_follower_position = robot.follower_arms["main"].read("Present_Position")
-        self.reset_follower_position = reset_follower_position
+        self.initial_follower_position = robot.follower_arms["main"].read("Present_Position")
 
         # Episode tracking
         self.current_step = 0
@@ -105,9 +103,6 @@ class HILSerlRobotEnv(gym.Env):
         """
         super().reset(seed=seed, options=options)
 
-        if self.reset_follower_position:
-            reset_follower_position(self.robot, target_position=self._initial_follower_position)
-
         # Capture initial observation
         observation = self.robot.capture_observation()
 
@@ -115,7 +110,7 @@ class HILSerlRobotEnv(gym.Env):
         self.current_step = 0
         self.episode_data = None
 
-        return observation, {}
+        return observation, {"initial_position": self.initial_follower_position}
 
     def step(
         self, action: Tuple[np.ndarray, bool]
@@ -140,7 +135,7 @@ class HILSerlRobotEnv(gym.Env):
         policy_action, intervention_bool = action
         teleop_action = None
         if not intervention_bool:
-            self.robot.send_action(policy_action.cpu().numpy())
+            self.robot.send_action(policy_action.cpu())
             observation = self.robot.capture_observation()
         else:
             observation, teleop_action = self.robot.teleop_step(record_data=True)
@@ -152,7 +147,13 @@ class HILSerlRobotEnv(gym.Env):
         terminated = False
         truncated = False
 
-        return observation, reward, terminated, truncated, {"action": teleop_action}
+        return (
+            observation,
+            reward,
+            terminated,
+            truncated,
+            {"action_intervention": teleop_action, "is_intervention": teleop_action is not None},
+        )
 
     def render(self):
         """
@@ -176,37 +177,51 @@ class HILSerlRobotEnv(gym.Env):
             self.robot.disconnect()
 
 
-class HILSerlTimeLimitWrapper(gym.Wrapper):
-    def __init__(self, env, control_time_s, fps):
-        self.env = env
-        self.control_time_s = control_time_s
-        self.fps = fps
-
-        self.last_timestamp = 0.0
-        self.episode_time_in_s = 0.0
+class ActionRepeatWrapper(gym.Wrapper):
+    def __init__(self, env, nb_repeat: int = 1):
+        super().__init__(env)
+        self.nb_repeat = nb_repeat
 
     def step(self, action):
-        ret = self.env.step(action)
-        time_since_last_step = time.perf_counter() - self.last_timestamp
-        self.episode_time_in_s += time_since_last_step
-        self.last_timestamp = time.perf_counter()
-
-        # check if last timestep took more time than the expected fps
-        if 1.0 / time_since_last_step > self.fps:
-            logging.warning(f"Current timestep exceeded expected fps {self.fps}")
-
-        if self.episode_time_in_s > self.control_time_s:
-            # Terminated = True
-            ret[2] = True
-        return ret
-
-    def reset(self, seed=None, options=None):
-        self.episode_time_in_s = 0.0
-        self.last_timestamp = time.perf_counter()
-        return self.env.reset(seed, options=None)
+        for _ in range(self.nb_repeat):
+            obs, reward, done, truncated, info = self.env.step(action)
+            if done or truncated:
+                break
+        return obs, reward, done, truncated, info
 
 
-class HILSerlRewardWrapper(gym.Wrapper):
+class RelativeJointPositionActionWrapper(gym.Wrapper):
+    def __init__(self, env: HILSerlRobotEnv, delta: float = 0.1):
+        super().__init__(env)
+        self.joint_positions = self.unwrapped.robot.follower_arms["main"].read("Present_Position")
+        self.delta = delta
+
+    def step(self, action):
+        action_joint = action
+        self.joint_positions = self.unwrapped.robot.follower_arms["main"].read("Present_Position")
+        if isinstance(self.env.action_space, gym.spaces.Tuple):
+            action_joint = action[0]
+        joint_positions = self.joint_positions + (self.delta * action_joint)
+        # clip the joint positions to the joint limits with the action space
+        joint_positions = np.clip(joint_positions, self.action_space.low, self.action_space.high)
+
+        if isinstance(self.env.action_space, gym.spaces.Tuple):
+            return self.env.step((joint_positions, action[1]))
+
+        obs, reward, terminated, truncated, info = self.env.step(joint_positions)
+        if info["is_intervention"]:
+            # teleop actions are returned in absolute joint space
+            # If we are using a relative joint position action space,
+            # there will be a mismatch between the spaces of the policy and teleop actions
+            # Solution is to transform the teleop actions into relative space.
+            teleop_action = info["action_intervention"]  # teleop actions are in absolute joint space
+            relative_teleop_action = (teleop_action - self.joint_positions) / self.delta
+            info["action_intervention"] = relative_teleop_action
+
+        return self.env.step(joint_positions)
+
+
+class RewardWrapper(gym.Wrapper):
     def __init__(self, env, reward_classifier: Optional[None], device: torch.device = "cuda"):
         self.env = env
         self.reward_classifier = reward_classifier
@@ -225,7 +240,37 @@ class HILSerlRewardWrapper(gym.Wrapper):
         return self.env.reset(seed=seed, options=options)
 
 
-class HILSerlImageCropResizeWrapper(gym.Wrapper):
+class TimeLimitWrapper(gym.Wrapper):
+    def __init__(self, env, control_time_s, fps):
+        self.env = env
+        self.control_time_s = control_time_s
+        self.fps = fps
+
+        self.last_timestamp = 0.0
+        self.episode_time_in_s = 0.0
+
+    def step(self, action):
+        obs, reward, terminated, truncated, info = self.env.step(action)
+        time_since_last_step = time.perf_counter() - self.last_timestamp
+        self.episode_time_in_s += time_since_last_step
+        self.last_timestamp = time.perf_counter()
+
+        # check if last timestep took more time than the expected fps
+        if 1.0 / time_since_last_step < self.fps:
+            logging.warning(f"Current timestep exceeded expected fps {self.fps}")
+
+        if self.episode_time_in_s > self.control_time_s:
+            # Terminated = True
+            terminated = True
+        return obs, reward, terminated, truncated, info
+
+    def reset(self, seed=None, options=None):
+        self.episode_time_in_s = 0.0
+        self.last_timestamp = time.perf_counter()
+        return self.env.reset(seed=seed, options=options)
+
+
+class ImageCropResizeWrapper(gym.Wrapper):
     def __init__(self, env, crop_params_dict: Dict[str, Annotated[Tuple[int], 4]], resize_size=None):
         self.env = env
         self.crop_params_dict = crop_params_dict
@@ -260,6 +305,131 @@ class ConvertToLeRobotObservation(gym.ObservationWrapper):
         return observation
 
 
+class KeyboardInterfaceWrapper(gym.Wrapper):
+    def __init__(self, env):
+        super().__init__(env)
+        self.listener = None
+        self.events = {
+            "exit_early": False,
+            "pause_policy": False,
+            "reset_env": False,
+            "human_intervention_step": False,
+        }
+        self.event_lock = Lock()  # Thread-safe access to events
+        self._init_keyboard_listener()
+
+    def _init_keyboard_listener(self):
+        """Initialize keyboard listener if not in headless mode"""
+
+        if is_headless():
+            logging.warning(
+                "Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
+            )
+            return
+        try:
+            from pynput import keyboard
+
+            def on_press(key):
+                with self.event_lock:
+                    try:
+                        if key == keyboard.Key.right or key == keyboard.Key.esc:
+                            print("Right arrow key pressed. Exiting loop...")
+                            self.events["exit_early"] = True
+                        elif key == keyboard.Key.space:
+                            if not self.events["pause_policy"]:
+                                print(
+                                    "Space key pressed. Human intervention required.\n"
+                                    "Place the leader in similar pose to the follower and press space again."
+                                )
+                                self.events["pause_policy"] = True
+                                log_say("Human intervention stage. Get ready to take over.", play_sounds=True)
+                            elif self.events["pause_policy"] and not self.events["human_intervention_step"]:
+                                self.events["human_intervention_step"] = True
+                                print("Space key pressed. Human intervention starting.")
+                                log_say("Starting human intervention.", play_sounds=True)
+                            else:
+                                self.events["pause_policy"] = False
+                                self.events["human_intervention_step"] = False
+                                print("Space key pressed for a third time.")
+                                log_say("Continuing with policy actions.", play_sounds=True)
+                    except Exception as e:
+                        print(f"Error handling key press: {e}")
+
+            self.listener = keyboard.Listener(on_press=on_press)
+            self.listener.start()
+        except ImportError:
+            logging.warning("Could not import pynput. Keyboard interface will not be available.")
+            self.listener = None
+
+    def step(self, action: Any) -> Tuple[Any, float, bool, bool, Dict]:
+        is_intervention = False
+        terminated_by_keyboard = False
+
+        # Extract policy_action if needed
+        if isinstance(self.env.action_space, gym.spaces.Tuple):
+            policy_action = action[0]
+
+        # Check the event flags without holding the lock for too long.
+        with self.event_lock:
+            if self.events["exit_early"]:
+                terminated_by_keyboard = True
+            # If we need to wait for human intervention, we note that outside the lock.
+            pause_policy = self.events["pause_policy"]
+
+        if pause_policy:
+            # Now, wait for human_intervention_step without holding the lock
+            while True:
+                with self.event_lock:
+                    if self.events["human_intervention_step"]:
+                        is_intervention = True
+                        break
+                time.sleep(0.1)  # Check more frequently if desired
+
+        # Execute the step in the underlying environment
+        obs, reward, terminated, truncated, info = self.env.step((policy_action, is_intervention))
+        return obs, reward, terminated or terminated_by_keyboard, truncated, info
+
+    def reset(self, **kwargs) -> Tuple[Any, Dict]:
+        """
+        Reset the environment and clear any pending events
+        """
+        with self.event_lock:
+            self.events = {k: False for k in self.events}
+        return self.env.reset(**kwargs)
+
+    def close(self):
+        """
+        Properly clean up the keyboard listener when the environment is closed
+        """
+        if self.listener is not None:
+            self.listener.stop()
+        super().close()
+
+
+class ResetWrapper(gym.Wrapper):
+    def __init__(
+        self, env: HILSerlRobotEnv, reset_fn: Optional[Callable[[], None]] = None, reset_time_s: float = 5
+    ):
+        super().__init__(env)
+        self.reset_fn = reset_fn
+        self.reset_time_s = reset_time_s
+
+        self.robot = self.unwrapped.robot
+        self.init_pos = self.unwrapped.initial_follower_position
+
+    def reset(self, *, seed=None, options=None):
+        if self.reset_fn is not None:
+            self.reset_fn(self.env)
+        else:
+            log_say(f"Manually reset the environment for {self.reset_time_s} seconds.", play_sounds=True)
+            start_time = time.perf_counter()
+            while time.perf_counter() - start_time < self.reset_time_s:
+                self.robot.teleop_step()
+
+            log_say("Manual reseting of the environment done.", play_sounds=True)
+        return super().reset(seed=seed, options=options)
+
+
 def make_robot_env(
     robot,
     reward_classifier,
@@ -270,18 +440,27 @@ def make_robot_env(
     display_cameras=False,
     device="cuda:0",
     resize_size=None,
+    reset_time_s=10,
+    delta_action=0.1,
+    nb_repeats=1,
+    use_relative_joint_positions=False,
 ):
     """
     Factory function to create the robot environment.
 
     Mimics gym.make() for consistent environment creation.
     """
-    env = HILSerlRobotEnv(robot, reset_follower_pos, display_cameras)
+    env = HILSerlRobotEnv(robot, display_cameras)
     env = ConvertToLeRobotObservation(env, device)
-    if crop_params_dict is not None:
-        env = HILSerlImageCropResizeWrapper(env, crop_params_dict, resize_size=resize_size)
-    env = HILSerlRewardWrapper(env, reward_classifier)
-    env = HILSerlTimeLimitWrapper(env, control_time_s, fps)
+    # if crop_params_dict is not None:
+    #     env = ImageCropResizeWrapper(env, crop_params_dict, resize_size=resize_size)
+    # env = RewardWrapper(env, reward_classifier)
+    env = TimeLimitWrapper(env, control_time_s, fps)
+    # if use_relative_joint_positions:
+    #     env = RelativeJointPositionActionWrapper(env, delta=delta_action)
+    # env = ActionRepeatWrapper(env, nb_repeat=nb_repeats)
+    env = KeyboardInterfaceWrapper(env)
+    env = ResetWrapper(env, reset_fn=None, reset_time_s=reset_time_s)
     return env
 
 
@@ -369,12 +548,37 @@ if __name__ == "__main__":
         args.reset_follower_pos,
         args.display_cameras,
         device="mps",
+        resize_size=None,
+        reset_time_s=10,
+        delta_action=0.1,
+        nb_repeats=1,
+        use_relative_joint_positions=False,
     )
 
     env.reset()
+    init_pos = env.unwrapped.initial_follower_position
+    goal_pos = init_pos
+
+    right_goal = init_pos.copy()
+    right_goal[0] += 50
+
+    left_goal = init_pos.copy()
+    left_goal[0] -= 50
+
+    # Michel is a beast
+    pitch_angle = np.linspace(left_goal[0], right_goal[0], 1000)
+
     while True:
-        intervention_action = (None, True)
-        obs, reward, terminated, truncated, info = env.step(intervention_action)
-        if terminated or truncated:
-            logging.info("Max control time reached, reset environment.")
-            env.reset()
+        for i in range(len(pitch_angle)):
+            goal_pos[0] = pitch_angle[i]
+            obs, reward, terminated, truncated, info = env.step((torch.from_numpy(goal_pos), False))
+            if terminated or truncated:
+                logging.info("Max control time reached, reset environment.")
+                env.reset()
+
+        for i in reversed(range(len(pitch_angle))):
+            goal_pos[0] = pitch_angle[i]
+            obs, reward, terminated, truncated, info = env.step((torch.from_numpy(goal_pos), False))
+            if terminated or truncated:
+                logging.info("Max control time reached, reset environment.")
+                env.reset()