is_headless
This commit is contained in:
parent
42325f5990
commit
d878ec27d5
|
@ -84,6 +84,7 @@ import concurrent.futures
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import platform
|
||||||
import shutil
|
import shutil
|
||||||
import time
|
import time
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
|
@ -94,7 +95,6 @@ import tqdm
|
||||||
from huggingface_hub import create_branch
|
from huggingface_hub import create_branch
|
||||||
from omegaconf import DictConfig
|
from omegaconf import DictConfig
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pynput import keyboard
|
|
||||||
|
|
||||||
# from safetensors.torch import load_file, save_file
|
# from safetensors.torch import load_file, save_file
|
||||||
from lerobot.common.datasets.compute_stats import compute_stats
|
from lerobot.common.datasets.compute_stats import compute_stats
|
||||||
|
@ -270,25 +270,38 @@ def record_dataset(
|
||||||
rerecord_episode = False
|
rerecord_episode = False
|
||||||
stop_recording = False
|
stop_recording = False
|
||||||
|
|
||||||
def on_press(key):
|
def is_headless():
|
||||||
nonlocal exit_early, rerecord_episode, stop_recording
|
if platform.system() == "Linux":
|
||||||
try:
|
display = os.environ.get("DISPLAY")
|
||||||
if key == keyboard.Key.right:
|
if display is None or display == "":
|
||||||
print("Right arrow key pressed. Exiting loop...")
|
return True
|
||||||
exit_early = True
|
return False
|
||||||
elif key == keyboard.Key.left:
|
|
||||||
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
|
|
||||||
rerecord_episode = True
|
|
||||||
exit_early = True
|
|
||||||
elif key == keyboard.Key.esc:
|
|
||||||
print("Escape key pressed. Stopping data recording...")
|
|
||||||
stop_recording = True
|
|
||||||
exit_early = True
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error handling key press: {e}")
|
|
||||||
|
|
||||||
listener = keyboard.Listener(on_press=on_press)
|
# Only import pynput if not in a headless environment
|
||||||
listener.start()
|
if is_headless():
|
||||||
|
logging.info("Headless environment detected. Keyboard input will not be available.")
|
||||||
|
else:
|
||||||
|
from pynput import keyboard
|
||||||
|
|
||||||
|
def on_press(key):
|
||||||
|
nonlocal exit_early, rerecord_episode, stop_recording
|
||||||
|
try:
|
||||||
|
if key == keyboard.Key.right:
|
||||||
|
print("Right arrow key pressed. Exiting loop...")
|
||||||
|
exit_early = True
|
||||||
|
elif key == keyboard.Key.left:
|
||||||
|
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
|
||||||
|
rerecord_episode = True
|
||||||
|
exit_early = True
|
||||||
|
elif key == keyboard.Key.esc:
|
||||||
|
print("Escape key pressed. Stopping data recording...")
|
||||||
|
stop_recording = True
|
||||||
|
exit_early = True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error handling key press: {e}")
|
||||||
|
|
||||||
|
listener = keyboard.Listener(on_press=on_press)
|
||||||
|
listener.start()
|
||||||
|
|
||||||
# Save images using threads to reach high fps (30 and more)
|
# Save images using threads to reach high fps (30 and more)
|
||||||
# Using `with` to exist smoothly if an execption is raised.
|
# Using `with` to exist smoothly if an execption is raised.
|
||||||
|
@ -408,8 +421,10 @@ def record_dataset(
|
||||||
if stop_recording or episode_index == num_episodes:
|
if stop_recording or episode_index == num_episodes:
|
||||||
logging.info("Done recording")
|
logging.info("Done recording")
|
||||||
os.system('say "Done recording"')
|
os.system('say "Done recording"')
|
||||||
|
if not is_headless():
|
||||||
|
listener.stop()
|
||||||
|
|
||||||
logging.info("Waiting for threads writing the images on disk to terminate...")
|
logging.info("Waiting for threads writing the images on disk to terminate...")
|
||||||
listener.stop()
|
|
||||||
for _ in tqdm.tqdm(
|
for _ in tqdm.tqdm(
|
||||||
concurrent.futures.as_completed(futures), total=len(futures), desc="Writting images"
|
concurrent.futures.as_completed(futures), total=len(futures), desc="Writting images"
|
||||||
):
|
):
|
||||||
|
|
Loading…
Reference in New Issue