- Added base gym env class for the real robot environment.

- Added several wrappers around the base gym env robot class.
- Including: time limit, reward classifier, crop images, preprocess observations.
- Added an interactive script crop_roi.py where the user can interactively select the roi in the observation images and return the correct crop values that will improve the policy and reward classifier performance.

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
Michel Aractingi 2025-02-03 14:52:45 +00:00
parent 506821c7df
commit 2211209be5
2 changed files with 528 additions and 0 deletions

View File

@ -0,0 +1,148 @@
import cv2
from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera
def select_square_roi(img):
"""
Allows the user to draw a square ROI on the image.
The user must click and drag to draw the square.
- While dragging, the square is dynamically drawn.
- On mouse button release, the square is fixed.
- Press 'c' to confirm the selection.
- Press 'r' to reset the selection.
- Press ESC to cancel.
Returns:
A tuple (top, left, height, width) representing the square ROI,
or None if no valid ROI is selected.
"""
# Create a working copy of the image
clone = img.copy()
working_img = clone.copy()
roi = None # Will store the final ROI as (top, left, side, side)
drawing = False
ix, iy = -1, -1 # Initial click coordinates
def mouse_callback(event, x, y, flags, param):
nonlocal ix, iy, drawing, roi, working_img
if event == cv2.EVENT_LBUTTONDOWN:
# Start drawing: record starting coordinates
drawing = True
ix, iy = x, y
elif event == cv2.EVENT_MOUSEMOVE:
if drawing:
# Compute side length as the minimum of horizontal/vertical drags
side = min(abs(x - ix), abs(y - iy))
# Determine the direction to draw (in case of dragging to top/left)
dx = side if x >= ix else -side
dy = side if y >= iy else -side
# Show a temporary image with the current square drawn
temp = working_img.copy()
cv2.rectangle(temp, (ix, iy), (ix + dx, iy + dy), (0, 255, 0), 2)
cv2.imshow("Select ROI", temp)
elif event == cv2.EVENT_LBUTTONUP:
# Finish drawing
drawing = False
side = min(abs(x - ix), abs(y - iy))
dx = side if x >= ix else -side
dy = side if y >= iy else -side
# Normalize coordinates: (top, left) is the minimum of the two points
x1 = min(ix, ix + dx)
y1 = min(iy, iy + dy)
roi = (y1, x1, side, side) # (top, left, height, width)
# Draw the final square on the working image and display it
working_img = clone.copy()
cv2.rectangle(working_img, (ix, iy), (ix + dx, iy + dy), (0, 255, 0), 2)
cv2.imshow("Select ROI", working_img)
# Create the window and set the callback
cv2.namedWindow("Select ROI")
cv2.setMouseCallback("Select ROI", mouse_callback)
cv2.imshow("Select ROI", working_img)
print("Instructions for ROI selection:")
print(" - Click and drag to draw a square ROI.")
print(" - Press 'c' to confirm the selection.")
print(" - Press 'r' to reset and draw again.")
print(" - Press ESC to cancel the selection.")
# Wait until the user confirms with 'c', resets with 'r', or cancels with ESC
while True:
key = cv2.waitKey(1) & 0xFF
# Confirm ROI if one has been drawn
if key == ord("c") and roi is not None:
break
# Reset: clear the ROI and restore the original image
elif key == ord("r"):
working_img = clone.copy()
roi = None
cv2.imshow("Select ROI", working_img)
# Cancel selection for this image
elif key == 27: # ESC key
roi = None
break
cv2.destroyWindow("Select ROI")
return roi
def select_square_roi_for_images(images: dict) -> dict:
"""
For each image in the provided dictionary, open a window to allow the user
to select a square ROI. Returns a dictionary mapping each key to a tuple
(top, left, height, width) representing the ROI.
Parameters:
images (dict): Dictionary where keys are identifiers and values are OpenCV images.
Returns:
dict: Mapping of image keys to the selected square ROI.
"""
selected_rois = {}
for key, img in images.items():
if img is None:
print(f"Image for key '{key}' is None, skipping.")
continue
print(f"\nSelect square ROI for image with key: '{key}'")
roi = select_square_roi(img)
if roi is None:
print(f"No valid ROI selected for '{key}'.")
else:
selected_rois[key] = roi
print(f"ROI for '{key}': {roi}")
return selected_rois
if __name__ == "__main__":
# Example usage:
# Replace 'image1.jpg' and 'image2.jpg' with valid paths to your image files.
fps = [5, 30]
cameras = [OpenCVCamera(i, fps=fps[i], width=640, height=480, mock=False) for i in range(2)]
[camera.connect() for camera in cameras]
image_keys = ["image_" + str(i) for i in range(len(cameras))]
images = {image_keys[i]: cameras[i].read() for i in range(len(cameras))}
# Verify images loaded correctly
for key, img in images.items():
if img is None:
raise ValueError(f"Failed to load image for key '{key}'. Check the file path.")
# Let the user select a square ROI for each image
rois = select_square_roi_for_images(images)
# Print the selected square ROIs
print("\nSelected Square Regions of Interest (top, left, height, width):")
for key, roi in rois.items():
print(f"{key}: {roi}")

