diff --git a/lerobot/common/robot_devices/cameras/utils.py b/lerobot/common/robot_devices/cameras/utils.py index 08c9465f..0f329d9f 100644 --- a/lerobot/common/robot_devices/cameras/utils.py +++ b/lerobot/common/robot_devices/cameras/utils.py @@ -2,6 +2,7 @@ from pathlib import Path from typing import Protocol import cv2 +import einops import numpy as np @@ -39,6 +40,16 @@ def save_depth_image(depth, path, write_shape=False): cv2.imwrite(str(path), depth_image) +def convert_torch_image_to_cv2(tensor, rgb_to_bgr=True): + assert tensor.ndim == 3 + c, h, w = tensor.shape + assert c < h and c < w + color_image = einops.rearrange(tensor, "c h w -> h w c").numpy() + if rgb_to_bgr: + color_image = cv2.cvtColor(color_image, cv2.COLOR_RGB2BGR) + return color_image + + # Defines a camera type class Camera(Protocol): def connect(self): ... diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index 1ee11005..71903568 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -88,8 +88,10 @@ import platform import shutil import time from contextlib import nullcontext +from functools import cache from pathlib import Path +import cv2 import torch import tqdm from huggingface_hub import create_branch @@ -105,6 +107,7 @@ from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episod from lerobot.common.datasets.utils import calculate_episode_data_index from lerobot.common.datasets.video_utils import encode_video_frames from lerobot.common.policies.factory import make_policy +from lerobot.common.robot_devices.cameras.utils import convert_torch_image_to_cv2 from lerobot.common.robot_devices.robots.factory import make_robot from lerobot.common.robot_devices.robots.utils import Robot from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed @@ -179,7 +182,8 @@ def log_control_info(robot, dt_s, episode_index=None, frame_index=None, fps=None logging.info(info_str) -def get_is_headless(): +@cache +def is_headless(): if platform.system() == "Linux": display = os.environ.get("DISPLAY") if display is None or display == "": @@ -255,7 +259,10 @@ def record_dataset( else: episode_index = 0 - is_headless = get_is_headless() + if is_headless(): + logging.info( + "Headless environment detected. Display cameras on screen and keyboard inputs will not be available." + ) # Execute a few seconds without recording data, to give times # to the robot devices to connect and start synchronizing. @@ -269,10 +276,14 @@ def record_dataset( is_warmup_print = True now = time.perf_counter() + observation, action = robot.teleop_step(record_data=True) - if not is_headless: + if not is_headless(): image_keys = [key for key in observation if "image" in key] + for key in image_keys: + cv2.imshow(key, convert_torch_image_to_cv2(observation[key])) + cv2.waitKey(1) dt_s = time.perf_counter() - now busy_wait(1 / fps - dt_s) @@ -290,9 +301,7 @@ def record_dataset( stop_recording = False # Only import pynput if not in a headless environment - if is_headless: - logging.info("Headless environment detected. Keyboard input will not be available.") - else: + if not is_headless(): from pynput import keyboard def on_press(key): @@ -342,6 +351,12 @@ def record_dataset( ) ] + if not is_headless(): + image_keys = [key for key in observation if "image" in key] + for key in image_keys: + cv2.imshow(key, convert_torch_image_to_cv2(observation[key])) + cv2.waitKey(1) + for key in not_image_keys: if key not in ep_dict: ep_dict[key] = [] @@ -434,7 +449,7 @@ def record_dataset( if is_last_episode: logging.info("Done recording") os.system('say "Done recording"') - if not is_headless: + if not is_headless(): listener.stop() logging.info("Waiting for threads writing the images on disk to terminate...") @@ -543,7 +558,14 @@ def replay_episode(robot: Robot, episode: int, fps: int | None = None, root="dat log_control_info(robot, dt_s, fps=fps) -def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig, run_time_s: float | None = None): +def run_policy( + robot: Robot, + policy: torch.nn.Module, + hydra_cfg: DictConfig, + warmup_time_s: float = 4, + run_time_s: float | None = None, + reset_time_s: float = 15, +): # TODO(rcadene): Add option to record eval dataset and logs # Check device is available @@ -561,12 +583,76 @@ def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig, run if not robot.is_connected: robot.connect() + if is_headless(): + logging.info( + "Headless environment detected. Display cameras on screen and keyboard inputs will not be available." + ) + + # Execute a few seconds without recording data, to give times + # to the robot devices to connect and start synchronizing. + timestamp = 0 + start_time = time.perf_counter() + is_warmup_print = False + while timestamp < warmup_time_s: + if not is_warmup_print: + logging.info("Warming up (no data recording)") + os.system('say "Warmup" &') + is_warmup_print = True + + now = time.perf_counter() + observation = robot.capture_observation() + + if not is_headless(): + image_keys = [key for key in observation if "image" in key] + for key in image_keys: + cv2.imshow(key, convert_torch_image_to_cv2(observation[key])) + cv2.waitKey(1) + + dt_s = time.perf_counter() - now + busy_wait(1 / fps - dt_s) + + dt_s = time.perf_counter() - now + log_control_info(robot, dt_s, fps=fps) + + timestamp = time.perf_counter() - start_time + + # Allow to reset environment or exit early + # by tapping the right arrow key '->'. This might require a sudo permission + # to allow your terminal to monitor keyboard events. + reset_environment = False + exit_reset = False + + # Only import pynput if not in a headless environment + if not is_headless(): + from pynput import keyboard + + def on_press(key): + nonlocal reset_environment, exit_reset + try: + if key == keyboard.Key.right and not reset_environment: + print("Right arrow key pressed. Suspend robot control to reset environment...") + reset_environment = True + elif key == keyboard.Key.right and reset_environment: + print("Right arrow key pressed. Enable robot control and exit reset environment...") + exit_reset = True + except Exception as e: + print(f"Error handling key press: {e}") + + listener = keyboard.Listener(on_press=on_press) + listener.start() + start_time = time.perf_counter() while True: now = time.perf_counter() observation = robot.capture_observation() + if not is_headless(): + image_keys = [key for key in observation if "image" in key] + for key in image_keys: + cv2.imshow(key, convert_torch_image_to_cv2(observation[key])) + cv2.waitKey(1) + with ( torch.inference_mode(), torch.autocast(device_type=device.type) @@ -597,6 +683,25 @@ def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig, run if run_time_s is not None and time.perf_counter() - start_time > run_time_s: break + if reset_environment: + # Start resetting env while the executor are finishing + logging.info("Reset the environment") + os.system('say "Reset the environment" &') + + # Wait if necessary + timestamp = 0 + start_time = time.perf_counter() + with tqdm.tqdm(total=reset_time_s, desc="Waiting") as pbar: + while timestamp < reset_time_s: + time.sleep(1) + timestamp = time.perf_counter() - start_time + pbar.update(1) + if exit_reset: + exit_reset = False + break + + reset_environment = False + if __name__ == "__main__": parser = argparse.ArgumentParser()