858 lines
35 KiB
Python
858 lines
35 KiB
Python
import argparse
|
|
import logging
|
|
import time
|
|
from threading import Lock
|
|
from typing import Annotated, Any, Callable, Dict, Optional, Tuple
|
|
|
|
import gymnasium as gym
|
|
import numpy as np
|
|
import torch
|
|
import torchvision.transforms.functional as F # noqa: N812
|
|
|
|
from lerobot.common.envs.utils import preprocess_observation
|
|
from lerobot.common.robot_devices.control_utils import busy_wait, is_headless, reset_follower_position
|
|
from lerobot.common.robot_devices.robots.factory import make_robot
|
|
from lerobot.common.utils.utils import init_hydra_config, log_say
|
|
from lerobot.scripts.server.maniskill_manipulator import make_maniskill
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
|
|
class HILSerlRobotEnv(gym.Env):
|
|
"""
|
|
Gym-compatible environment for evaluating robotic control policies with integrated human intervention.
|
|
|
|
This environment wraps a robot interface to provide a consistent API for policy evaluation. It supports both relative (delta)
|
|
and absolute joint position commands and automatically configures its observation and action spaces based on the robot's
|
|
sensors and configuration.
|
|
|
|
The environment can switch between executing actions from a policy or using teleoperated actions (human intervention) during
|
|
each step. When teleoperation is used, the override action is captured and returned in the `info` dict along with a flag
|
|
`is_intervention`.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
robot,
|
|
use_delta_action_space: bool = True,
|
|
delta: float | None = None,
|
|
display_cameras: bool = False,
|
|
):
|
|
"""
|
|
Initialize the HILSerlRobotEnv environment.
|
|
|
|
The environment is set up with a robot interface, which is used to capture observations and send joint commands. The setup
|
|
supports both relative (delta) adjustments and absolute joint positions for controlling the robot.
|
|
|
|
Args:
|
|
robot: The robot interface object used to connect and interact with the physical robot.
|
|
use_delta_action_space (bool): If True, uses a delta (relative) action space for joint control. Otherwise, absolute
|
|
joint positions are used.
|
|
delta (float or None): A scaling factor for the relative adjustments applied to joint positions. Should be a value between
|
|
0 and 1 when using a delta action space.
|
|
display_cameras (bool): If True, the robot's camera feeds will be displayed during execution.
|
|
"""
|
|
super().__init__()
|
|
|
|
self.robot = robot
|
|
self.display_cameras = display_cameras
|
|
|
|
# Connect to the robot if not already connected.
|
|
if not self.robot.is_connected:
|
|
self.robot.connect()
|
|
|
|
self.initial_follower_position = robot.follower_arms["main"].read("Present_Position")
|
|
|
|
# Episode tracking.
|
|
self.current_step = 0
|
|
self.episode_data = None
|
|
|
|
self.delta = delta
|
|
self.use_delta_action_space = use_delta_action_space
|
|
self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position")
|
|
|
|
# Retrieve the size of the joint position interval bound.
|
|
self.relative_bounds_size = (
|
|
self.robot.config.joint_position_relative_bounds["max"]
|
|
- self.robot.config.joint_position_relative_bounds["min"]
|
|
)
|
|
|
|
self.delta_relative_bounds_size = self.relative_bounds_size * self.delta
|
|
|
|
self.robot.config.max_relative_target = self.delta_relative_bounds_size.float()
|
|
|
|
# Dynamically configure the observation and action spaces.
|
|
self._setup_spaces()
|
|
|
|
def _setup_spaces(self):
|
|
"""
|
|
Dynamically configure the observation and action spaces based on the robot's capabilities.
|
|
|
|
Observation Space:
|
|
- For keys with "image": A Box space with pixel values ranging from 0 to 255.
|
|
- For non-image keys: A nested Dict space is created under 'observation.state' with a suitable range.
|
|
|
|
Action Space:
|
|
- The action space is defined as a Tuple where:
|
|
• The first element is a Box space representing joint position commands. It is defined as relative (delta)
|
|
or absolute, based on the configuration.
|
|
• The second element is a Discrete space (with 2 values) serving as a flag for intervention (teleoperation).
|
|
"""
|
|
example_obs = self.robot.capture_observation()
|
|
|
|
# Define observation spaces for images and other states.
|
|
image_keys = [key for key in example_obs if "image" in key]
|
|
state_keys = [key for key in example_obs if "image" not in key]
|
|
observation_spaces = {
|
|
key: gym.spaces.Box(low=0, high=255, shape=example_obs[key].shape, dtype=np.uint8)
|
|
for key in image_keys
|
|
}
|
|
observation_spaces["observation.state"] = gym.spaces.Dict(
|
|
{
|
|
key: gym.spaces.Box(low=0, high=10, shape=example_obs[key].shape, dtype=np.float32)
|
|
for key in state_keys
|
|
}
|
|
)
|
|
|
|
self.observation_space = gym.spaces.Dict(observation_spaces)
|
|
|
|
# Define the action space for joint positions along with setting an intervention flag.
|
|
action_dim = len(self.robot.follower_arms["main"].read("Present_Position"))
|
|
if self.use_delta_action_space:
|
|
action_space_robot = gym.spaces.Box(
|
|
low=-self.relative_bounds_size.cpu().numpy(),
|
|
high=self.relative_bounds_size.cpu().numpy(),
|
|
shape=(action_dim,),
|
|
dtype=np.float32,
|
|
)
|
|
else:
|
|
action_space_robot = gym.spaces.Box(
|
|
low=self.robot.config.joint_position_relative_bounds["min"].cpu().numpy(),
|
|
high=self.robot.config.joint_position_relative_bounds["max"].cpu().numpy(),
|
|
shape=(action_dim,),
|
|
dtype=np.float32,
|
|
)
|
|
|
|
self.action_space = gym.spaces.Tuple(
|
|
(
|
|
action_space_robot,
|
|
gym.spaces.Discrete(2),
|
|
),
|
|
)
|
|
|
|
def reset(self, seed=None, options=None) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]:
|
|
"""
|
|
Reset the environment to its initial state.
|
|
This method resets the step counter and clears any episodic data.
|
|
|
|
Args:
|
|
seed (Optional[int]): A seed for random number generation to ensure reproducibility.
|
|
options (Optional[dict]): Additional options to influence the reset behavior.
|
|
|
|
Returns:
|
|
A tuple containing:
|
|
- observation (dict): The initial sensor observation.
|
|
- info (dict): A dictionary with supplementary information, including the key "initial_position".
|
|
"""
|
|
super().reset(seed=seed, options=options)
|
|
|
|
# Capture the initial observation.
|
|
observation = self.robot.capture_observation()
|
|
|
|
# Reset episode tracking variables.
|
|
self.current_step = 0
|
|
self.episode_data = None
|
|
|
|
return observation, {"initial_position": self.initial_follower_position}
|
|
|
|
def step(
|
|
self, action: Tuple[np.ndarray, bool]
|
|
) -> Tuple[Dict[str, np.ndarray], float, bool, bool, Dict[str, Any]]:
|
|
"""
|
|
Execute a single step within the environment using the specified action.
|
|
|
|
The provided action is a tuple comprised of:
|
|
• A policy action (joint position commands) that may be either in absolute values or as a delta.
|
|
• A boolean flag indicating whether teleoperation (human intervention) should be used for this step.
|
|
|
|
Behavior:
|
|
- When the intervention flag is False, the environment processes and sends the policy action to the robot.
|
|
- When True, a teleoperation step is executed. If using a delta action space, an absolute teleop action is converted
|
|
to relative change based on the current joint positions.
|
|
|
|
Args:
|
|
action (tuple): A tuple with two elements:
|
|
- policy_action (np.ndarray or torch.Tensor): The commanded joint positions.
|
|
- intervention_bool (bool): True if the human operator intervenes by providing a teleoperation input.
|
|
|
|
Returns:
|
|
tuple: A tuple containing:
|
|
- observation (dict): The new sensor observation after taking the step.
|
|
- reward (float): The step reward (default is 0.0 within this wrapper).
|
|
- terminated (bool): True if the episode has reached a terminal state.
|
|
- truncated (bool): True if the episode was truncated (e.g., time constraints).
|
|
- info (dict): Additional debugging information including:
|
|
◦ "action_intervention": The teleop action if intervention was used.
|
|
◦ "is_intervention": Flag indicating whether teleoperation was employed.
|
|
"""
|
|
policy_action, intervention_bool = action
|
|
teleop_action = None
|
|
self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position")
|
|
if isinstance(policy_action, torch.Tensor):
|
|
policy_action = policy_action.cpu().numpy()
|
|
policy_action = np.clip(policy_action, self.action_space[0].low, self.action_space[0].high)
|
|
if not intervention_bool:
|
|
if self.use_delta_action_space:
|
|
target_joint_positions = self.current_joint_positions + self.delta * policy_action
|
|
else:
|
|
target_joint_positions = policy_action
|
|
self.robot.send_action(torch.from_numpy(target_joint_positions))
|
|
observation = self.robot.capture_observation()
|
|
else:
|
|
observation, teleop_action = self.robot.teleop_step(record_data=True)
|
|
teleop_action = teleop_action["action"] # Convert tensor to appropriate format
|
|
|
|
# When applying the delta action space, convert teleop absolute values to relative differences.
|
|
if self.use_delta_action_space:
|
|
teleop_action = (teleop_action - self.current_joint_positions) / self.delta
|
|
if torch.any(teleop_action < -self.relative_bounds_size) and torch.any(
|
|
teleop_action > self.relative_bounds_size
|
|
):
|
|
logging.debug(
|
|
f"Relative teleop delta exceeded bounds {self.relative_bounds_size}, teleop_action {teleop_action}\n"
|
|
f"lower bounds condition {teleop_action < -self.relative_bounds_size}\n"
|
|
f"upper bounds condition {teleop_action > self.relative_bounds_size}"
|
|
)
|
|
|
|
teleop_action = torch.clamp(
|
|
teleop_action, -self.relative_bounds_size, self.relative_bounds_size
|
|
)
|
|
# NOTE: To mimic the shape of a neural network output, we add a batch dimension to the teleop action.
|
|
if teleop_action.dim() == 1:
|
|
teleop_action = teleop_action.unsqueeze(0)
|
|
|
|
# self.render()
|
|
|
|
self.current_step += 1
|
|
|
|
reward = 0.0
|
|
terminated = False
|
|
truncated = False
|
|
|
|
return (
|
|
observation,
|
|
reward,
|
|
terminated,
|
|
truncated,
|
|
{"action_intervention": teleop_action, "is_intervention": teleop_action is not None},
|
|
)
|
|
|
|
def render(self):
|
|
"""
|
|
Render the current state of the environment by displaying the robot's camera feeds.
|
|
"""
|
|
import cv2
|
|
|
|
observation = self.robot.capture_observation()
|
|
image_keys = [key for key in observation if "image" in key]
|
|
|
|
for key in image_keys:
|
|
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
|
cv2.waitKey(1)
|
|
|
|
def close(self):
|
|
"""
|
|
Close the environment and clean up resources by disconnecting the robot.
|
|
|
|
If the robot is currently connected, this method properly terminates the connection to ensure that all
|
|
associated resources are released.
|
|
"""
|
|
if self.robot.is_connected:
|
|
self.robot.disconnect()
|
|
|
|
|
|
class ActionRepeatWrapper(gym.Wrapper):
|
|
def __init__(self, env, nb_repeat: int = 1):
|
|
super().__init__(env)
|
|
self.nb_repeat = nb_repeat
|
|
|
|
def step(self, action):
|
|
for _ in range(self.nb_repeat):
|
|
obs, reward, done, truncated, info = self.env.step(action)
|
|
if done or truncated:
|
|
break
|
|
return obs, reward, done, truncated, info
|
|
|
|
|
|
class RewardWrapper(gym.Wrapper):
|
|
def __init__(self, env, reward_classifier, device: torch.device = "cuda"):
|
|
"""
|
|
Wrapper to add reward prediction to the environment, it use a trained classifer.
|
|
|
|
Args:
|
|
env: The environment to wrap
|
|
reward_classifier: The reward classifier model
|
|
device: The device to run the model on
|
|
"""
|
|
self.env = env
|
|
|
|
# NOTE: We got 15% speedup by compiling the model
|
|
self.reward_classifier = torch.compile(reward_classifier)
|
|
|
|
if isinstance(device, str):
|
|
device = torch.device(device)
|
|
self.device = device
|
|
|
|
def step(self, action):
|
|
observation, _, terminated, truncated, info = self.env.step(action)
|
|
images = [
|
|
observation[key].to(self.device, non_blocking=self.device.type == "cuda")
|
|
for key in observation
|
|
if "image" in key
|
|
]
|
|
start_time = time.perf_counter()
|
|
with torch.inference_mode():
|
|
reward = (
|
|
self.reward_classifier.predict_reward(images, threshold=0.8)
|
|
if self.reward_classifier is not None
|
|
else 0.0
|
|
)
|
|
info["Reward classifer frequency"] = 1 / (time.perf_counter() - start_time)
|
|
|
|
# logging.info(f"Reward: {reward}")
|
|
|
|
if reward == 1.0:
|
|
terminated = True
|
|
return observation, reward, terminated, truncated, info
|
|
|
|
def reset(self, seed=None, options=None):
|
|
return self.env.reset(seed=seed, options=options)
|
|
|
|
|
|
class JointMaskingActionSpace(gym.Wrapper):
|
|
def __init__(self, env, mask):
|
|
"""
|
|
Wrapper to mask out dimensions of the action space.
|
|
|
|
Args:
|
|
env: The environment to wrap
|
|
mask: Binary mask array where 0 indicates dimensions to remove
|
|
"""
|
|
super().__init__(env)
|
|
|
|
# Validate mask matches action space
|
|
|
|
# Keep only dimensions where mask is 1
|
|
self.active_dims = np.where(mask)[0]
|
|
|
|
if isinstance(env.action_space, gym.spaces.Box):
|
|
if len(mask) != env.action_space.shape[0]:
|
|
raise ValueError("Mask length must match action space dimensions")
|
|
low = env.action_space.low[self.active_dims]
|
|
high = env.action_space.high[self.active_dims]
|
|
self.action_space = gym.spaces.Box(low=low, high=high, dtype=env.action_space.dtype)
|
|
|
|
if isinstance(env.action_space, gym.spaces.Tuple):
|
|
if len(mask) != env.action_space[0].shape[0]:
|
|
raise ValueError("Mask length must match action space 0 dimensions")
|
|
|
|
low = env.action_space[0].low[self.active_dims]
|
|
high = env.action_space[0].high[self.active_dims]
|
|
action_space_masked = gym.spaces.Box(low=low, high=high, dtype=env.action_space[0].dtype)
|
|
self.action_space = gym.spaces.Tuple((action_space_masked, env.action_space[1]))
|
|
# Create new action space with masked dimensions
|
|
|
|
def action(self, action):
|
|
"""
|
|
Convert masked action back to full action space.
|
|
|
|
Args:
|
|
action: Action in masked space. For Tuple spaces, the first element is masked.
|
|
|
|
Returns:
|
|
Action in original space with masked dims set to 0.
|
|
"""
|
|
|
|
# Determine whether we are handling a Tuple space or a Box.
|
|
if isinstance(self.env.action_space, gym.spaces.Tuple):
|
|
# Extract the masked component from the tuple.
|
|
masked_action = action[0] if isinstance(action, tuple) else action
|
|
# Create a full action for the Box element.
|
|
full_box_action = np.zeros(self.env.action_space[0].shape, dtype=self.env.action_space[0].dtype)
|
|
full_box_action[self.active_dims] = masked_action
|
|
# Return a tuple with the reconstructed Box action and the unchanged remainder.
|
|
return (full_box_action, action[1])
|
|
else:
|
|
# For Box action spaces.
|
|
masked_action = action if not isinstance(action, tuple) else action[0]
|
|
full_action = np.zeros(self.env.action_space.shape, dtype=self.env.action_space.dtype)
|
|
full_action[self.active_dims] = masked_action
|
|
return full_action
|
|
|
|
def step(self, action):
|
|
action = self.action(action)
|
|
obs, reward, terminated, truncated, info = self.env.step(action)
|
|
if "action_intervention" in info and info["action_intervention"] is not None:
|
|
if info["action_intervention"].dim() == 1:
|
|
info["action_intervention"] = info["action_intervention"][self.active_dims]
|
|
else:
|
|
info["action_intervention"] = info["action_intervention"][:, self.active_dims]
|
|
return obs, reward, terminated, truncated, info
|
|
|
|
|
|
class TimeLimitWrapper(gym.Wrapper):
|
|
def __init__(self, env, control_time_s, fps):
|
|
self.env = env
|
|
self.control_time_s = control_time_s
|
|
self.fps = fps
|
|
|
|
self.last_timestamp = 0.0
|
|
self.episode_time_in_s = 0.0
|
|
|
|
self.max_episode_steps = int(self.control_time_s * self.fps)
|
|
|
|
self.current_step = 0
|
|
|
|
def step(self, action):
|
|
obs, reward, terminated, truncated, info = self.env.step(action)
|
|
time_since_last_step = time.perf_counter() - self.last_timestamp
|
|
self.episode_time_in_s += time_since_last_step
|
|
self.last_timestamp = time.perf_counter()
|
|
self.current_step += 1
|
|
# check if last timestep took more time than the expected fps
|
|
if 1.0 / time_since_last_step < self.fps:
|
|
logging.debug(f"Current timestep exceeded expected fps {self.fps}")
|
|
|
|
if self.episode_time_in_s > self.control_time_s:
|
|
# if self.current_step >= self.max_episode_steps:
|
|
# Terminated = True
|
|
terminated = True
|
|
return obs, reward, terminated, truncated, info
|
|
|
|
def reset(self, seed=None, options=None):
|
|
self.episode_time_in_s = 0.0
|
|
self.last_timestamp = time.perf_counter()
|
|
self.current_step = 0
|
|
return self.env.reset(seed=seed, options=options)
|
|
|
|
|
|
class ImageCropResizeWrapper(gym.Wrapper):
|
|
def __init__(self, env, crop_params_dict: Dict[str, Annotated[Tuple[int], 4]], resize_size=None):
|
|
super().__init__(env)
|
|
self.env = env
|
|
self.crop_params_dict = crop_params_dict
|
|
print(f"obs_keys , {self.env.observation_space}")
|
|
print(f"crop params dict {crop_params_dict.keys()}")
|
|
for key_crop in crop_params_dict:
|
|
if key_crop not in self.env.observation_space.keys(): # noqa: SIM118
|
|
raise ValueError(f"Key {key_crop} not in observation space")
|
|
for key in crop_params_dict:
|
|
top, left, height, width = crop_params_dict[key]
|
|
new_shape = (top + height, left + width)
|
|
self.observation_space[key] = gym.spaces.Box(low=0, high=255, shape=new_shape)
|
|
|
|
self.resize_size = resize_size
|
|
if self.resize_size is None:
|
|
self.resize_size = (128, 128)
|
|
|
|
def step(self, action):
|
|
obs, reward, terminated, truncated, info = self.env.step(action)
|
|
for k in self.crop_params_dict:
|
|
device = obs[k].device
|
|
|
|
# Check for NaNs before processing
|
|
if torch.isnan(obs[k]).any():
|
|
logging.error(f"NaN values detected in observation {k} before crop and resize")
|
|
|
|
if device == torch.device("mps:0"):
|
|
obs[k] = obs[k].cpu()
|
|
|
|
obs[k] = F.crop(obs[k], *self.crop_params_dict[k])
|
|
obs[k] = F.resize(obs[k], self.resize_size)
|
|
|
|
# Check for NaNs after processing
|
|
if torch.isnan(obs[k]).any():
|
|
logging.error(f"NaN values detected in observation {k} after crop and resize")
|
|
|
|
obs[k] = obs[k].to(device)
|
|
|
|
return obs, reward, terminated, truncated, info
|
|
|
|
def reset(self, seed=None, options=None):
|
|
obs, info = self.env.reset(seed=seed, options=options)
|
|
for k in self.crop_params_dict:
|
|
device = obs[k].device
|
|
if device == torch.device("mps:0"):
|
|
obs[k] = obs[k].cpu()
|
|
obs[k] = F.crop(obs[k], *self.crop_params_dict[k])
|
|
obs[k] = F.resize(obs[k], self.resize_size)
|
|
obs[k] = obs[k].to(device)
|
|
return obs, info
|
|
|
|
|
|
class ConvertToLeRobotObservation(gym.ObservationWrapper):
|
|
def __init__(self, env, device):
|
|
super().__init__(env)
|
|
|
|
if isinstance(device, str):
|
|
device = torch.device(device)
|
|
self.device = device
|
|
|
|
def observation(self, observation):
|
|
observation = preprocess_observation(observation)
|
|
|
|
observation = {
|
|
key: observation[key].to(self.device, non_blocking=self.device.type == "cuda")
|
|
for key in observation
|
|
}
|
|
observation = {k: torch.tensor(v, device=self.device) for k, v in observation.items()}
|
|
return observation
|
|
|
|
|
|
class KeyboardInterfaceWrapper(gym.Wrapper):
|
|
def __init__(self, env):
|
|
super().__init__(env)
|
|
self.listener = None
|
|
self.events = {
|
|
"exit_early": False,
|
|
"pause_policy": False,
|
|
"reset_env": False,
|
|
"human_intervention_step": False,
|
|
"episode_success": False,
|
|
}
|
|
self.event_lock = Lock() # Thread-safe access to events
|
|
self._init_keyboard_listener()
|
|
|
|
def _init_keyboard_listener(self):
|
|
"""Initialize keyboard listener if not in headless mode"""
|
|
|
|
if is_headless():
|
|
logging.warning(
|
|
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
|
|
)
|
|
return
|
|
try:
|
|
from pynput import keyboard
|
|
|
|
def on_press(key):
|
|
with self.event_lock:
|
|
try:
|
|
if key == keyboard.Key.right or key == keyboard.Key.esc:
|
|
print("Right arrow key pressed. Exiting loop...")
|
|
self.events["exit_early"] = True
|
|
return
|
|
if hasattr(key, "char") and key.char == "s":
|
|
print("Key 's' pressed. Episode success triggered.")
|
|
self.events["episode_success"] = True
|
|
return
|
|
if key == keyboard.Key.space and not self.events["exit_early"]:
|
|
if not self.events["pause_policy"]:
|
|
print(
|
|
"Space key pressed. Human intervention required.\n"
|
|
"Place the leader in similar pose to the follower and press space again."
|
|
)
|
|
self.events["pause_policy"] = True
|
|
log_say("Human intervention stage. Get ready to take over.", play_sounds=True)
|
|
return
|
|
if self.events["pause_policy"] and not self.events["human_intervention_step"]:
|
|
self.events["human_intervention_step"] = True
|
|
print("Space key pressed. Human intervention starting.")
|
|
log_say("Starting human intervention.", play_sounds=True)
|
|
return
|
|
if self.events["pause_policy"] and self.events["human_intervention_step"]:
|
|
self.events["pause_policy"] = False
|
|
self.events["human_intervention_step"] = False
|
|
print("Space key pressed for a third time.")
|
|
log_say("Continuing with policy actions.", play_sounds=True)
|
|
return
|
|
except Exception as e:
|
|
print(f"Error handling key press: {e}")
|
|
|
|
self.listener = keyboard.Listener(on_press=on_press)
|
|
self.listener.start()
|
|
except ImportError:
|
|
logging.warning("Could not import pynput. Keyboard interface will not be available.")
|
|
self.listener = None
|
|
|
|
def step(self, action: Any) -> Tuple[Any, float, bool, bool, Dict]:
|
|
is_intervention = False
|
|
terminated_by_keyboard = False
|
|
|
|
# Extract policy_action if needed
|
|
if isinstance(self.env.action_space, gym.spaces.Tuple):
|
|
policy_action = action[0]
|
|
|
|
# Check the event flags without holding the lock for too long.
|
|
with self.event_lock:
|
|
if self.events["exit_early"]:
|
|
terminated_by_keyboard = True
|
|
pause_policy = self.events["pause_policy"]
|
|
|
|
if pause_policy:
|
|
# Now, wait for human_intervention_step without holding the lock
|
|
while True:
|
|
with self.event_lock:
|
|
if self.events["human_intervention_step"]:
|
|
is_intervention = True
|
|
break
|
|
time.sleep(0.1) # Check more frequently if desired
|
|
|
|
# Execute the step in the underlying environment
|
|
obs, reward, terminated, truncated, info = self.env.step((policy_action, is_intervention))
|
|
|
|
# Override reward and termination if episode success event triggered
|
|
with self.event_lock:
|
|
if self.events["episode_success"]:
|
|
reward = 1
|
|
terminated_by_keyboard = True
|
|
|
|
return obs, reward, terminated or terminated_by_keyboard, truncated, info
|
|
|
|
def reset(self, **kwargs) -> Tuple[Any, Dict]:
|
|
"""
|
|
Reset the environment and clear any pending events
|
|
"""
|
|
with self.event_lock:
|
|
self.events = {k: False for k in self.events}
|
|
return self.env.reset(**kwargs)
|
|
|
|
def close(self):
|
|
"""
|
|
Properly clean up the keyboard listener when the environment is closed
|
|
"""
|
|
if self.listener is not None:
|
|
self.listener.stop()
|
|
super().close()
|
|
|
|
|
|
class ResetWrapper(gym.Wrapper):
|
|
def __init__(
|
|
self, env: HILSerlRobotEnv, reset_fn: Optional[Callable[[], None]] = None, reset_time_s: float = 5
|
|
):
|
|
super().__init__(env)
|
|
self.reset_fn = reset_fn
|
|
self.reset_time_s = reset_time_s
|
|
|
|
self.robot = self.unwrapped.robot
|
|
self.init_pos = self.unwrapped.initial_follower_position
|
|
|
|
def reset(self, *, seed=None, options=None):
|
|
if self.reset_fn is not None:
|
|
self.reset_fn(self.env)
|
|
else:
|
|
log_say(f"Manually reset the environment for {self.reset_time_s} seconds.", play_sounds=True)
|
|
start_time = time.perf_counter()
|
|
while time.perf_counter() - start_time < self.reset_time_s:
|
|
self.robot.teleop_step()
|
|
|
|
log_say("Manual reseting of the environment done.", play_sounds=True)
|
|
return super().reset(seed=seed, options=options)
|
|
|
|
|
|
class BatchCompitableWrapper(gym.ObservationWrapper):
|
|
def __init__(self, env):
|
|
super().__init__(env)
|
|
|
|
def observation(self, observation: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
|
for key in observation:
|
|
if "image" in key and observation[key].dim() == 3:
|
|
observation[key] = observation[key].unsqueeze(0)
|
|
if "state" in key and observation[key].dim() == 1:
|
|
observation[key] = observation[key].unsqueeze(0)
|
|
return observation
|
|
|
|
|
|
# TODO: REMOVE TH
|
|
|
|
|
|
def make_robot_env(
|
|
robot,
|
|
reward_classifier,
|
|
cfg,
|
|
n_envs: int = 1,
|
|
) -> gym.vector.VectorEnv:
|
|
"""
|
|
Factory function to create a vectorized robot environment.
|
|
|
|
Args:
|
|
robot: Robot instance to control
|
|
reward_classifier: Classifier model for computing rewards
|
|
cfg: Configuration object containing environment parameters
|
|
n_envs: Number of environments to create in parallel. Defaults to 1.
|
|
|
|
Returns:
|
|
A vectorized gym environment with all the necessary wrappers applied.
|
|
"""
|
|
if "maniskill" in cfg.env.name:
|
|
logging.warning("WE SHOULD REMOVE THE MANISKILL BEFORE THE MERGE INTO MAIN")
|
|
env = make_maniskill(
|
|
cfg=cfg,
|
|
n_envs=1,
|
|
)
|
|
return env
|
|
# Create base environment
|
|
env = HILSerlRobotEnv(
|
|
robot=robot,
|
|
display_cameras=cfg.env.wrapper.display_cameras,
|
|
delta=cfg.env.wrapper.delta_action,
|
|
use_delta_action_space=cfg.env.wrapper.use_relative_joint_positions,
|
|
)
|
|
|
|
# Add observation and image processing
|
|
env = ConvertToLeRobotObservation(env=env, device=cfg.device)
|
|
if cfg.env.wrapper.crop_params_dict is not None:
|
|
env = ImageCropResizeWrapper(
|
|
env=env, crop_params_dict=cfg.env.wrapper.crop_params_dict, resize_size=cfg.env.wrapper.resize_size
|
|
)
|
|
|
|
# Add reward computation and control wrappers
|
|
env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
|
|
env = TimeLimitWrapper(env=env, control_time_s=cfg.env.wrapper.control_time_s, fps=cfg.fps)
|
|
env = KeyboardInterfaceWrapper(env=env)
|
|
env = ResetWrapper(env=env, reset_fn=None, reset_time_s=cfg.env.wrapper.reset_time_s)
|
|
env = JointMaskingActionSpace(env=env, mask=cfg.env.wrapper.joint_masking_action_space)
|
|
env = BatchCompitableWrapper(env=env)
|
|
|
|
return env
|
|
|
|
# batched version of the env that returns an observation of shape (b, c)
|
|
|
|
|
|
def get_classifier(pretrained_path, config_path, device="mps"):
|
|
if pretrained_path is None or config_path is None:
|
|
return None
|
|
|
|
from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg
|
|
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
|
|
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
|
|
|
cfg = init_hydra_config(config_path)
|
|
|
|
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
|
|
classifier_config.num_cameras = len(cfg.training.image_keys) # TODO automate these paths
|
|
model = Classifier(classifier_config)
|
|
model.load_state_dict(Classifier.from_pretrained(pretrained_path).state_dict())
|
|
model = model.to(device)
|
|
return model
|
|
|
|
|
|
def replay_episode(env, repo_id, root=None, episode=0):
|
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
|
|
|
local_files_only = root is not None
|
|
dataset = LeRobotDataset(repo_id, root=root, episodes=[episode], local_files_only=local_files_only)
|
|
actions = dataset.hf_dataset.select_columns("action")
|
|
|
|
for idx in range(dataset.num_frames):
|
|
start_episode_t = time.perf_counter()
|
|
|
|
action = actions[idx]["action"][:4]
|
|
print(action)
|
|
env.step((action / env.unwrapped.delta, False))
|
|
|
|
dt_s = time.perf_counter() - start_episode_t
|
|
busy_wait(1 / 10 - dt_s)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--fps", type=int, default=30, help="control frequency")
|
|
parser.add_argument(
|
|
"--robot-path",
|
|
type=str,
|
|
default="lerobot/configs/robot/koch.yaml",
|
|
help="Path to robot yaml file used to instantiate the robot using `make_robot` factory function.",
|
|
)
|
|
parser.add_argument(
|
|
"--robot-overrides",
|
|
type=str,
|
|
nargs="*",
|
|
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
|
|
)
|
|
parser.add_argument(
|
|
"-p",
|
|
"--pretrained-policy-name-or-path",
|
|
help=(
|
|
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
|
|
"saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch "
|
|
"(useful for debugging). This argument is mutually exclusive with `--config`."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--config",
|
|
help=(
|
|
"Path to a yaml config you want to use for initializing a policy from scratch (useful for "
|
|
"debugging). This argument is mutually exclusive with `--pretrained-policy-name-or-path` (`-p`)."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--display-cameras", help=("Whether to display the camera feed while the rollout is happening")
|
|
)
|
|
parser.add_argument(
|
|
"--reward-classifier-pretrained-path",
|
|
type=str,
|
|
default=None,
|
|
help="Path to the pretrained classifier weights.",
|
|
)
|
|
parser.add_argument(
|
|
"--reward-classifier-config-file",
|
|
type=str,
|
|
default=None,
|
|
help="Path to a yaml config file that is necessary to build the reward classifier model.",
|
|
)
|
|
parser.add_argument("--env-path", type=str, default=None, help="Path to the env yaml file")
|
|
parser.add_argument("--env-overrides", type=str, default=None, help="Overrides for the env yaml file")
|
|
parser.add_argument("--control-time-s", type=float, default=20, help="Maximum episode length in seconds")
|
|
parser.add_argument("--reset-follower-pos", type=int, default=1, help="Reset follower between episodes")
|
|
parser.add_argument("--replay-repo-id", type=str, default=None, help="Repo ID of the episode to replay")
|
|
parser.add_argument("--replay-root", type=str, default=None, help="Root of the dataset to replay")
|
|
parser.add_argument("--replay-episode", type=int, default=0, help="Episode to replay")
|
|
args = parser.parse_args()
|
|
|
|
robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides)
|
|
robot = make_robot(robot_cfg)
|
|
|
|
reward_classifier = get_classifier(
|
|
args.reward_classifier_pretrained_path, args.reward_classifier_config_file
|
|
)
|
|
user_relative_joint_positions = True
|
|
|
|
cfg = init_hydra_config(args.env_path, args.env_overrides)
|
|
env = make_robot_env(
|
|
robot,
|
|
reward_classifier,
|
|
cfg.env, # .wrapper,
|
|
)
|
|
|
|
env.reset()
|
|
|
|
if args.replay_repo_id is not None:
|
|
replay_episode(env, args.replay_repo_id, root=args.replay_root, episode=args.replay_episode)
|
|
exit()
|
|
|
|
# Retrieve the robot's action space for joint commands.
|
|
action_space_robot = env.action_space.spaces[0]
|
|
|
|
# Initialize the smoothed action as a random sample.
|
|
smoothed_action = action_space_robot.sample()
|
|
|
|
# Smoothing coefficient (alpha) defines how much of the new random sample to mix in.
|
|
# A value close to 0 makes the trajectory very smooth (slow to change), while a value close to 1 is less smooth.
|
|
alpha = 0.4
|
|
|
|
while True:
|
|
start_loop_s = time.perf_counter()
|
|
# Sample a new random action from the robot's action space.
|
|
new_random_action = action_space_robot.sample()
|
|
# Update the smoothed action using an exponential moving average.
|
|
smoothed_action = alpha * new_random_action + (1 - alpha) * smoothed_action
|
|
|
|
# Execute the step: wrap the NumPy action in a torch tensor.
|
|
obs, reward, terminated, truncated, info = env.step((torch.from_numpy(smoothed_action), False))
|
|
if terminated or truncated:
|
|
env.reset()
|
|
|
|
dt_s = time.perf_counter() - start_loop_s
|
|
busy_wait(1 / args.fps - dt_s)
|