Update control_utils.py

This commit is contained in:
DUDULRX 2025-03-05 20:11:05 +08:00
parent e060689dc1
commit 472853a818
1 changed files with 64 additions and 28 deletions

View File

@ -113,46 +113,73 @@ def predict_action(observation, policy, device, use_amp):
return action
# def init_keyboard_listener():
# # Allow to exit early while recording an episode or resetting the environment,
# # by tapping the right arrow key '->'. This might require a sudo permission
# # to allow your terminal to monitor keyboard events.
# events = {}
# events["exit_early"] = False
# events["rerecord_episode"] = False
# events["stop_recording"] = False
# if is_headless():
# logging.warning(
# "Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
# )
# listener = None
# return listener, events
# # Only import pynput if not in a headless environment
# from pynput import keyboard
# def on_press(key):
# try:
# if key == keyboard.Key.right:
# print("Right arrow key pressed. Exiting loop...")
# events["exit_early"] = True
# elif key == keyboard.Key.left:
# print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
# events["rerecord_episode"] = True
# events["exit_early"] = True
# elif key == keyboard.Key.esc:
# print("Escape key pressed. Stopping data recording...")
# events["stop_recording"] = True
# events["exit_early"] = True
# except Exception as e:
# print(f"Error handling key press: {e}")
# listener = keyboard.Listener(on_press=on_press)
# listener.start()
# return listener, events
def init_keyboard_listener():
# Allow to exit early while recording an episode or resetting the environment,
# by tapping the right arrow key '->'. This might require a sudo permission
# to allow your terminal to monitor keyboard events.
events = {}
events["exit_early"] = False
events["rerecord_episode"] = False
events["stop_recording"] = False
if is_headless():
logging.warning(
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
)
listener = None
return listener, events
# Only import pynput if not in a headless environment
from pynput import keyboard
from sshkeyboard import listen_keyboard
import threading
def on_press(key):
try:
if key == keyboard.Key.right:
if key == "right":
print("Right arrow key pressed. Exiting loop...")
events["exit_early"] = True
elif key == keyboard.Key.left:
elif key == "left":
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
events["rerecord_episode"] = True
events["exit_early"] = True
elif key == keyboard.Key.esc:
print("Escape key pressed. Stopping data recording...")
elif key == "q":
print("Q key pressed. Stopping data recording...")
events["stop_recording"] = True
events["exit_early"] = True
except Exception as e:
print(f"Error handling key press: {e}")
listener = keyboard.Listener(on_press=on_press)
listener = threading.Thread(target=listen_keyboard, kwargs={"on_press": on_press})
listener.start()
return listener, events
return listener,events
def warmup_record(
robot,
@ -256,7 +283,8 @@ def control_loop(
frame = {**observation, **action, "task": single_task}
dataset.add_frame(frame)
if display_cameras and not is_headless():
if display_cameras:
# if display_cameras and not is_headless():
image_keys = [key for key in observation if "image" in key]
for key in image_keys:
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
@ -288,17 +316,25 @@ def reset_environment(robot, events, reset_time_s, fps):
teleoperate=True,
)
# def stop_recording(robot, listener, display_cameras):
# robot.disconnect()
# if not is_headless():
# if listener is not None:
# listener.stop()
# if display_cameras:
# cv2.destroyAllWindows()
def stop_recording(robot, listener, display_cameras):
robot.disconnect()
if not is_headless():
if listener is not None:
listener.stop()
if display_cameras:
cv2.destroyAllWindows()
from sshkeyboard import stop_listening
if listener is not None:
stop_listening()
if display_cameras:
cv2.destroyAllWindows()
def sanity_check_dataset_name(repo_id, policy_cfg):
_, dataset_name = repo_id.split("/")