- Added `lerobot/scripts/server/gym_manipulator.py` that contains all the necessary wrappers to run a gym-style env around the real robot.

- Added `lerobot/scripts/server/find_joint_limits.py` to test the min and max angles of the motion you wish the robot to explore during RL training.
- Added logic in `manipulator.py` to limit the maximum possible joint angles to allow motion within a predefined joint position range. The limits are specified in the yaml config for each robot. Checkout the so100.yaml.

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
Michel Aractingi 2025-02-06 16:29:37 +01:00 committed by AdilZouitine
parent 163bcbcad4
commit 729b4ed697
8 changed files with 812 additions and 29 deletions

View File

@ -126,27 +126,29 @@ class PixelWrapper(gym.Wrapper):
obs, reward, terminated, truncated, info = self.env.step(action) obs, reward, terminated, truncated, info = self.env.step(action)
return self._get_obs(obs), reward, terminated, truncated, info return self._get_obs(obs), reward, terminated, truncated, info
class ConvertToLeRobotEnv(gym.Wrapper): class ConvertToLeRobotEnv(gym.Wrapper):
def __init__(self, env, num_envs): def __init__(self, env, num_envs):
super().__init__(env) super().__init__(env)
def reset(self, seed=None, options=None): def reset(self, seed=None, options=None):
obs, info = self.env.reset(seed=seed, options={}) obs, info = self.env.reset(seed=seed, options={})
return self._get_obs(obs), info return self._get_obs(obs), info
def step(self, action): def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action) obs, reward, terminated, truncated, info = self.env.step(action)
return self._get_obs(obs), reward, terminated, truncated, info return self._get_obs(obs), reward, terminated, truncated, info
def _get_obs(self, observation): def _get_obs(self, observation):
sensor_data = observation.pop("sensor_data") sensor_data = observation.pop("sensor_data")
del observation["sensor_param"] del observation["sensor_param"]
images = [] images = []
for cam_data in sensor_data.values(): for cam_data in sensor_data.values():
images.append(cam_data["rgb"]) images.append(cam_data["rgb"])
images = torch.concat(images, axis=-1) images = torch.concat(images, axis=-1)
# flatten the rest of the data which should just be state data # flatten the rest of the data which should just be state data
observation = common.flatten_state_dict( observation = common.flatten_state_dict(observation, use_torch=True, device=self.base_env.device)
observation, use_torch=True, device=self.base_env.device
)
ret = dict() ret = dict()
ret["state"] = observation ret["state"] = observation
ret["pixels"] = images ret["pixels"] = images

View File

@ -36,11 +36,9 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
# TODO: You have to merge all tensors from agent key and extra key # TODO: You have to merge all tensors from agent key and extra key
# You don't keep sensor param key in the observation # You don't keep sensor param key in the observation
# And you keep sensor data rgb # And you keep sensor data rgb
if "pixels" in observations: for key, img in observations.items():
if isinstance(observations["pixels"], dict): if "images" not in key:
imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()} continue
else:
imgs = {"observation.image": observations["pixels"]}
for imgkey, img in imgs.items(): for imgkey, img in imgs.items():
# TODO(aliberts, rcadene): use transforms.ToTensor()? # TODO(aliberts, rcadene): use transforms.ToTensor()?
@ -50,15 +48,15 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
_, h, w, c = img.shape _, h, w, c = img.shape
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}" assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
# sanity check that images are uint8 # sanity check that images are uint8
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}" assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
# convert to channel first of type float32 in range [0,1] # convert to channel first of type float32 in range [0,1]
img = einops.rearrange(img, "b h w c -> b c h w").contiguous() img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
img = img.type(torch.float32) img = img.type(torch.float32)
img /= 255 img /= 255
return_observations[imgkey] = img return_observations[key] = img
# obs state agent qpos and qvel # obs state agent qpos and qvel
# image # image
@ -69,7 +67,8 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing # TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
# requirement for "agent_pos" # requirement for "agent_pos"
return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float() # return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float()
return_observations["observation.state"] = observations["observation.state"].float()
return return_observations return return_observations

View File

