Added additional wrappers for the environment: Action repeat, keyboard interface, reset wrapper
Tested the reset mechanism and keyboard interface and the convert wrapper on the robots. Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
parent
efb1982eec
commit
e0527b4a6b
|
@ -273,6 +273,9 @@ def act_with_policy(cfg: DictConfig):
|
|||
# TODO (michel-aractingi): Label the reward
|
||||
# if config.label_reward_on_actor:
|
||||
# reward = reward_classifier(obs)
|
||||
if info["is_intervention"]:
|
||||
# TODO: Check the shape
|
||||
action = info["action_intervention"]
|
||||
|
||||
list_transition_to_send_to_learner.append(
|
||||
Transition(
|
||||
|
@ -281,7 +284,7 @@ def act_with_policy(cfg: DictConfig):
|
|||
reward=reward,
|
||||
next_state=next_obs,
|
||||
done=done,
|
||||
complementary_info=None,
|
||||
complementary_info=info, # TODO Handle information for the transition, is_demonstraction: bool
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
@ -332,6 +332,9 @@ def add_actor_information_and_train(
|
|||
transition = move_transition_to_device(transition, device=device)
|
||||
replay_buffer.add(**transition)
|
||||
|
||||
if transition.get("complementary_info", {}).get("is_interaction"):
|
||||
offline_replay_buffer.add(**transition)
|
||||
|
||||
while not interaction_message_queue.empty():
|
||||
interaction_message = interaction_message_queue.get()
|
||||
# If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging
|
||||
|
|
|
@ -1,18 +1,18 @@
|
|||
import argparse
|
||||
import logging
|
||||
import time
|
||||
from typing import Annotated, Any, Dict, Optional, Tuple
|
||||
from threading import Lock
|
||||
from typing import Annotated, Any, Callable, Dict, Optional, Tuple
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.transforms.functional as F # noqa: N812
|
||||
|
||||
from lerobot.common.envs.utils import preprocess_observation
|
||||
from lerobot.common.robot_devices.control_utils import reset_follower_position
|
||||
from lerobot.common.robot_devices.control_utils import is_headless, reset_follower_position
|
||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
from lerobot.common.utils.utils import init_hydra_config, log_say
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
@ -28,7 +28,6 @@ class HILSerlRobotEnv(gym.Env):
|
|||
def __init__(
|
||||
self,
|
||||
robot,
|
||||
reset_follower_position=True,
|
||||
display_cameras=False,
|
||||
):
|
||||
"""
|
||||
|
@ -53,8 +52,7 @@ class HILSerlRobotEnv(gym.Env):
|
|||
# Dynamically determine observation and action spaces
|
||||
self._setup_spaces()
|
||||
|
||||
self._initial_follower_position = robot.follower_arms["main"].read("Present_Position")
|
||||
self.reset_follower_position = reset_follower_position
|
||||
self.initial_follower_position = robot.follower_arms["main"].read("Present_Position")
|
||||
|
||||
# Episode tracking
|
||||
self.current_step = 0
|
||||
|
@ -105,9 +103,6 @@ class HILSerlRobotEnv(gym.Env):
|
|||
"""
|
||||
super().reset(seed=seed, options=options)
|
||||
|
||||
if self.reset_follower_position:
|
||||
reset_follower_position(self.robot, target_position=self._initial_follower_position)
|
||||
|
||||
# Capture initial observation
|
||||
observation = self.robot.capture_observation()
|
||||
|
||||
|
@ -115,7 +110,7 @@ class HILSerlRobotEnv(gym.Env):
|
|||
self.current_step = 0
|
||||
self.episode_data = None
|
||||
|
||||
return observation, {}
|
||||
return observation, {"initial_position": self.initial_follower_position}
|
||||
|
||||
def step(
|
||||
self, action: Tuple[np.ndarray, bool]
|
||||
|
@ -140,7 +135,7 @@ class HILSerlRobotEnv(gym.Env):
|
|||
policy_action, intervention_bool = action
|
||||
teleop_action = None
|
||||
if not intervention_bool:
|
||||
self.robot.send_action(policy_action.cpu().numpy())
|
||||
self.robot.send_action(policy_action.cpu())
|
||||
observation = self.robot.capture_observation()
|
||||
else:
|
||||
observation, teleop_action = self.robot.teleop_step(record_data=True)
|
||||
|
@ -152,7 +147,13 @@ class HILSerlRobotEnv(gym.Env):
|
|||
terminated = False
|
||||
truncated = False
|
||||
|
||||
return observation, reward, terminated, truncated, {"action": teleop_action}
|
||||
return (
|
||||
observation,
|
||||
reward,
|
||||
terminated,
|
||||
truncated,
|
||||
{"action_intervention": teleop_action, "is_intervention": teleop_action is not None},
|
||||
)
|
||||
|
||||
def render(self):
|
||||
"""
|
||||
|
@ -176,37 +177,51 @@ class HILSerlRobotEnv(gym.Env):
|
|||
self.robot.disconnect()
|
||||
|
||||
|
||||
class HILSerlTimeLimitWrapper(gym.Wrapper):
|
||||
def __init__(self, env, control_time_s, fps):
|
||||
self.env = env
|
||||
self.control_time_s = control_time_s
|
||||
self.fps = fps
|
||||
|
||||
self.last_timestamp = 0.0
|
||||
self.episode_time_in_s = 0.0
|
||||
class ActionRepeatWrapper(gym.Wrapper):
|
||||
def __init__(self, env, nb_repeat: int = 1):
|
||||
super().__init__(env)
|
||||
self.nb_repeat = nb_repeat
|
||||
|
||||
def step(self, action):
|
||||
ret = self.env.step(action)
|
||||
time_since_last_step = time.perf_counter() - self.last_timestamp
|
||||
self.episode_time_in_s += time_since_last_step
|
||||
self.last_timestamp = time.perf_counter()
|
||||
|
||||
# check if last timestep took more time than the expected fps
|
||||
if 1.0 / time_since_last_step > self.fps:
|
||||
logging.warning(f"Current timestep exceeded expected fps {self.fps}")
|
||||
|
||||
if self.episode_time_in_s > self.control_time_s:
|
||||
# Terminated = True
|
||||
ret[2] = True
|
||||
return ret
|
||||
|
||||
def reset(self, seed=None, options=None):
|
||||
self.episode_time_in_s = 0.0
|
||||
self.last_timestamp = time.perf_counter()
|
||||
return self.env.reset(seed, options=None)
|
||||
for _ in range(self.nb_repeat):
|
||||
obs, reward, done, truncated, info = self.env.step(action)
|
||||
if done or truncated:
|
||||
break
|
||||
return obs, reward, done, truncated, info
|
||||
|
||||
|
||||
class HILSerlRewardWrapper(gym.Wrapper):
|
||||
class RelativeJointPositionActionWrapper(gym.Wrapper):
|
||||
def __init__(self, env: HILSerlRobotEnv, delta: float = 0.1):
|
||||
super().__init__(env)
|
||||
self.joint_positions = self.unwrapped.robot.follower_arms["main"].read("Present_Position")
|
||||
self.delta = delta
|
||||
|
||||
def step(self, action):
|
||||
action_joint = action
|
||||
self.joint_positions = self.unwrapped.robot.follower_arms["main"].read("Present_Position")
|
||||
if isinstance(self.env.action_space, gym.spaces.Tuple):
|
||||
action_joint = action[0]
|
||||
joint_positions = self.joint_positions + (self.delta * action_joint)
|
||||
# clip the joint positions to the joint limits with the action space
|
||||
joint_positions = np.clip(joint_positions, self.action_space.low, self.action_space.high)
|
||||
|
||||
if isinstance(self.env.action_space, gym.spaces.Tuple):
|
||||
return self.env.step((joint_positions, action[1]))
|
||||
|
||||
obs, reward, terminated, truncated, info = self.env.step(joint_positions)
|
||||
if info["is_intervention"]:
|
||||
# teleop actions are returned in absolute joint space
|
||||
# If we are using a relative joint position action space,
|
||||
# there will be a mismatch between the spaces of the policy and teleop actions
|
||||
# Solution is to transform the teleop actions into relative space.
|
||||
teleop_action = info["action_intervention"] # teleop actions are in absolute joint space
|
||||
relative_teleop_action = (teleop_action - self.joint_positions) / self.delta
|
||||
info["action_intervention"] = relative_teleop_action
|
||||
|
||||
return self.env.step(joint_positions)
|
||||
|
||||
|
||||
class RewardWrapper(gym.Wrapper):
|
||||
def __init__(self, env, reward_classifier: Optional[None], device: torch.device = "cuda"):
|
||||
self.env = env
|
||||
self.reward_classifier = reward_classifier
|
||||
|
@ -225,7 +240,37 @@ class HILSerlRewardWrapper(gym.Wrapper):
|
|||
return self.env.reset(seed=seed, options=options)
|
||||
|
||||
|
||||
class HILSerlImageCropResizeWrapper(gym.Wrapper):
|
||||
class TimeLimitWrapper(gym.Wrapper):
|
||||
def __init__(self, env, control_time_s, fps):
|
||||
self.env = env
|
||||
self.control_time_s = control_time_s
|
||||
self.fps = fps
|
||||
|
||||
self.last_timestamp = 0.0
|
||||
self.episode_time_in_s = 0.0
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
time_since_last_step = time.perf_counter() - self.last_timestamp
|
||||
self.episode_time_in_s += time_since_last_step
|
||||
self.last_timestamp = time.perf_counter()
|
||||
|
||||
# check if last timestep took more time than the expected fps
|
||||
if 1.0 / time_since_last_step < self.fps:
|
||||
logging.warning(f"Current timestep exceeded expected fps {self.fps}")
|
||||
|
||||
if self.episode_time_in_s > self.control_time_s:
|
||||
# Terminated = True
|
||||
terminated = True
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
def reset(self, seed=None, options=None):
|
||||
self.episode_time_in_s = 0.0
|
||||
self.last_timestamp = time.perf_counter()
|
||||
return self.env.reset(seed=seed, options=options)
|
||||
|
||||
|
||||
class ImageCropResizeWrapper(gym.Wrapper):
|
||||
def __init__(self, env, crop_params_dict: Dict[str, Annotated[Tuple[int], 4]], resize_size=None):
|
||||
self.env = env
|
||||
self.crop_params_dict = crop_params_dict
|
||||
|
@ -260,6 +305,131 @@ class ConvertToLeRobotObservation(gym.ObservationWrapper):
|
|||
return observation
|
||||
|
||||
|
||||
class KeyboardInterfaceWrapper(gym.Wrapper):
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
self.listener = None
|
||||
self.events = {
|
||||
"exit_early": False,
|
||||
"pause_policy": False,
|
||||
"reset_env": False,
|
||||
"human_intervention_step": False,
|
||||
}
|
||||
self.event_lock = Lock() # Thread-safe access to events
|
||||
self._init_keyboard_listener()
|
||||
|
||||
def _init_keyboard_listener(self):
|
||||
"""Initialize keyboard listener if not in headless mode"""
|
||||
|
||||
if is_headless():
|
||||
logging.warning(
|
||||
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
|
||||
)
|
||||
return
|
||||
try:
|
||||
from pynput import keyboard
|
||||
|
||||
def on_press(key):
|
||||
with self.event_lock:
|
||||
try:
|
||||
if key == keyboard.Key.right or key == keyboard.Key.esc:
|
||||
print("Right arrow key pressed. Exiting loop...")
|
||||
self.events["exit_early"] = True
|
||||
elif key == keyboard.Key.space:
|
||||
if not self.events["pause_policy"]:
|
||||
print(
|
||||
"Space key pressed. Human intervention required.\n"
|
||||
"Place the leader in similar pose to the follower and press space again."
|
||||
)
|
||||
self.events["pause_policy"] = True
|
||||
log_say("Human intervention stage. Get ready to take over.", play_sounds=True)
|
||||
elif self.events["pause_policy"] and not self.events["human_intervention_step"]:
|
||||
self.events["human_intervention_step"] = True
|
||||
print("Space key pressed. Human intervention starting.")
|
||||
log_say("Starting human intervention.", play_sounds=True)
|
||||
else:
|
||||
self.events["pause_policy"] = False
|
||||
self.events["human_intervention_step"] = False
|
||||
print("Space key pressed for a third time.")
|
||||
log_say("Continuing with policy actions.", play_sounds=True)
|
||||
except Exception as e:
|
||||
print(f"Error handling key press: {e}")
|
||||
|
||||
self.listener = keyboard.Listener(on_press=on_press)
|
||||
self.listener.start()
|
||||
except ImportError:
|
||||
logging.warning("Could not import pynput. Keyboard interface will not be available.")
|
||||
self.listener = None
|
||||
|
||||
def step(self, action: Any) -> Tuple[Any, float, bool, bool, Dict]:
|
||||
is_intervention = False
|
||||
terminated_by_keyboard = False
|
||||
|
||||
# Extract policy_action if needed
|
||||
if isinstance(self.env.action_space, gym.spaces.Tuple):
|
||||
policy_action = action[0]
|
||||
|
||||
# Check the event flags without holding the lock for too long.
|
||||
with self.event_lock:
|
||||
if self.events["exit_early"]:
|
||||
terminated_by_keyboard = True
|
||||
# If we need to wait for human intervention, we note that outside the lock.
|
||||
pause_policy = self.events["pause_policy"]
|
||||
|
||||
if pause_policy:
|
||||
# Now, wait for human_intervention_step without holding the lock
|
||||
while True:
|
||||
with self.event_lock:
|
||||
if self.events["human_intervention_step"]:
|
||||
is_intervention = True
|
||||
break
|
||||
time.sleep(0.1) # Check more frequently if desired
|
||||
|
||||
# Execute the step in the underlying environment
|
||||
obs, reward, terminated, truncated, info = self.env.step((policy_action, is_intervention))
|
||||
return obs, reward, terminated or terminated_by_keyboard, truncated, info
|
||||
|
||||
def reset(self, **kwargs) -> Tuple[Any, Dict]:
|
||||
"""
|
||||
Reset the environment and clear any pending events
|
||||
"""
|
||||
with self.event_lock:
|
||||
self.events = {k: False for k in self.events}
|
||||
return self.env.reset(**kwargs)
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Properly clean up the keyboard listener when the environment is closed
|
||||
"""
|
||||
if self.listener is not None:
|
||||
self.listener.stop()
|
||||
super().close()
|
||||
|
||||
|
||||
class ResetWrapper(gym.Wrapper):
|
||||
def __init__(
|
||||
self, env: HILSerlRobotEnv, reset_fn: Optional[Callable[[], None]] = None, reset_time_s: float = 5
|
||||
):
|
||||
super().__init__(env)
|
||||
self.reset_fn = reset_fn
|
||||
self.reset_time_s = reset_time_s
|
||||
|
||||
self.robot = self.unwrapped.robot
|
||||
self.init_pos = self.unwrapped.initial_follower_position
|
||||
|
||||
def reset(self, *, seed=None, options=None):
|
||||
if self.reset_fn is not None:
|
||||
self.reset_fn(self.env)
|
||||
else:
|
||||
log_say(f"Manually reset the environment for {self.reset_time_s} seconds.", play_sounds=True)
|
||||
start_time = time.perf_counter()
|
||||
while time.perf_counter() - start_time < self.reset_time_s:
|
||||
self.robot.teleop_step()
|
||||
|
||||
log_say("Manual reseting of the environment done.", play_sounds=True)
|
||||
return super().reset(seed=seed, options=options)
|
||||
|
||||
|
||||
def make_robot_env(
|
||||
robot,
|
||||
reward_classifier,
|
||||
|
@ -270,18 +440,27 @@ def make_robot_env(
|
|||
display_cameras=False,
|
||||
device="cuda:0",
|
||||
resize_size=None,
|
||||
reset_time_s=10,
|
||||
delta_action=0.1,
|
||||
nb_repeats=1,
|
||||
use_relative_joint_positions=False,
|
||||
):
|
||||
"""
|
||||
Factory function to create the robot environment.
|
||||
|
||||
Mimics gym.make() for consistent environment creation.
|
||||
"""
|
||||
env = HILSerlRobotEnv(robot, reset_follower_pos, display_cameras)
|
||||
env = HILSerlRobotEnv(robot, display_cameras)
|
||||
env = ConvertToLeRobotObservation(env, device)
|
||||
if crop_params_dict is not None:
|
||||
env = HILSerlImageCropResizeWrapper(env, crop_params_dict, resize_size=resize_size)
|
||||
env = HILSerlRewardWrapper(env, reward_classifier)
|
||||
env = HILSerlTimeLimitWrapper(env, control_time_s, fps)
|
||||
# if crop_params_dict is not None:
|
||||
# env = ImageCropResizeWrapper(env, crop_params_dict, resize_size=resize_size)
|
||||
# env = RewardWrapper(env, reward_classifier)
|
||||
env = TimeLimitWrapper(env, control_time_s, fps)
|
||||
# if use_relative_joint_positions:
|
||||
# env = RelativeJointPositionActionWrapper(env, delta=delta_action)
|
||||
# env = ActionRepeatWrapper(env, nb_repeat=nb_repeats)
|
||||
env = KeyboardInterfaceWrapper(env)
|
||||
env = ResetWrapper(env, reset_fn=None, reset_time_s=reset_time_s)
|
||||
return env
|
||||
|
||||
|
||||
|
@ -369,12 +548,37 @@ if __name__ == "__main__":
|
|||
args.reset_follower_pos,
|
||||
args.display_cameras,
|
||||
device="mps",
|
||||
resize_size=None,
|
||||
reset_time_s=10,
|
||||
delta_action=0.1,
|
||||
nb_repeats=1,
|
||||
use_relative_joint_positions=False,
|
||||
)
|
||||
|
||||
env.reset()
|
||||
init_pos = env.unwrapped.initial_follower_position
|
||||
goal_pos = init_pos
|
||||
|
||||
right_goal = init_pos.copy()
|
||||
right_goal[0] += 50
|
||||
|
||||
left_goal = init_pos.copy()
|
||||
left_goal[0] -= 50
|
||||
|
||||
# Michel is a beast
|
||||
pitch_angle = np.linspace(left_goal[0], right_goal[0], 1000)
|
||||
|
||||
while True:
|
||||
intervention_action = (None, True)
|
||||
obs, reward, terminated, truncated, info = env.step(intervention_action)
|
||||
if terminated or truncated:
|
||||
logging.info("Max control time reached, reset environment.")
|
||||
env.reset()
|
||||
for i in range(len(pitch_angle)):
|
||||
goal_pos[0] = pitch_angle[i]
|
||||
obs, reward, terminated, truncated, info = env.step((torch.from_numpy(goal_pos), False))
|
||||
if terminated or truncated:
|
||||
logging.info("Max control time reached, reset environment.")
|
||||
env.reset()
|
||||
|
||||
for i in reversed(range(len(pitch_angle))):
|
||||
goal_pos[0] = pitch_angle[i]
|
||||
obs, reward, terminated, truncated, info = env.step((torch.from_numpy(goal_pos), False))
|
||||
if terminated or truncated:
|
||||
logging.info("Max control time reached, reset environment.")
|
||||
env.reset()
|
||||
|
|
Loading…
Reference in New Issue