Added additional wrappers for the environment: Action repeat, keyboard interface, reset wrapper
Tested the reset mechanism and keyboard interface and the convert wrapper on the robots. Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
parent
efb1982eec
commit
e0527b4a6b
|
@ -273,6 +273,9 @@ def act_with_policy(cfg: DictConfig):
|
||||||
# TODO (michel-aractingi): Label the reward
|
# TODO (michel-aractingi): Label the reward
|
||||||
# if config.label_reward_on_actor:
|
# if config.label_reward_on_actor:
|
||||||
# reward = reward_classifier(obs)
|
# reward = reward_classifier(obs)
|
||||||
|
if info["is_intervention"]:
|
||||||
|
# TODO: Check the shape
|
||||||
|
action = info["action_intervention"]
|
||||||
|
|
||||||
list_transition_to_send_to_learner.append(
|
list_transition_to_send_to_learner.append(
|
||||||
Transition(
|
Transition(
|
||||||
|
@ -281,7 +284,7 @@ def act_with_policy(cfg: DictConfig):
|
||||||
reward=reward,
|
reward=reward,
|
||||||
next_state=next_obs,
|
next_state=next_obs,
|
||||||
done=done,
|
done=done,
|
||||||
complementary_info=None,
|
complementary_info=info, # TODO Handle information for the transition, is_demonstraction: bool
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -332,6 +332,9 @@ def add_actor_information_and_train(
|
||||||
transition = move_transition_to_device(transition, device=device)
|
transition = move_transition_to_device(transition, device=device)
|
||||||
replay_buffer.add(**transition)
|
replay_buffer.add(**transition)
|
||||||
|
|
||||||
|
if transition.get("complementary_info", {}).get("is_interaction"):
|
||||||
|
offline_replay_buffer.add(**transition)
|
||||||
|
|
||||||
while not interaction_message_queue.empty():
|
while not interaction_message_queue.empty():
|
||||||
interaction_message = interaction_message_queue.get()
|
interaction_message = interaction_message_queue.get()
|
||||||
# If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging
|
# If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging
|
||||||
|
|
|
@ -1,18 +1,18 @@
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import Annotated, Any, Dict, Optional, Tuple
|
from threading import Lock
|
||||||
|
from typing import Annotated, Any, Callable, Dict, Optional, Tuple
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
import torchvision.transforms.functional as F # noqa: N812
|
import torchvision.transforms.functional as F # noqa: N812
|
||||||
|
|
||||||
from lerobot.common.envs.utils import preprocess_observation
|
from lerobot.common.envs.utils import preprocess_observation
|
||||||
from lerobot.common.robot_devices.control_utils import reset_follower_position
|
from lerobot.common.robot_devices.control_utils import is_headless, reset_follower_position
|
||||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||||
from lerobot.common.utils.utils import init_hydra_config
|
from lerobot.common.utils.utils import init_hydra_config, log_say
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
@ -28,7 +28,6 @@ class HILSerlRobotEnv(gym.Env):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
robot,
|
robot,
|
||||||
reset_follower_position=True,
|
|
||||||
display_cameras=False,
|
display_cameras=False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -53,8 +52,7 @@ class HILSerlRobotEnv(gym.Env):
|
||||||
# Dynamically determine observation and action spaces
|
# Dynamically determine observation and action spaces
|
||||||
self._setup_spaces()
|
self._setup_spaces()
|
||||||
|
|
||||||
self._initial_follower_position = robot.follower_arms["main"].read("Present_Position")
|
self.initial_follower_position = robot.follower_arms["main"].read("Present_Position")
|
||||||
self.reset_follower_position = reset_follower_position
|
|
||||||
|
|
||||||
# Episode tracking
|
# Episode tracking
|
||||||
self.current_step = 0
|
self.current_step = 0
|
||||||
|
@ -105,9 +103,6 @@ class HILSerlRobotEnv(gym.Env):
|
||||||
"""
|
"""
|
||||||
super().reset(seed=seed, options=options)
|
super().reset(seed=seed, options=options)
|
||||||
|
|
||||||
if self.reset_follower_position:
|
|
||||||
reset_follower_position(self.robot, target_position=self._initial_follower_position)
|
|
||||||
|
|
||||||
# Capture initial observation
|
# Capture initial observation
|
||||||
observation = self.robot.capture_observation()
|
observation = self.robot.capture_observation()
|
||||||
|
|
||||||
|
@ -115,7 +110,7 @@ class HILSerlRobotEnv(gym.Env):
|
||||||
self.current_step = 0
|
self.current_step = 0
|
||||||
self.episode_data = None
|
self.episode_data = None
|
||||||
|
|
||||||
return observation, {}
|
return observation, {"initial_position": self.initial_follower_position}
|
||||||
|
|
||||||
def step(
|
def step(
|
||||||
self, action: Tuple[np.ndarray, bool]
|
self, action: Tuple[np.ndarray, bool]
|
||||||
|
@ -140,7 +135,7 @@ class HILSerlRobotEnv(gym.Env):
|
||||||
policy_action, intervention_bool = action
|
policy_action, intervention_bool = action
|
||||||
teleop_action = None
|
teleop_action = None
|
||||||
if not intervention_bool:
|
if not intervention_bool:
|
||||||
self.robot.send_action(policy_action.cpu().numpy())
|
self.robot.send_action(policy_action.cpu())
|
||||||
observation = self.robot.capture_observation()
|
observation = self.robot.capture_observation()
|
||||||
else:
|
else:
|
||||||
observation, teleop_action = self.robot.teleop_step(record_data=True)
|
observation, teleop_action = self.robot.teleop_step(record_data=True)
|
||||||
|
@ -152,7 +147,13 @@ class HILSerlRobotEnv(gym.Env):
|
||||||
terminated = False
|
terminated = False
|
||||||
truncated = False
|
truncated = False
|
||||||
|
|
||||||
return observation, reward, terminated, truncated, {"action": teleop_action}
|
return (
|
||||||
|
observation,
|
||||||
|
reward,
|
||||||
|
terminated,
|
||||||
|
truncated,
|
||||||
|
{"action_intervention": teleop_action, "is_intervention": teleop_action is not None},
|
||||||
|
)
|
||||||
|
|
||||||
def render(self):
|
def render(self):
|
||||||
"""
|
"""
|
||||||
|
@ -176,37 +177,51 @@ class HILSerlRobotEnv(gym.Env):
|
||||||
self.robot.disconnect()
|
self.robot.disconnect()
|
||||||
|
|
||||||
|
|
||||||
class HILSerlTimeLimitWrapper(gym.Wrapper):
|
class ActionRepeatWrapper(gym.Wrapper):
|
||||||
def __init__(self, env, control_time_s, fps):
|
def __init__(self, env, nb_repeat: int = 1):
|
||||||
self.env = env
|
super().__init__(env)
|
||||||
self.control_time_s = control_time_s
|
self.nb_repeat = nb_repeat
|
||||||
self.fps = fps
|
|
||||||
|
|
||||||
self.last_timestamp = 0.0
|
|
||||||
self.episode_time_in_s = 0.0
|
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
ret = self.env.step(action)
|
for _ in range(self.nb_repeat):
|
||||||
time_since_last_step = time.perf_counter() - self.last_timestamp
|
obs, reward, done, truncated, info = self.env.step(action)
|
||||||
self.episode_time_in_s += time_since_last_step
|
if done or truncated:
|
||||||
self.last_timestamp = time.perf_counter()
|
break
|
||||||
|
return obs, reward, done, truncated, info
|
||||||
# check if last timestep took more time than the expected fps
|
|
||||||
if 1.0 / time_since_last_step > self.fps:
|
|
||||||
logging.warning(f"Current timestep exceeded expected fps {self.fps}")
|
|
||||||
|
|
||||||
if self.episode_time_in_s > self.control_time_s:
|
|
||||||
# Terminated = True
|
|
||||||
ret[2] = True
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def reset(self, seed=None, options=None):
|
|
||||||
self.episode_time_in_s = 0.0
|
|
||||||
self.last_timestamp = time.perf_counter()
|
|
||||||
return self.env.reset(seed, options=None)
|
|
||||||
|
|
||||||
|
|
||||||
class HILSerlRewardWrapper(gym.Wrapper):
|
class RelativeJointPositionActionWrapper(gym.Wrapper):
|
||||||
|
def __init__(self, env: HILSerlRobotEnv, delta: float = 0.1):
|
||||||
|
super().__init__(env)
|
||||||
|
self.joint_positions = self.unwrapped.robot.follower_arms["main"].read("Present_Position")
|
||||||
|
self.delta = delta
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
action_joint = action
|
||||||
|
self.joint_positions = self.unwrapped.robot.follower_arms["main"].read("Present_Position")
|
||||||
|
if isinstance(self.env.action_space, gym.spaces.Tuple):
|
||||||
|
action_joint = action[0]
|
||||||
|
joint_positions = self.joint_positions + (self.delta * action_joint)
|
||||||
|
# clip the joint positions to the joint limits with the action space
|
||||||
|
joint_positions = np.clip(joint_positions, self.action_space.low, self.action_space.high)
|
||||||
|
|
||||||
|
if isinstance(self.env.action_space, gym.spaces.Tuple):
|
||||||
|
return self.env.step((joint_positions, action[1]))
|
||||||
|
|
||||||
|
obs, reward, terminated, truncated, info = self.env.step(joint_positions)
|
||||||
|
if info["is_intervention"]:
|
||||||
|
# teleop actions are returned in absolute joint space
|
||||||
|
# If we are using a relative joint position action space,
|
||||||
|
# there will be a mismatch between the spaces of the policy and teleop actions
|
||||||
|
# Solution is to transform the teleop actions into relative space.
|
||||||
|
teleop_action = info["action_intervention"] # teleop actions are in absolute joint space
|
||||||
|
relative_teleop_action = (teleop_action - self.joint_positions) / self.delta
|
||||||
|
info["action_intervention"] = relative_teleop_action
|
||||||
|
|
||||||
|
return self.env.step(joint_positions)
|
||||||
|
|
||||||
|
|
||||||
|
class RewardWrapper(gym.Wrapper):
|
||||||
def __init__(self, env, reward_classifier: Optional[None], device: torch.device = "cuda"):
|
def __init__(self, env, reward_classifier: Optional[None], device: torch.device = "cuda"):
|
||||||
self.env = env
|
self.env = env
|
||||||
self.reward_classifier = reward_classifier
|
self.reward_classifier = reward_classifier
|
||||||
|
@ -225,7 +240,37 @@ class HILSerlRewardWrapper(gym.Wrapper):
|
||||||
return self.env.reset(seed=seed, options=options)
|
return self.env.reset(seed=seed, options=options)
|
||||||
|
|
||||||
|
|
||||||
class HILSerlImageCropResizeWrapper(gym.Wrapper):
|
class TimeLimitWrapper(gym.Wrapper):
|
||||||
|
def __init__(self, env, control_time_s, fps):
|
||||||
|
self.env = env
|
||||||
|
self.control_time_s = control_time_s
|
||||||
|
self.fps = fps
|
||||||
|
|
||||||
|
self.last_timestamp = 0.0
|
||||||
|
self.episode_time_in_s = 0.0
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||||
|
time_since_last_step = time.perf_counter() - self.last_timestamp
|
||||||
|
self.episode_time_in_s += time_since_last_step
|
||||||
|
self.last_timestamp = time.perf_counter()
|
||||||
|
|
||||||
|
# check if last timestep took more time than the expected fps
|
||||||
|
if 1.0 / time_since_last_step < self.fps:
|
||||||
|
logging.warning(f"Current timestep exceeded expected fps {self.fps}")
|
||||||
|
|
||||||
|
if self.episode_time_in_s > self.control_time_s:
|
||||||
|
# Terminated = True
|
||||||
|
terminated = True
|
||||||
|
return obs, reward, terminated, truncated, info
|
||||||
|
|
||||||
|
def reset(self, seed=None, options=None):
|
||||||
|
self.episode_time_in_s = 0.0
|
||||||
|
self.last_timestamp = time.perf_counter()
|
||||||
|
return self.env.reset(seed=seed, options=options)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageCropResizeWrapper(gym.Wrapper):
|
||||||
def __init__(self, env, crop_params_dict: Dict[str, Annotated[Tuple[int], 4]], resize_size=None):
|
def __init__(self, env, crop_params_dict: Dict[str, Annotated[Tuple[int], 4]], resize_size=None):
|
||||||
self.env = env
|
self.env = env
|
||||||
self.crop_params_dict = crop_params_dict
|
self.crop_params_dict = crop_params_dict
|
||||||
|
@ -260,6 +305,131 @@ class ConvertToLeRobotObservation(gym.ObservationWrapper):
|
||||||
return observation
|
return observation
|
||||||
|
|
||||||
|
|
||||||
|
class KeyboardInterfaceWrapper(gym.Wrapper):
|
||||||
|
def __init__(self, env):
|
||||||
|
super().__init__(env)
|
||||||
|
self.listener = None
|
||||||
|
self.events = {
|
||||||
|
"exit_early": False,
|
||||||
|
"pause_policy": False,
|
||||||
|
"reset_env": False,
|
||||||
|
"human_intervention_step": False,
|
||||||
|
}
|
||||||
|
self.event_lock = Lock() # Thread-safe access to events
|
||||||
|
self._init_keyboard_listener()
|
||||||
|
|
||||||
|
def _init_keyboard_listener(self):
|
||||||
|
"""Initialize keyboard listener if not in headless mode"""
|
||||||
|
|
||||||
|
if is_headless():
|
||||||
|
logging.warning(
|
||||||
|
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
from pynput import keyboard
|
||||||
|
|
||||||
|
def on_press(key):
|
||||||
|
with self.event_lock:
|
||||||
|
try:
|
||||||
|
if key == keyboard.Key.right or key == keyboard.Key.esc:
|
||||||
|
print("Right arrow key pressed. Exiting loop...")
|
||||||
|
self.events["exit_early"] = True
|
||||||
|
elif key == keyboard.Key.space:
|
||||||
|
if not self.events["pause_policy"]:
|
||||||
|
print(
|
||||||
|
"Space key pressed. Human intervention required.\n"
|
||||||
|
"Place the leader in similar pose to the follower and press space again."
|
||||||
|
)
|
||||||
|
self.events["pause_policy"] = True
|
||||||
|
log_say("Human intervention stage. Get ready to take over.", play_sounds=True)
|
||||||
|
elif self.events["pause_policy"] and not self.events["human_intervention_step"]:
|
||||||
|
self.events["human_intervention_step"] = True
|
||||||
|
print("Space key pressed. Human intervention starting.")
|
||||||
|
log_say("Starting human intervention.", play_sounds=True)
|
||||||
|
else:
|
||||||
|
self.events["pause_policy"] = False
|
||||||
|
self.events["human_intervention_step"] = False
|
||||||
|
print("Space key pressed for a third time.")
|
||||||
|
log_say("Continuing with policy actions.", play_sounds=True)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error handling key press: {e}")
|
||||||
|
|
||||||
|
self.listener = keyboard.Listener(on_press=on_press)
|
||||||
|
self.listener.start()
|
||||||
|
except ImportError:
|
||||||
|
logging.warning("Could not import pynput. Keyboard interface will not be available.")
|
||||||
|
self.listener = None
|
||||||
|
|
||||||
|
def step(self, action: Any) -> Tuple[Any, float, bool, bool, Dict]:
|
||||||
|
is_intervention = False
|
||||||
|
terminated_by_keyboard = False
|
||||||
|
|
||||||
|
# Extract policy_action if needed
|
||||||
|
if isinstance(self.env.action_space, gym.spaces.Tuple):
|
||||||
|
policy_action = action[0]
|
||||||
|
|
||||||
|
# Check the event flags without holding the lock for too long.
|
||||||
|
with self.event_lock:
|
||||||
|
if self.events["exit_early"]:
|
||||||
|
terminated_by_keyboard = True
|
||||||
|
# If we need to wait for human intervention, we note that outside the lock.
|
||||||
|
pause_policy = self.events["pause_policy"]
|
||||||
|
|
||||||
|
if pause_policy:
|
||||||
|
# Now, wait for human_intervention_step without holding the lock
|
||||||
|
while True:
|
||||||
|
with self.event_lock:
|
||||||
|
if self.events["human_intervention_step"]:
|
||||||
|
is_intervention = True
|
||||||
|
break
|
||||||
|
time.sleep(0.1) # Check more frequently if desired
|
||||||
|
|
||||||
|
# Execute the step in the underlying environment
|
||||||
|
obs, reward, terminated, truncated, info = self.env.step((policy_action, is_intervention))
|
||||||
|
return obs, reward, terminated or terminated_by_keyboard, truncated, info
|
||||||
|
|
||||||
|
def reset(self, **kwargs) -> Tuple[Any, Dict]:
|
||||||
|
"""
|
||||||
|
Reset the environment and clear any pending events
|
||||||
|
"""
|
||||||
|
with self.event_lock:
|
||||||
|
self.events = {k: False for k in self.events}
|
||||||
|
return self.env.reset(**kwargs)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""
|
||||||
|
Properly clean up the keyboard listener when the environment is closed
|
||||||
|
"""
|
||||||
|
if self.listener is not None:
|
||||||
|
self.listener.stop()
|
||||||
|
super().close()
|
||||||
|
|
||||||
|
|
||||||
|
class ResetWrapper(gym.Wrapper):
|
||||||
|
def __init__(
|
||||||
|
self, env: HILSerlRobotEnv, reset_fn: Optional[Callable[[], None]] = None, reset_time_s: float = 5
|
||||||
|
):
|
||||||
|
super().__init__(env)
|
||||||
|
self.reset_fn = reset_fn
|
||||||
|
self.reset_time_s = reset_time_s
|
||||||
|
|
||||||
|
self.robot = self.unwrapped.robot
|
||||||
|
self.init_pos = self.unwrapped.initial_follower_position
|
||||||
|
|
||||||
|
def reset(self, *, seed=None, options=None):
|
||||||
|
if self.reset_fn is not None:
|
||||||
|
self.reset_fn(self.env)
|
||||||
|
else:
|
||||||
|
log_say(f"Manually reset the environment for {self.reset_time_s} seconds.", play_sounds=True)
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
while time.perf_counter() - start_time < self.reset_time_s:
|
||||||
|
self.robot.teleop_step()
|
||||||
|
|
||||||
|
log_say("Manual reseting of the environment done.", play_sounds=True)
|
||||||
|
return super().reset(seed=seed, options=options)
|
||||||
|
|
||||||
|
|
||||||
def make_robot_env(
|
def make_robot_env(
|
||||||
robot,
|
robot,
|
||||||
reward_classifier,
|
reward_classifier,
|
||||||
|
@ -270,18 +440,27 @@ def make_robot_env(
|
||||||
display_cameras=False,
|
display_cameras=False,
|
||||||
device="cuda:0",
|
device="cuda:0",
|
||||||
resize_size=None,
|
resize_size=None,
|
||||||
|
reset_time_s=10,
|
||||||
|
delta_action=0.1,
|
||||||
|
nb_repeats=1,
|
||||||
|
use_relative_joint_positions=False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Factory function to create the robot environment.
|
Factory function to create the robot environment.
|
||||||
|
|
||||||
Mimics gym.make() for consistent environment creation.
|
Mimics gym.make() for consistent environment creation.
|
||||||
"""
|
"""
|
||||||
env = HILSerlRobotEnv(robot, reset_follower_pos, display_cameras)
|
env = HILSerlRobotEnv(robot, display_cameras)
|
||||||
env = ConvertToLeRobotObservation(env, device)
|
env = ConvertToLeRobotObservation(env, device)
|
||||||
if crop_params_dict is not None:
|
# if crop_params_dict is not None:
|
||||||
env = HILSerlImageCropResizeWrapper(env, crop_params_dict, resize_size=resize_size)
|
# env = ImageCropResizeWrapper(env, crop_params_dict, resize_size=resize_size)
|
||||||
env = HILSerlRewardWrapper(env, reward_classifier)
|
# env = RewardWrapper(env, reward_classifier)
|
||||||
env = HILSerlTimeLimitWrapper(env, control_time_s, fps)
|
env = TimeLimitWrapper(env, control_time_s, fps)
|
||||||
|
# if use_relative_joint_positions:
|
||||||
|
# env = RelativeJointPositionActionWrapper(env, delta=delta_action)
|
||||||
|
# env = ActionRepeatWrapper(env, nb_repeat=nb_repeats)
|
||||||
|
env = KeyboardInterfaceWrapper(env)
|
||||||
|
env = ResetWrapper(env, reset_fn=None, reset_time_s=reset_time_s)
|
||||||
return env
|
return env
|
||||||
|
|
||||||
|
|
||||||
|
@ -369,12 +548,37 @@ if __name__ == "__main__":
|
||||||
args.reset_follower_pos,
|
args.reset_follower_pos,
|
||||||
args.display_cameras,
|
args.display_cameras,
|
||||||
device="mps",
|
device="mps",
|
||||||
|
resize_size=None,
|
||||||
|
reset_time_s=10,
|
||||||
|
delta_action=0.1,
|
||||||
|
nb_repeats=1,
|
||||||
|
use_relative_joint_positions=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
env.reset()
|
env.reset()
|
||||||
|
init_pos = env.unwrapped.initial_follower_position
|
||||||
|
goal_pos = init_pos
|
||||||
|
|
||||||
|
right_goal = init_pos.copy()
|
||||||
|
right_goal[0] += 50
|
||||||
|
|
||||||
|
left_goal = init_pos.copy()
|
||||||
|
left_goal[0] -= 50
|
||||||
|
|
||||||
|
# Michel is a beast
|
||||||
|
pitch_angle = np.linspace(left_goal[0], right_goal[0], 1000)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
intervention_action = (None, True)
|
for i in range(len(pitch_angle)):
|
||||||
obs, reward, terminated, truncated, info = env.step(intervention_action)
|
goal_pos[0] = pitch_angle[i]
|
||||||
if terminated or truncated:
|
obs, reward, terminated, truncated, info = env.step((torch.from_numpy(goal_pos), False))
|
||||||
logging.info("Max control time reached, reset environment.")
|
if terminated or truncated:
|
||||||
env.reset()
|
logging.info("Max control time reached, reset environment.")
|
||||||
|
env.reset()
|
||||||
|
|
||||||
|
for i in reversed(range(len(pitch_angle))):
|
||||||
|
goal_pos[0] = pitch_angle[i]
|
||||||
|
obs, reward, terminated, truncated, info = env.step((torch.from_numpy(goal_pos), False))
|
||||||
|
if terminated or truncated:
|
||||||
|
logging.info("Max control time reached, reset environment.")
|
||||||
|
env.reset()
|
||||||
|
|
Loading…
Reference in New Issue