@ -47,7 +47,7 @@ class Classifier(
super().__init__() super().__init__()
self.config = config self.config = config
self.processor = AutoImageProcessor.from_pretrained(self.config.model_name, trust_remote_code=True) # self.processor = AutoImageProcessor.from_pretrained(self.config.model_name, trust_remote_code=True)
encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True) encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True)
# Extract vision model if we're given a multimodal model # Extract vision model if we're given a multimodal model
if hasattr(encoder, "vision_model"): if hasattr(encoder, "vision_model"):
@ -108,11 +108,12 @@ class Classifier(
def _get_encoder_output(self, x: torch.Tensor) -> torch.Tensor: def _get_encoder_output(self, x: torch.Tensor) -> torch.Tensor:
"""Extract the appropriate output from the encoder.""" """Extract the appropriate output from the encoder."""
# Process images with the processor (handles resizing and normalization) # Process images with the processor (handles resizing and normalization)
processed = self.processor( # processed = self.processor(
images=x, # LeRobotDataset already provides proper tensor format # images=x, # LeRobotDataset already provides proper tensor format
return_tensors="pt", # return_tensors="pt",
) # )
processed = processed["pixel_values"].to(x.device) # processed = processed["pixel_values"].to(x.device)
processed = x
with torch.no_grad(): with torch.no_grad():
if self.is_cnn: if self.is_cnn:
@ -146,6 +147,6 @@ class Classifier(
def predict_reward(self, x): def predict_reward(self, x):
if self.config.num_classes == 2: if self.config.num_classes == 2:
return (self.forward(x).probabilities > 0.5).float() return (self.forward(x).probabilities > 0.6).float()
else: else:
return torch.argmax(self.forward(x).probabilities, dim=1) return torch.argmax(self.forward(x).probabilities, dim=1)

View File

@ -45,7 +45,7 @@ def ensure_safe_goal_position(
safe_goal_pos = present_pos + safe_diff safe_goal_pos = present_pos + safe_diff
if not torch.allclose(goal_pos, safe_goal_pos): if not torch.allclose(goal_pos, safe_goal_pos):
logging.warning( logging.debug(
"Relative goal position magnitude had to be clamped to be safe.\n" "Relative goal position magnitude had to be clamped to be safe.\n"
f" requested relative goal position target: {diff}\n" f" requested relative goal position target: {diff}\n"
f" clamped relative goal position target: {safe_diff}" f" clamped relative goal position target: {safe_diff}"
@ -464,6 +464,14 @@ class ManipulatorRobot:
before_fwrite_t = time.perf_counter() before_fwrite_t = time.perf_counter()
goal_pos = leader_pos[name] goal_pos = leader_pos[name]
# If specified, clip the goal positions within predefined bounds specified in the config of the robot
if self.config.joint_position_relative_bounds is not None:
goal_pos = torch.clamp(
goal_pos,
self.config.joint_position_relative_bounds["min"],
self.config.joint_position_relative_bounds["max"],
)
# Cap goal position when too far away from present position. # Cap goal position when too far away from present position.
# Slower fps expected due to reading from the follower. # Slower fps expected due to reading from the follower.
if self.config.max_relative_target is not None: if self.config.max_relative_target is not None:
@ -585,6 +593,14 @@ class ManipulatorRobot:
goal_pos = action[from_idx:to_idx] goal_pos = action[from_idx:to_idx]
from_idx = to_idx from_idx = to_idx
# If specified, clip the goal positions within predefined bounds specified in the config of the robot
if self.config.joint_position_relative_bounds is not None:
goal_pos = torch.clamp(
goal_pos,
self.config.joint_position_relative_bounds["min"],
self.config.joint_position_relative_bounds["max"],
)
# Cap goal position when too far away from present position. # Cap goal position when too far away from present position.
# Slower fps expected due to reading from the follower. # Slower fps expected due to reading from the follower.
if self.config.max_relative_target is not None: if self.config.max_relative_target is not None:

View File

@ -4,7 +4,7 @@ defaults:
- _self_ - _self_
seed: 13 seed: 13
dataset_repo_id: aractingi/pick_place_lego_cube_1 dataset_repo_id: aractingi/push_green_cube_hf_cropped_resized
train_split_proportion: 0.8 train_split_proportion: 0.8
# Required by logger # Required by logger
@ -24,7 +24,8 @@ training:
eval_freq: 1 # How often to run validation (in epochs) eval_freq: 1 # How often to run validation (in epochs)
save_freq: 1 # How often to save checkpoints (in epochs) save_freq: 1 # How often to save checkpoints (in epochs)
save_checkpoint: true save_checkpoint: true
image_keys: ["observation.images.top", "observation.images.wrist"] # image_keys: ["observation.images.top", "observation.images.wrist"]
image_keys: ["observation.images.laptop", "observation.images.phone"]
label_key: "next.reward" label_key: "next.reward"
eval: eval:
@ -32,7 +33,7 @@ eval:
num_samples_to_log: 30 # Number of validation samples to log in the table num_samples_to_log: 30 # Number of validation samples to log in the table
policy: policy:
name: "hilserl/classifier/pick_place_lego_cube_1" name: "hilserl/classifier/push_green_cube_hf_cropped_resized" #"hilserl/classifier/pick_place_lego_cube_1"
model_name: "facebook/convnext-base-224" model_name: "facebook/convnext-base-224"
model_type: "cnn" model_type: "cnn"
num_cameras: 2 # Has to be len(training.image_keys) num_cameras: 2 # Has to be len(training.image_keys)

View File

@ -14,6 +14,9 @@ calibration_dir: .cache/calibration/so100
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms. # the number of motors in your follower arms.
max_relative_target: null max_relative_target: null
joint_position_relative_bounds:
min: [-88.50586, 23.81836, 0.87890625, -32.16797, 78.66211, 0.53691274]
max: [84.55078, 187.11914, 145.98633, 101.60156, 146.60156, 88.18792]
leader_arms: leader_arms:
main: main:

View File

@ -0,0 +1,64 @@
import argparse
import time
import cv2
import numpy as np
from lerobot.common.robot_devices.control_utils import is_headless
from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.utils.utils import init_hydra_config
def find_joint_bounds(
robot,
control_time_s=20,
display_cameras=False,
):
# TODO(rcadene): Add option to record logs
if not robot.is_connected:
robot.connect()
control_time_s = float("inf")
timestamp = 0
start_episode_t = time.perf_counter()
pos_list = []
while timestamp < control_time_s:
observation, action = robot.teleop_step(record_data=True)
pos_list.append(robot.follower_arms["main"].read("Present_Position"))
if display_cameras and not is_headless():
image_keys = [key for key in observation if "image" in key]
for key in image_keys:
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
cv2.waitKey(1)
timestamp = time.perf_counter() - start_episode_t
if timestamp > 60:
max = np.max(np.stack(pos_list), 0)
min = np.min(np.stack(pos_list), 0)
print(f"Max angle position per joint {max}")
print(f"Min angle position per joint {min}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--robot-path",
type=str,
default="lerobot/configs/robot/koch.yaml",
help="Path to robot yaml file used to instantiate the robot using `make_robot` factory function.",
)
parser.add_argument(
"--robot-overrides",
type=str,
nargs="*",
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
)
parser.add_argument("--control-time-s", type=float, default=20, help="Maximum episode length in seconds")
args = parser.parse_args()
robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides)
robot = make_robot(robot_cfg)
find_joint_bounds(robot, control_time_s=args.control_time_s)

View File

@ -0,0 +1,697 @@
import argparse
import logging
import time
from threading import Lock
from typing import Annotated, Any, Callable, Dict, Optional, Tuple
import cv2
import gymnasium as gym
import numpy as np
import torch
import torchvision.transforms.functional as F # noqa: N812
from lerobot.common.envs.utils import preprocess_observation
from lerobot.common.robot_devices.control_utils import busy_wait, is_headless, reset_follower_position
from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.utils.utils import init_hydra_config, log_say
logging.basicConfig(level=logging.INFO)
class HILSerlRobotEnv(gym.Env):
"""
Gym-like environment wrapper for robot policy evaluation.
This wrapper provides a consistent interface for interacting with the robot,
following the OpenAI Gym environment conventions.
"""
def __init__(
self,
robot,
use_delta_action_space: bool = True,
delta: float | None = None,
display_cameras=False,
):
"""
Initialize the robot environment.
Args:
robot: The robot interface object
reward_classifier: Optional reward classifier
fps: Frames per second for control
control_time_s: Total control time for each episode
display_cameras: Whether to display camera feeds
output_normalization_params_action: Bound parameters for the action space
delta: The delta for the relative joint position action space
"""
super().__init__()
self.robot = robot
self.display_cameras = display_cameras
# connect robot
if not self.robot.is_connected:
self.robot.connect()
self.initial_follower_position = robot.follower_arms["main"].read("Present_Position")
# Episode tracking
self.current_step = 0
self.episode_data = None
self.delta = delta
self.use_delta_action_space = use_delta_action_space
self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position")
self.relative_bounds_size = (
self.robot.config.joint_position_relative_bounds["max"]
- self.robot.config.joint_position_relative_bounds["min"]
)
self.delta_relative_bounds_size = self.relative_bounds_size * self.delta
self.robot.config.max_relative_target = self.delta_relative_bounds_size.float()
# Dynamically determine observation and action spaces
self._setup_spaces()
def _setup_spaces(self):
"""
Dynamically determine observation and action spaces based on robot capabilities.
This method should be customized based on the specific robot's observation
and action representations.
"""
# Example space setup - you'll need to adapt this to your specific robot
example_obs = self.robot.capture_observation()
# Observation space (assuming image-based observations)
image_keys = [key for key in example_obs if "image" in key]
state_keys = [key for key in example_obs if "image" not in key]
observation_spaces = {
key: gym.spaces.Box(low=0, high=255, shape=example_obs[key].shape, dtype=np.uint8)
for key in image_keys
}
observation_spaces["observation.state"] = gym.spaces.Dict(
{
key: gym.spaces.Box(low=0, high=10, shape=example_obs[key].shape, dtype=np.float32)
for key in state_keys
}
)
self.observation_space = gym.spaces.Dict(observation_spaces)
# Action space (assuming joint positions)
action_dim = len(self.robot.follower_arms["main"].read("Present_Position"))
if self.use_delta_action_space:
action_space_robot = gym.spaces.Box(
low=-self.relative_bounds_size.cpu().numpy(),
high=self.relative_bounds_size.cpu().numpy(),
shape=(action_dim,),
dtype=np.float32,
)
else:
action_space_robot = gym.spaces.Box(
low=self.robot.config.joint_position_relative_bounds["min"].cpu().numpy(),
high=self.robot.config.joint_position_relative_bounds["max"].cpu().numpy(),
shape=(action_dim,),
dtype=np.float32,
)
self.action_space = gym.spaces.Tuple(
(
action_space_robot,
gym.spaces.Discrete(2),
),
)
def reset(self, seed=None, options=None) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]:
"""
Reset the environment to initial state.
Returns:
observation (dict): Initial observation
info (dict): Additional information
"""
super().reset(seed=seed, options=options)
# Capture initial observation
observation = self.robot.capture_observation()
# Reset tracking variables
self.current_step = 0
self.episode_data = None
return observation, {"initial_position": self.initial_follower_position}
def step(
self, action: Tuple[np.ndarray, bool]
) -> Tuple[Dict[str, np.ndarray], float, bool, bool, Dict[str, Any]]:
"""
Take a step in the environment.
Args:
action tuple(np.ndarray, bool):
Policy action to be executed on the robot and boolean to determine
whether to choose policy action or expert action.
Returns:
observation (dict): Next observation
reward (float): Reward for this step
terminated (bool): Whether the episode has terminated
truncated (bool): Whether the episode was truncated
info (dict): Additional information
"""
# The actions recieved are the in form of a tuple containing the policy action and an intervention bool
# The boolean inidicated whether we will use the expert's actions (through teleoperation) or the policy actions
policy_action, intervention_bool = action
teleop_action = None
self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position")
if isinstance(policy_action, torch.Tensor):
policy_action = policy_action.cpu().numpy()
olicy_action = np.clip(policy_action, self.action_space[0].low, self.action_space[0].high)
if not intervention_bool:
if self.use_delta_action_space:
target_joint_positions = self.current_joint_positions + self.delta * policy_action
else:
target_joint_positions = policy_action
self.robot.send_action(torch.from_numpy(target_joint_positions))
observation = self.robot.capture_observation()
else:
observation, teleop_action = self.robot.teleop_step(record_data=True)
teleop_action = teleop_action["action"] # teleop step returns torch tensors but in a dict
# teleop actions are returned in absolute joint space
# If we are using a relative joint position action space,
# there will be a mismatch between the spaces of the policy and teleop actions
# Solution is to transform the teleop actions into relative space.
# teleop relative action is:
if self.use_delta_action_space:
teleop_action = teleop_action - self.current_joint_positions
if torch.any(teleop_action < -self.delta_relative_bounds_size * self.delta) and torch.any(
teleop_action > self.delta_relative_bounds_size
):
print(
f"relative teleop delta exceeded bounds {self.delta_relative_bounds_size}, teleop_action {teleop_action}\n"
f"lower bounds condition {teleop_action < -self.delta_relative_bounds_size}\n"
f"upper bounds condition {teleop_action > self.delta_relative_bounds_size}"
)
teleop_action = torch.clamp(
teleop_action, -self.delta_relative_bounds_size, self.delta_relative_bounds_size
)
self.current_step += 1
reward = 0.0
terminated = False
truncated = False
return (
observation,
reward,
terminated,
truncated,
{"action_intervention": teleop_action, "is_intervention": teleop_action is not None},
)
def render(self):
"""
Render the environment (in this case, display camera feeds).
"""
import cv2
observation = self.robot.capture_observation()
image_keys = [key for key in observation if "image" in key]
for key in image_keys:
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
cv2.waitKey(1)
def close(self):
"""
Close the environment and disconnect the robot.
"""
if self.robot.is_connected:
self.robot.disconnect()
class ActionRepeatWrapper(gym.Wrapper):
def __init__(self, env, nb_repeat: int = 1):
super().__init__(env)
self.nb_repeat = nb_repeat
def step(self, action):
for _ in range(self.nb_repeat):
obs, reward, done, truncated, info = self.env.step(action)
if done or truncated:
break
return obs, reward, done, truncated, info
class RelativeJointPositionActionWrapper(gym.Wrapper):
def __init__(
self,
env: HILSerlRobotEnv,
# output_normalization_params_action: dict[str, list[float]],
delta: float = 0.1,
):
super().__init__(env)
self.joint_positions = self.unwrapped.robot.follower_arms["main"].read("Present_Position")
self.delta = delta
if delta > 1:
raise ValueError("Delta should be less than 1")
def step(self, action):
action_joint = action
self.joint_positions = self.unwrapped.robot.follower_arms["main"].read("Present_Position")
if isinstance(self.env.action_space, gym.spaces.Tuple):
action_joint = action[0]
joint_positions = self.joint_positions + (self.delta * action_joint)
# clip the joint positions to the joint limits with the action space
joint_positions = np.clip(joint_positions, self.action_space.low, self.action_space.high)
if isinstance(self.env.action_space, gym.spaces.Tuple):
return self.env.step((joint_positions, action[1]))
obs, reward, terminated, truncated, info = self.env.step(joint_positions)
if info["is_intervention"]:
# teleop actions are returned in absolute joint space
# If we are using a relative joint position action space,
# there will be a mismatch between the spaces of the policy and teleop actions
# Solution is to transform the teleop actions into relative space.
self.joint_positions = self.unwrapped.robot.follower_arms["main"].read("Present_Position")
teleop_action = info["action_intervention"] # teleop actions are in absolute joint space
relative_teleop_action = (teleop_action - self.joint_positions) / self.delta
info["action_intervention"] = relative_teleop_action
return self.env.step(joint_positions)
class RewardWrapper(gym.Wrapper):
def __init__(self, env, reward_classifier: Optional[None], device: torch.device = "cuda"):
self.env = env
self.reward_classifier = torch.compile(reward_classifier)
self.device = device
def step(self, action):
observation, _, terminated, truncated, info = self.env.step(action)
images = [
observation[key].to(self.device, non_blocking=True) for key in observation if "image" in key
]
start_time = time.perf_counter()
with torch.inference_mode():
reward = (
self.reward_classifier.predict_reward(images) if self.reward_classifier is not None else 0.0
)
# print(f"fps for reward classifier {1/(time.perf_counter() - start_time)}")
reward = reward.item()
# print(f"Reward from reward classifier {reward}")
return observation, reward, terminated, truncated, info
def reset(self, seed=None, options=None):
return self.env.reset(seed=seed, options=options)
class TimeLimitWrapper(gym.Wrapper):
def __init__(self, env, control_time_s, fps):
self.env = env
self.control_time_s = control_time_s
self.fps = fps
self.last_timestamp = 0.0
self.episode_time_in_s = 0.0
def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
time_since_last_step = time.perf_counter() - self.last_timestamp
self.episode_time_in_s += time_since_last_step
self.last_timestamp = time.perf_counter()
# check if last timestep took more time than the expected fps
if 1.0 / time_since_last_step < self.fps:
logging.warning(f"Current timestep is lower than the expected fps {self.fps}")
if self.episode_time_in_s > self.control_time_s:
# Terminated = True
terminated = True
return obs, reward, terminated, truncated, info
def reset(self, seed=None, options=None):
self.episode_time_in_s = 0.0
self.last_timestamp = time.perf_counter()
return self.env.reset(seed=seed, options=options)
class ImageCropResizeWrapper(gym.Wrapper):
def __init__(self, env, crop_params_dict: Dict[str, Annotated[Tuple[int], 4]], resize_size=None):
self.env = env
self.crop_params_dict = crop_params_dict
print(f"obs_keys , {self.env.observation_space}")
print(f"crop params dict {crop_params_dict.keys()}")
for key_crop in crop_params_dict:
if key_crop not in self.env.observation_space.keys():
raise ValueError(f"Key {key_crop} not in observation space")
for key in crop_params_dict:
top, left, height, width = crop_params_dict[key]
new_shape = (top + height, left + width)
self.observation_space[key] = gym.spaces.Box(low=0, high=255, shape=new_shape)
self.resize_size = resize_size
if self.resize_size is None:
self.resize_size = (128, 128)
def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
for k in self.crop_params_dict:
device = obs[k].device
if device == torch.device("mps:0"):
obs[k] = obs[k].cpu()
obs[k] = F.crop(obs[k], *self.crop_params_dict[k])
obs[k] = F.resize(obs[k], self.resize_size)
obs[k] = obs[k].to(device)
# print(f"observation with key {k} with size {obs[k].size()}")
cv2.imshow(k, cv2.cvtColor(obs[k].cpu().squeeze(0).permute(1, 2, 0).numpy(), cv2.COLOR_RGB2BGR))
cv2.waitKey(1)
return obs, reward, terminated, truncated, info
class ConvertToLeRobotObservation(gym.ObservationWrapper):
def __init__(self, env, device):
super().__init__(env)
self.device = device
def observation(self, observation):
observation = preprocess_observation(observation)
observation = {key: observation[key].to(self.device, non_blocking=True) for key in observation}
observation = {k: torch.tensor(v, device=self.device) for k, v in observation.items()}
return observation
class KeyboardInterfaceWrapper(gym.Wrapper):
def __init__(self, env):
super().__init__(env)
self.listener = None
self.events = {
"exit_early": False,
"pause_policy": False,
"reset_env": False,
"human_intervention_step": False,
}
self.event_lock = Lock() # Thread-safe access to events
self._init_keyboard_listener()
def _init_keyboard_listener(self):
"""Initialize keyboard listener if not in headless mode"""
if is_headless():
logging.warning(
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
)
return
try:
from pynput import keyboard
def on_press(key):
with self.event_lock:
try:
if key == keyboard.Key.right or key == keyboard.Key.esc:
print("Right arrow key pressed. Exiting loop...")
self.events["exit_early"] = True
elif key == keyboard.Key.space:
if not self.events["pause_policy"]:
print(
"Space key pressed. Human intervention required.\n"
"Place the leader in similar pose to the follower and press space again."
)
self.events["pause_policy"] = True
log_say("Human intervention stage. Get ready to take over.", play_sounds=True)
elif self.events["pause_policy"] and not self.events["human_intervention_step"]:
self.events["human_intervention_step"] = True
print("Space key pressed. Human intervention starting.")
log_say("Starting human intervention.", play_sounds=True)
else:
self.events["pause_policy"] = False
self.events["human_intervention_step"] = False
print("Space key pressed for a third time.")
log_say("Continuing with policy actions.", play_sounds=True)
except Exception as e:
print(f"Error handling key press: {e}")
self.listener = keyboard.Listener(on_press=on_press)
self.listener.start()
except ImportError:
logging.warning("Could not import pynput. Keyboard interface will not be available.")
self.listener = None
def step(self, action: Any) -> Tuple[Any, float, bool, bool, Dict]:
is_intervention = False
terminated_by_keyboard = False
# Extract policy_action if needed
if isinstance(self.env.action_space, gym.spaces.Tuple):
policy_action = action[0]
# Check the event flags without holding the lock for too long.
with self.event_lock:
if self.events["exit_early"]:
terminated_by_keyboard = True
# If we need to wait for human intervention, we note that outside the lock.
pause_policy = self.events["pause_policy"]
if pause_policy:
# Now, wait for human_intervention_step without holding the lock
while True:
with self.event_lock:
if self.events["human_intervention_step"]:
is_intervention = True
break
time.sleep(0.1) # Check more frequently if desired
# Execute the step in the underlying environment
obs, reward, terminated, truncated, info = self.env.step((policy_action, is_intervention))
return obs, reward, terminated or terminated_by_keyboard, truncated, info
def reset(self, **kwargs) -> Tuple[Any, Dict]:
"""
Reset the environment and clear any pending events
"""
with self.event_lock:
self.events = {k: False for k in self.events}
return self.env.reset(**kwargs)
def close(self):
"""
Properly clean up the keyboard listener when the environment is closed
"""
if self.listener is not None:
self.listener.stop()
super().close()
class ResetWrapper(gym.Wrapper):
def __init__(
self, env: HILSerlRobotEnv, reset_fn: Optional[Callable[[], None]] = None, reset_time_s: float = 5
):
super().__init__(env)
self.reset_fn = reset_fn
self.reset_time_s = reset_time_s
self.robot = self.unwrapped.robot
self.init_pos = self.unwrapped.initial_follower_position
def reset(self, *, seed=None, options=None):
if self.reset_fn is not None:
self.reset_fn(self.env)
else:
log_say(f"Manually reset the environment for {self.reset_time_s} seconds.", play_sounds=True)
start_time = time.perf_counter()
while time.perf_counter() - start_time < self.reset_time_s:
self.robot.teleop_step()
log_say("Manual reseting of the environment done.", play_sounds=True)
return super().reset(seed=seed, options=options)
def make_robot_env(
robot,
reward_classifier,
crop_params_dict=None,
fps=30,
control_time_s=20,
reset_follower_pos=True,
display_cameras=False,
device="cuda:0",
resize_size=None,
reset_time_s=10,
delta_action=0.1,
nb_repeats=1,
use_relative_joint_positions=False,
):
"""
Factory function to create the robot environment.
Mimics gym.make() for consistent environment creation.
"""
env = HILSerlRobotEnv(
robot,
display_cameras=display_cameras,
delta=delta_action,
use_delta_action_space=use_relative_joint_positions,
)
env = ConvertToLeRobotObservation(env, device)
if crop_params_dict is not None:
env = ImageCropResizeWrapper(env, crop_params_dict, resize_size=resize_size)
env = RewardWrapper(env, reward_classifier, device=device)
env = TimeLimitWrapper(env, control_time_s, fps)
# env = ActionRepeatWrapper(env, nb_repeat=nb_repeats)
env = KeyboardInterfaceWrapper(env)
env = ResetWrapper(env, reset_fn=None, reset_time_s=reset_time_s)
return env
def get_classifier(pretrained_path, config_path, device="mps"):
if pretrained_path is None or config_path is None:
return
from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
cfg = init_hydra_config(config_path)
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
classifier_config.num_cameras = len(cfg.training.image_keys) # TODO automate these paths
model = Classifier(classifier_config)
model.load_state_dict(Classifier.from_pretrained(pretrained_path).state_dict())
model = model.to(device)
return model
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--fps", type=int, default=30, help="control frequency")
parser.add_argument(
"--robot-path",
type=str,
default="lerobot/configs/robot/koch.yaml",
help="Path to robot yaml file used to instantiate the robot using `make_robot` factory function.",
)
parser.add_argument(
"--robot-overrides",
type=str,
nargs="*",
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
)
parser.add_argument(
"-p",
"--pretrained-policy-name-or-path",
help=(
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
"saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch "
"(useful for debugging). This argument is mutually exclusive with `--config`."
),
)
parser.add_argument(
"--config",
help=(
"Path to a yaml config you want to use for initializing a policy from scratch (useful for "
"debugging). This argument is mutually exclusive with `--pretrained-policy-name-or-path` (`-p`)."
),
)
parser.add_argument(
"--display-cameras", help=("Whether to display the camera feed while the rollout is happening")
)
parser.add_argument(
"--reward-classifier-pretrained-path",
type=str,
default=None,
help="Path to the pretrained classifier weights.",
)
parser.add_argument(
"--reward-classifier-config-file",
type=str,
default=None,
help="Path to a yaml config file that is necessary to build the reward classifier model.",
)
parser.add_argument("--control-time-s", type=float, default=20, help="Maximum episode length in seconds")
parser.add_argument("--reset-follower-pos", type=int, default=1, help="Reset follower between episodes")
args = parser.parse_args()
robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides)
robot = make_robot(robot_cfg)
reward_classifier = get_classifier(
args.reward_classifier_pretrained_path, args.reward_classifier_config_file
)
crop_parameters = {
"observation.images.laptop": (58, 89, 357, 455),
"observation.images.phone": (3, 4, 471, 633),
}
user_relative_joint_positions = True
env = make_robot_env(
robot,
reward_classifier,
crop_parameters,
args.fps,
args.control_time_s,
args.reset_follower_pos,
args.display_cameras,
device="mps",
resize_size=None,
reset_time_s=10,
delta_action=0.1,
nb_repeats=1,
use_relative_joint_positions=user_relative_joint_positions,
)
env.reset()
init_pos = env.unwrapped.initial_follower_position
right_goal = init_pos.copy()
right_goal[0] += 50
left_goal = init_pos.copy()
left_goal[0] -= 50
pitch_angle = np.linspace(left_goal[0], right_goal[0], 1000)
delta_angle = np.concatenate((-np.ones(50), np.ones(50))) * 100
while True:
action = np.zeros(len(init_pos))
for i in range(len(delta_angle)):
start_loop_s = time.perf_counter()
action[0] = delta_angle[i]
obs, reward, terminated, truncated, info = env.step((torch.from_numpy(action), False))
if terminated or truncated:
env.reset()
dt_s = time.perf_counter() - start_loop_s
busy_wait(1 / args.fps - dt_s)
# action = np.zeros(len(init_pos)) if user_relative_joint_positions else init_pos
# for i in range(len(pitch_angle)):
# if user_relative_joint_positions:
# action[0] = delta_angle[i]
# else:
# action[0] = pitch_angle[i]
# obs, reward, terminated, truncated, info = env.step((torch.from_numpy(action), False))
# if terminated or truncated:
# logging.info("Max control time reached, reset environment.")
# env.reset()
# for i in reversed(range(len(pitch_angle))):
# if user_relative_joint_positions:
# action[0] = delta_angle[i]
# else:
# action[0] = pitch_angle[i]
# obs, reward, terminated, truncated, info = env.step((torch.from_numpy(action), False))
# if terminated or truncated:
# logging.info("Max control time reached, reset environment.")
# env.reset()