is_headless

This commit is contained in:
Remi Cadene 2024-07-14 14:52:15 +00:00
parent 42325f5990
commit d878ec27d5
1 changed files with 35 additions and 20 deletions

View File

@ -84,6 +84,7 @@ import concurrent.futures
import json import json
import logging import logging
import os import os
import platform
import shutil import shutil
import time import time
from contextlib import nullcontext from contextlib import nullcontext
@ -94,7 +95,6 @@ import tqdm
from huggingface_hub import create_branch from huggingface_hub import create_branch
from omegaconf import DictConfig from omegaconf import DictConfig
from PIL import Image from PIL import Image
from pynput import keyboard
# from safetensors.torch import load_file, save_file # from safetensors.torch import load_file, save_file
from lerobot.common.datasets.compute_stats import compute_stats from lerobot.common.datasets.compute_stats import compute_stats
@ -270,25 +270,38 @@ def record_dataset(
rerecord_episode = False rerecord_episode = False
stop_recording = False stop_recording = False
def on_press(key): def is_headless():
nonlocal exit_early, rerecord_episode, stop_recording if platform.system() == "Linux":
try: display = os.environ.get("DISPLAY")
if key == keyboard.Key.right: if display is None or display == "":
print("Right arrow key pressed. Exiting loop...") return True
exit_early = True return False
elif key == keyboard.Key.left:
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
rerecord_episode = True
exit_early = True
elif key == keyboard.Key.esc:
print("Escape key pressed. Stopping data recording...")
stop_recording = True
exit_early = True
except Exception as e:
print(f"Error handling key press: {e}")
listener = keyboard.Listener(on_press=on_press) # Only import pynput if not in a headless environment
listener.start() if is_headless():
logging.info("Headless environment detected. Keyboard input will not be available.")
else:
from pynput import keyboard
def on_press(key):
nonlocal exit_early, rerecord_episode, stop_recording
try:
if key == keyboard.Key.right:
print("Right arrow key pressed. Exiting loop...")
exit_early = True
elif key == keyboard.Key.left:
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
rerecord_episode = True
exit_early = True
elif key == keyboard.Key.esc:
print("Escape key pressed. Stopping data recording...")
stop_recording = True
exit_early = True
except Exception as e:
print(f"Error handling key press: {e}")
listener = keyboard.Listener(on_press=on_press)
listener.start()
# Save images using threads to reach high fps (30 and more) # Save images using threads to reach high fps (30 and more)
# Using `with` to exist smoothly if an execption is raised. # Using `with` to exist smoothly if an execption is raised.
@ -408,8 +421,10 @@ def record_dataset(
if stop_recording or episode_index == num_episodes: if stop_recording or episode_index == num_episodes:
logging.info("Done recording") logging.info("Done recording")
os.system('say "Done recording"') os.system('say "Done recording"')
if not is_headless():
listener.stop()
logging.info("Waiting for threads writing the images on disk to terminate...") logging.info("Waiting for threads writing the images on disk to terminate...")
listener.stop()
for _ in tqdm.tqdm( for _ in tqdm.tqdm(
concurrent.futures.as_completed(futures), total=len(futures), desc="Writting images" concurrent.futures.as_completed(futures), total=len(futures), desc="Writting images"
): ):