diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index b4b66b3c..fe335556 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -84,6 +84,7 @@ import concurrent.futures import json import logging import os +import platform import shutil import time from contextlib import nullcontext @@ -94,7 +95,6 @@ import tqdm from huggingface_hub import create_branch from omegaconf import DictConfig from PIL import Image -from pynput import keyboard # from safetensors.torch import load_file, save_file from lerobot.common.datasets.compute_stats import compute_stats @@ -270,25 +270,38 @@ def record_dataset( rerecord_episode = False stop_recording = False - def on_press(key): - nonlocal exit_early, rerecord_episode, stop_recording - try: - if key == keyboard.Key.right: - print("Right arrow key pressed. Exiting loop...") - exit_early = True - elif key == keyboard.Key.left: - print("Left arrow key pressed. Exiting loop and rerecord the last episode...") - rerecord_episode = True - exit_early = True - elif key == keyboard.Key.esc: - print("Escape key pressed. Stopping data recording...") - stop_recording = True - exit_early = True - except Exception as e: - print(f"Error handling key press: {e}") + def is_headless(): + if platform.system() == "Linux": + display = os.environ.get("DISPLAY") + if display is None or display == "": + return True + return False - listener = keyboard.Listener(on_press=on_press) - listener.start() + # Only import pynput if not in a headless environment + if is_headless(): + logging.info("Headless environment detected. Keyboard input will not be available.") + else: + from pynput import keyboard + + def on_press(key): + nonlocal exit_early, rerecord_episode, stop_recording + try: + if key == keyboard.Key.right: + print("Right arrow key pressed. Exiting loop...") + exit_early = True + elif key == keyboard.Key.left: + print("Left arrow key pressed. Exiting loop and rerecord the last episode...") + rerecord_episode = True + exit_early = True + elif key == keyboard.Key.esc: + print("Escape key pressed. Stopping data recording...") + stop_recording = True + exit_early = True + except Exception as e: + print(f"Error handling key press: {e}") + + listener = keyboard.Listener(on_press=on_press) + listener.start() # Save images using threads to reach high fps (30 and more) # Using `with` to exist smoothly if an execption is raised. @@ -408,8 +421,10 @@ def record_dataset( if stop_recording or episode_index == num_episodes: logging.info("Done recording") os.system('say "Done recording"') + if not is_headless(): + listener.stop() + logging.info("Waiting for threads writing the images on disk to terminate...") - listener.stop() for _ in tqdm.tqdm( concurrent.futures.as_completed(futures), total=len(futures), desc="Writting images" ):