Added additional wrappers for the environment: Action repeat, keyboard interface, reset wrapper

Tested the reset mechanism and keyboard interface and the convert wrapper on the robots.

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
Michel Aractingi 2025-02-04 17:41:14 +00:00
parent efb1982eec
commit e0527b4a6b
3 changed files with 262 additions and 52 deletions

View File

@ -273,6 +273,9 @@ def act_with_policy(cfg: DictConfig):
# TODO (michel-aractingi): Label the reward
# if config.label_reward_on_actor:
# reward = reward_classifier(obs)
if info["is_intervention"]:
# TODO: Check the shape
action = info["action_intervention"]
list_transition_to_send_to_learner.append(
Transition(
@ -281,7 +284,7 @@ def act_with_policy(cfg: DictConfig):
reward=reward,
next_state=next_obs,
done=done,
complementary_info=None,
complementary_info=info, # TODO Handle information for the transition, is_demonstraction: bool
)
)

View File

@ -332,6 +332,9 @@ def add_actor_information_and_train(
transition = move_transition_to_device(transition, device=device)
replay_buffer.add(**transition)
if transition.get("complementary_info", {}).get("is_interaction"):
offline_replay_buffer.add(**transition)
while not interaction_message_queue.empty():
interaction_message = interaction_message_queue.get()
# If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging

View File

@ -1,18 +1,18 @@
import argparse
import logging
import time
from typing import Annotated, Any, Dict, Optional, Tuple
from threading import Lock
from typing import Annotated, Any, Callable, Dict, Optional, Tuple
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms.functional as F # noqa: N812
from lerobot.common.envs.utils import preprocess_observation
from lerobot.common.robot_devices.control_utils import reset_follower_position
from lerobot.common.robot_devices.control_utils import is_headless, reset_follower_position
from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.utils.utils import init_hydra_config
from lerobot.common.utils.utils import init_hydra_config, log_say
logging.basicConfig(level=logging.INFO)
@ -28,7 +28,6 @@ class HILSerlRobotEnv(gym.Env):
def __init__(
self,
robot,
reset_follower_position=True,
display_cameras=False,
):
"""
@ -53,8 +52,7 @@ class HILSerlRobotEnv(gym.Env):
# Dynamically determine observation and action spaces
self._setup_spaces()
self._initial_follower_position = robot.follower_arms["main"].read("Present_Position")
self.reset_follower_position = reset_follower_position
self.initial_follower_position = robot.follower_arms["main"].read("Present_Position")
# Episode tracking
self.current_step = 0
@ -105,9 +103,6 @@ class HILSerlRobotEnv(gym.Env):
"""
super().reset(seed=seed, options=options)
if self.reset_follower_position:
reset_follower_position(self.robot, target_position=self._initial_follower_position)
# Capture initial observation
observation = self.robot.capture_observation()
@ -115,7 +110,7 @@ class HILSerlRobotEnv(gym.Env):
self.current_step = 0
self.episode_data = None
return observation, {}
return observation, {"initial_position": self.initial_follower_position}
def step(
self, action: Tuple[np.ndarray, bool]
@ -140,7 +135,7 @@ class HILSerlRobotEnv(gym.Env):
policy_action, intervention_bool = action
teleop_action = None
if not intervention_bool:
self.robot.send_action(policy_action.cpu().numpy())
self.robot.send_action(policy_action.cpu())
observation = self.robot.capture_observation()
else:
observation, teleop_action = self.robot.teleop_step(record_data=True)
@ -152,7 +147,13 @@ class HILSerlRobotEnv(gym.Env):
terminated = False
truncated = False
return observation, reward, terminated, truncated, {"action": teleop_action}
return (
observation,
reward,
terminated,
truncated,
{"action_intervention": teleop_action, "is_intervention": teleop_action is not None},
)
def render(self):
"""
@ -176,37 +177,51 @@ class HILSerlRobotEnv(gym.Env):
self.robot.disconnect()
class HILSerlTimeLimitWrapper(gym.Wrapper):
def __init__(self, env, control_time_s, fps):
self.env = env
self.control_time_s = control_time_s
self.fps = fps
self.last_timestamp = 0.0
self.episode_time_in_s = 0.0
class ActionRepeatWrapper(gym.Wrapper):
def __init__(self, env, nb_repeat: int = 1):
super().__init__(env)
self.nb_repeat = nb_repeat
def step(self, action):
ret = self.env.step(action)
time_since_last_step = time.perf_counter() - self.last_timestamp
self.episode_time_in_s += time_since_last_step
self.last_timestamp = time.perf_counter()
# check if last timestep took more time than the expected fps
if 1.0 / time_since_last_step > self.fps:
logging.warning(f"Current timestep exceeded expected fps {self.fps}")
if self.episode_time_in_s > self.control_time_s:
# Terminated = True
ret[2] = True
return ret
def reset(self, seed=None, options=None):
self.episode_time_in_s = 0.0
self.last_timestamp = time.perf_counter()
return self.env.reset(seed, options=None)
for _ in range(self.nb_repeat):
obs, reward, done, truncated, info = self.env.step(action)
if done or truncated:
break
return obs, reward, done, truncated, info
class HILSerlRewardWrapper(gym.Wrapper):
class RelativeJointPositionActionWrapper(gym.Wrapper):
def __init__(self, env: HILSerlRobotEnv, delta: float = 0.1):
super().__init__(env)
self.joint_positions = self.unwrapped.robot.follower_arms["main"].read("Present_Position")
self.delta = delta
def step(self, action):
action_joint = action
self.joint_positions = self.unwrapped.robot.follower_arms["main"].read("Present_Position")
if isinstance(self.env.action_space, gym.spaces.Tuple):
action_joint = action[0]
joint_positions = self.joint_positions + (self.delta * action_joint)
# clip the joint positions to the joint limits with the action space
joint_positions = np.clip(joint_positions, self.action_space.low, self.action_space.high)
if isinstance(self.env.action_space, gym.spaces.Tuple):
return self.env.step((joint_positions, action[1]))
obs, reward, terminated, truncated, info = self.env.step(joint_positions)
if info["is_intervention"]:
# teleop actions are returned in absolute joint space
# If we are using a relative joint position action space,
# there will be a mismatch between the spaces of the policy and teleop actions
# Solution is to transform the teleop actions into relative space.
teleop_action = info["action_intervention"] # teleop actions are in absolute joint space
relative_teleop_action = (teleop_action - self.joint_positions) / self.delta
info["action_intervention"] = relative_teleop_action
return self.env.step(joint_positions)
class RewardWrapper(gym.Wrapper):
def __init__(self, env, reward_classifier: Optional[None], device: torch.device = "cuda"):
self.env = env
self.reward_classifier = reward_classifier
@ -225,7 +240,37 @@ class HILSerlRewardWrapper(gym.Wrapper):
return self.env.reset(seed=seed, options=options)
class HILSerlImageCropResizeWrapper(gym.Wrapper):
class TimeLimitWrapper(gym.Wrapper):
def __init__(self, env, control_time_s, fps):
self.env = env
self.control_time_s = control_time_s
self.fps = fps
self.last_timestamp = 0.0
self.episode_time_in_s = 0.0
def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
time_since_last_step = time.perf_counter() - self.last_timestamp
self.episode_time_in_s += time_since_last_step
self.last_timestamp = time.perf_counter()
# check if last timestep took more time than the expected fps
if 1.0 / time_since_last_step < self.fps:
logging.warning(f"Current timestep exceeded expected fps {self.fps}")
if self.episode_time_in_s > self.control_time_s:
# Terminated = True
terminated = True
return obs, reward, terminated, truncated, info
def reset(self, seed=None, options=None):
self.episode_time_in_s = 0.0
self.last_timestamp = time.perf_counter()
return self.env.reset(seed=seed, options=options)
class ImageCropResizeWrapper(gym.Wrapper):
def __init__(self, env, crop_params_dict: Dict[str, Annotated[Tuple[int], 4]], resize_size=None):
self.env = env
self.crop_params_dict = crop_params_dict
@ -260,6 +305,131 @@ class ConvertToLeRobotObservation(gym.ObservationWrapper):
return observation
class KeyboardInterfaceWrapper(gym.Wrapper):
def __init__(self, env):
super().__init__(env)
self.listener = None
self.events = {
"exit_early": False,
"pause_policy": False,
"reset_env": False,
"human_intervention_step": False,
}
self.event_lock = Lock() # Thread-safe access to events
self._init_keyboard_listener()
def _init_keyboard_listener(self):
"""Initialize keyboard listener if not in headless mode"""
if is_headless():
logging.warning(
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
)
return
try:
from pynput import keyboard
def on_press(key):
with self.event_lock:
try:
if key == keyboard.Key.right or key == keyboard.Key.esc:
print("Right arrow key pressed. Exiting loop...")
self.events["exit_early"] = True
elif key == keyboard.Key.space:
if not self.events["pause_policy"]:
print(
"Space key pressed. Human intervention required.\n"
"Place the leader in similar pose to the follower and press space again."
)
self.events["pause_policy"] = True
log_say("Human intervention stage. Get ready to take over.", play_sounds=True)
elif self.events["pause_policy"] and not self.events["human_intervention_step"]:
self.events["human_intervention_step"] = True
print("Space key pressed. Human intervention starting.")
log_say("Starting human intervention.", play_sounds=True)
else:
self.events["pause_policy"] = False
self.events["human_intervention_step"] = False
print("Space key pressed for a third time.")
log_say("Continuing with policy actions.", play_sounds=True)
except Exception as e:
print(f"Error handling key press: {e}")
self.listener = keyboard.Listener(on_press=on_press)
self.listener.start()
except ImportError:
logging.warning("Could not import pynput. Keyboard interface will not be available.")
self.listener = None
def step(self, action: Any) -> Tuple[Any, float, bool, bool, Dict]:
is_intervention = False
terminated_by_keyboard = False
# Extract policy_action if needed
if isinstance(self.env.action_space, gym.spaces.Tuple):
policy_action = action[0]
# Check the event flags without holding the lock for too long.
with self.event_lock:
if self.events["exit_early"]:
terminated_by_keyboard = True
# If we need to wait for human intervention, we note that outside the lock.
pause_policy = self.events["pause_policy"]
if pause_policy:
# Now, wait for human_intervention_step without holding the lock
while True:
with self.event_lock:
if self.events["human_intervention_step"]:
is_intervention = True
break
time.sleep(0.1) # Check more frequently if desired
# Execute the step in the underlying environment
obs, reward, terminated, truncated, info = self.env.step((policy_action, is_intervention))
return obs, reward, terminated or terminated_by_keyboard, truncated, info
def reset(self, **kwargs) -> Tuple[Any, Dict]:
"""
Reset the environment and clear any pending events
"""
with self.event_lock:
self.events = {k: False for k in self.events}
return self.env.reset(**kwargs)
def close(self):
"""
Properly clean up the keyboard listener when the environment is closed
"""
if self.listener is not None:
self.listener.stop()
super().close()
class ResetWrapper(gym.Wrapper):
def __init__(
self, env: HILSerlRobotEnv, reset_fn: Optional[Callable[[], None]] = None, reset_time_s: float = 5
):
super().__init__(env)
self.reset_fn = reset_fn
self.reset_time_s = reset_time_s
self.robot = self.unwrapped.robot
self.init_pos = self.unwrapped.initial_follower_position
def reset(self, *, seed=None, options=None):
if self.reset_fn is not None:
self.reset_fn(self.env)
else:
log_say(f"Manually reset the environment for {self.reset_time_s} seconds.", play_sounds=True)
start_time = time.perf_counter()
while time.perf_counter() - start_time < self.reset_time_s:
self.robot.teleop_step()
log_say("Manual reseting of the environment done.", play_sounds=True)
return super().reset(seed=seed, options=options)
def make_robot_env(
robot,
reward_classifier,
@ -270,18 +440,27 @@ def make_robot_env(
display_cameras=False,
device="cuda:0",
resize_size=None,
reset_time_s=10,
delta_action=0.1,
nb_repeats=1,
use_relative_joint_positions=False,
):
"""
Factory function to create the robot environment.
Mimics gym.make() for consistent environment creation.
"""
env = HILSerlRobotEnv(robot, reset_follower_pos, display_cameras)
env = HILSerlRobotEnv(robot, display_cameras)
env = ConvertToLeRobotObservation(env, device)
if crop_params_dict is not None:
env = HILSerlImageCropResizeWrapper(env, crop_params_dict, resize_size=resize_size)
env = HILSerlRewardWrapper(env, reward_classifier)
env = HILSerlTimeLimitWrapper(env, control_time_s, fps)
# if crop_params_dict is not None:
# env = ImageCropResizeWrapper(env, crop_params_dict, resize_size=resize_size)
# env = RewardWrapper(env, reward_classifier)
env = TimeLimitWrapper(env, control_time_s, fps)
# if use_relative_joint_positions:
# env = RelativeJointPositionActionWrapper(env, delta=delta_action)
# env = ActionRepeatWrapper(env, nb_repeat=nb_repeats)
env = KeyboardInterfaceWrapper(env)
env = ResetWrapper(env, reset_fn=None, reset_time_s=reset_time_s)
return env
@ -369,12 +548,37 @@ if __name__ == "__main__":
args.reset_follower_pos,
args.display_cameras,
device="mps",
resize_size=None,
reset_time_s=10,
delta_action=0.1,
nb_repeats=1,
use_relative_joint_positions=False,
)
env.reset()
init_pos = env.unwrapped.initial_follower_position
goal_pos = init_pos
right_goal = init_pos.copy()
right_goal[0] += 50
left_goal = init_pos.copy()
left_goal[0] -= 50
# Michel is a beast
pitch_angle = np.linspace(left_goal[0], right_goal[0], 1000)
while True:
intervention_action = (None, True)
obs, reward, terminated, truncated, info = env.step(intervention_action)
if terminated or truncated:
logging.info("Max control time reached, reset environment.")
env.reset()
for i in range(len(pitch_angle)):
goal_pos[0] = pitch_angle[i]
obs, reward, terminated, truncated, info = env.step((torch.from_numpy(goal_pos), False))
if terminated or truncated:
logging.info("Max control time reached, reset environment.")
env.reset()
for i in reversed(range(len(pitch_angle))):
goal_pos[0] = pitch_angle[i]
obs, reward, terminated, truncated, info = env.step((torch.from_numpy(goal_pos), False))
if terminated or truncated:
logging.info("Max control time reached, reset environment.")
env.reset()