is_headless

This commit is contained in:
Remi Cadene 2024-07-14 14:52:15 +00:00
parent 42325f5990
commit d878ec27d5
1 changed files with 35 additions and 20 deletions

View File

@ -84,6 +84,7 @@ import concurrent.futures
import json import json
import logging import logging
import os import os
import platform
import shutil import shutil
import time import time
from contextlib import nullcontext from contextlib import nullcontext
@ -94,7 +95,6 @@ import tqdm
from huggingface_hub import create_branch from huggingface_hub import create_branch
from omegaconf import DictConfig from omegaconf import DictConfig
from PIL import Image from PIL import Image
from pynput import keyboard
# from safetensors.torch import load_file, save_file # from safetensors.torch import load_file, save_file
from lerobot.common.datasets.compute_stats import compute_stats from lerobot.common.datasets.compute_stats import compute_stats
@ -270,6 +270,19 @@ def record_dataset(
rerecord_episode = False rerecord_episode = False
stop_recording = False stop_recording = False
def is_headless():
if platform.system() == "Linux":
display = os.environ.get("DISPLAY")
if display is None or display == "":
return True
return False
# Only import pynput if not in a headless environment
if is_headless():
logging.info("Headless environment detected. Keyboard input will not be available.")
else:
from pynput import keyboard
def on_press(key): def on_press(key):
nonlocal exit_early, rerecord_episode, stop_recording nonlocal exit_early, rerecord_episode, stop_recording
try: try:
@ -408,8 +421,10 @@ def record_dataset(
if stop_recording or episode_index == num_episodes: if stop_recording or episode_index == num_episodes:
logging.info("Done recording") logging.info("Done recording")
os.system('say "Done recording"') os.system('say "Done recording"')
logging.info("Waiting for threads writing the images on disk to terminate...") if not is_headless():
listener.stop() listener.stop()
logging.info("Waiting for threads writing the images on disk to terminate...")
for _ in tqdm.tqdm( for _ in tqdm.tqdm(
concurrent.futures.as_completed(futures), total=len(futures), desc="Writting images" concurrent.futures.as_completed(futures), total=len(futures), desc="Writting images"
): ):