View File

@ -0,0 +1,380 @@
import argparse
import logging
import time
from typing import Annotated, Any, Dict, Optional, Tuple
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms.functional as F # noqa: N812
from lerobot.common.envs.utils import preprocess_observation
from lerobot.common.robot_devices.control_utils import reset_follower_position
from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.utils.utils import init_hydra_config
logging.basicConfig(level=logging.INFO)
class HILSerlRobotEnv(gym.Env):
"""
Gym-like environment wrapper for robot policy evaluation.
This wrapper provides a consistent interface for interacting with the robot,
following the OpenAI Gym environment conventions.
"""
def __init__(
self,
robot,
reset_follower_position=True,
display_cameras=False,
):
"""
Initialize the robot environment.
Args:
robot: The robot interface object
reward_classifier: Optional reward classifier
fps: Frames per second for control
control_time_s: Total control time for each episode
display_cameras: Whether to display camera feeds
"""
super().__init__()
self.robot = robot
self.display_cameras = display_cameras
# connect robot
if not self.robot.is_connected:
self.robot.connect()
# Dynamically determine observation and action spaces
self._setup_spaces()
self._initial_follower_position = robot.follower_arms["main"].read("Present_Position")
self.reset_follower_position = reset_follower_position
# Episode tracking
self.current_step = 0
self.episode_data = None
def _setup_spaces(self):
"""
Dynamically determine observation and action spaces based on robot capabilities.
This method should be customized based on the specific robot's observation
and action representations.
"""
# Example space setup - you'll need to adapt this to your specific robot
example_obs = self.robot.capture_observation()
# Observation space (assuming image-based observations)
image_keys = [key for key in example_obs if "image" in key]
state_keys = [key for key in example_obs if "image" not in key]
observation_spaces = {
key: gym.spaces.Box(low=0, high=255, shape=example_obs[key].shape, dtype=np.uint8)
for key in image_keys
}
observation_spaces["observation.state"] = gym.spaces.Dict(
{
key: gym.spaces.Box(low=0, high=10, shape=example_obs[key].shape, dtype=np.float32)
for key in state_keys
}
)
self.observation_space = gym.spaces.Dict(observation_spaces)
# Action space (assuming joint positions)
action_dim = len(self.robot.follower_arms["main"].read("Present_Position"))
self.action_space = gym.spaces.Tuple(
(
gym.spaces.Box(low=-np.inf, high=np.inf, shape=(action_dim,), dtype=np.float32),
gym.spaces.Discrete(2),
),
)
def reset(self, seed=None, options=None) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]:
"""
Reset the environment to initial state.
Returns:
observation (dict): Initial observation
info (dict): Additional information
"""
super().reset(seed=seed, options=options)
if self.reset_follower_position:
reset_follower_position(self.robot, target_position=self._initial_follower_position)
# Capture initial observation
observation = self.robot.capture_observation()
# Reset tracking variables
self.current_step = 0
self.episode_data = None
return observation, {}
def step(
self, action: Tuple[np.ndarray, bool]
) -> Tuple[Dict[str, np.ndarray], float, bool, bool, Dict[str, Any]]:
"""
Take a step in the environment.
Args:
action tuple(np.ndarray, bool):
Policy action to be executed on the robot and boolean to determine
whether to choose policy action or expert action.
Returns:
observation (dict): Next observation
reward (float): Reward for this step
terminated (bool): Whether the episode has terminated
truncated (bool): Whether the episode was truncated
info (dict): Additional information
"""
# The actions recieved are the in form of a tuple containing the policy action and an intervention bool
# The boolean inidicated whether we will use the expert's actions (through teleoperation) or the policy actions
policy_action, intervention_bool = action
teleop_action = None
if not intervention_bool:
self.robot.send_action(policy_action.cpu().numpy())
observation = self.robot.capture_observation()
else:
observation, teleop_action = self.robot.teleop_step(record_data=True)
teleop_action = teleop_action["action"] # teleop step returns torch tensors but in a dict
self.current_step += 1
reward = 0.0
terminated = False
truncated = False
return observation, reward, terminated, truncated, {"action": teleop_action}
def render(self):
"""
Render the environment (in this case, display camera feeds).
"""
import cv2
observation = self.robot.capture_observation()
image_keys = [key for key in observation if "image" in key]
for key in image_keys:
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
cv2.waitKey(1)
def close(self):
"""
Close the environment and disconnect the robot.
"""
if self.robot.is_connected:
self.robot.disconnect()
class HILSerlTimeLimitWrapper(gym.Wrapper):
def __init__(self, env, control_time_s, fps):
self.env = env
self.control_time_s = control_time_s
self.fps = fps
self.last_timestamp = 0.0
self.episode_time_in_s = 0.0
def step(self, action):
ret = self.env.step(action)
time_since_last_step = time.perf_counter() - self.last_timestamp
self.episode_time_in_s += time_since_last_step
self.last_timestamp = time.perf_counter()
# check if last timestep took more time than the expected fps
if 1.0 / time_since_last_step > self.fps:
logging.warning(f"Current timestep exceeded expected fps {self.fps}")
if self.episode_time_in_s > self.control_time_s:
# Terminated = True
ret[2] = True
return ret
def reset(self, seed=None, options=None):
self.episode_time_in_s = 0.0
self.last_timestamp = time.perf_counter()
return self.env.reset(seed, options=None)
class HILSerlRewardWrapper(gym.Wrapper):
def __init__(self, env, reward_classifier: Optional[None], device: torch.device = "cuda"):
self.env = env
self.reward_classifier = reward_classifier
self.device = device
def step(self, action):
observation, _, terminated, truncated, info = self.env.step(action)
images = [
observation[key].to(self.device, non_blocking=True) for key in observation if "image" in key
]
reward = self.reward_classifier.predict_reward(images) if self.reward_classifier is not None else 0.0
reward = reward.item()
return observation, reward, terminated, truncated, info
def reset(self, seed=None, options=None):
return self.env.reset(seed=seed, options=options)
class HILSerlImageCropResizeWrapper(gym.Wrapper):
def __init__(self, env, crop_params_dict: Dict[str, Annotated[Tuple[int], 4]], resize_size=None):
self.env = env
self.crop_params_dict = crop_params_dict
for key in crop_params_dict:
assert key in self.env.observation_space, f"Key {key} not in observation space"
top, left, height, width = crop_params_dict[key]
new_shape = (top + height, left + width)
self.observation_space[key] = gym.spaces.Box(low=0, high=255, shape=new_shape)
self.resize_size = resize_size
if self.resize_size is None:
self.resize_size = (128, 128)
def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
for k in self.crop_params_dict:
obs[k] = F.crop(obs[k], *self.crop_params_dict[k])
obs[k] = F.resize(obs[k], self.resize_size)
return obs, reward, terminated, truncated, info
class ConvertToLeRobotObservation(gym.ObservationWrapper):
def __init__(self, env, device):
super().__init__(env)
self.device = device
def observation(self, observation):
observation = preprocess_observation(observation)
observation = {key: observation[key].to(self.device, non_blocking=True) for key in observation}
observation = {k: torch.tensor(v, device=self.device) for k, v in observation.items()}
return observation
def make_robot_env(
robot,
reward_classifier,
crop_params_dict=None,
fps=30,
control_time_s=20,
reset_follower_pos=True,
display_cameras=False,
device="cuda:0",
resize_size=None,
):
"""
Factory function to create the robot environment.
Mimics gym.make() for consistent environment creation.
"""
env = HILSerlRobotEnv(robot, reset_follower_pos, display_cameras)
env = ConvertToLeRobotObservation(env, device)
if crop_params_dict is not None:
env = HILSerlImageCropResizeWrapper(env, crop_params_dict, resize_size=resize_size)
env = HILSerlRewardWrapper(env, reward_classifier)
env = HILSerlTimeLimitWrapper(env, control_time_s, fps)
return env
def get_classifier(pretrained_path, config_path, device="mps"):
if pretrained_path is None or config_path is None:
return
from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
cfg = init_hydra_config(config_path)
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
classifier_config.num_cameras = len(cfg.training.image_keys) # TODO automate these paths
model = Classifier(classifier_config)
model.load_state_dict(Classifier.from_pretrained(pretrained_path).state_dict())
model = model.to(device)
return model
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--fps", type=int, default=30, help="control frequency")
parser.add_argument(
"--robot-path",
type=str,
default="lerobot/configs/robot/koch.yaml",
help="Path to robot yaml file used to instantiate the robot using `make_robot` factory function.",
)
parser.add_argument(
"--robot-overrides",
type=str,
nargs="*",
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
)
parser.add_argument(
"-p",
"--pretrained-policy-name-or-path",
help=(
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
"saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch "
"(useful for debugging). This argument is mutually exclusive with `--config`."
),
)
parser.add_argument(
"--config",
help=(
"Path to a yaml config you want to use for initializing a policy from scratch (useful for "
"debugging). This argument is mutually exclusive with `--pretrained-policy-name-or-path` (`-p`)."
),
)
parser.add_argument(
"--display-cameras", help=("Whether to display the camera feed while the rollout is happening")
)
parser.add_argument(
"--reward-classifier-pretrained-path",
type=str,
default=None,
help="Path to the pretrained classifier weights.",
)
parser.add_argument(
"--reward-classifier-config-file",
type=str,
default=None,
help="Path to a yaml config file that is necessary to build the reward classifier model.",
)
parser.add_argument("--control-time-s", type=float, default=20, help="Maximum episode length in seconds")
parser.add_argument("--reset-follower-pos", type=int, default=1, help="Reset follower between episodes")
args = parser.parse_args()
robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides)
robot = make_robot(robot_cfg)
reward_classifier = get_classifier(
args.reward_classifier_pretrained_path, args.reward_classifier_config_file
)
env = make_robot_env(
robot,
reward_classifier,
None,
args.fps,
args.control_time_s,
args.reset_follower_pos,
args.display_cameras,
device="mps",
)
env.reset()
while True:
intervention_action = (None, True)
obs, reward, terminated, truncated, info = env.step(intervention_action)
if terminated or truncated:
logging.info("Max control time reached, reset environment.")
env.reset()