Update control_utils.py

This commit is contained in:
DUDULRX 2025-03-05 20:11:05 +08:00
parent e060689dc1
commit 472853a818
1 changed files with 64 additions and 28 deletions

View File

@ -113,46 +113,73 @@ def predict_action(observation, policy, device, use_amp):
return action return action
# def init_keyboard_listener():
# # Allow to exit early while recording an episode or resetting the environment,
# # by tapping the right arrow key '->'. This might require a sudo permission
# # to allow your terminal to monitor keyboard events.
# events = {}
# events["exit_early"] = False
# events["rerecord_episode"] = False
# events["stop_recording"] = False
# if is_headless():
# logging.warning(
# "Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
# )
# listener = None
# return listener, events
# # Only import pynput if not in a headless environment
# from pynput import keyboard
# def on_press(key):
# try:
# if key == keyboard.Key.right:
# print("Right arrow key pressed. Exiting loop...")
# events["exit_early"] = True
# elif key == keyboard.Key.left:
# print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
# events["rerecord_episode"] = True
# events["exit_early"] = True
# elif key == keyboard.Key.esc:
# print("Escape key pressed. Stopping data recording...")
# events["stop_recording"] = True
# events["exit_early"] = True
# except Exception as e:
# print(f"Error handling key press: {e}")
# listener = keyboard.Listener(on_press=on_press)
# listener.start()
# return listener, events
def init_keyboard_listener(): def init_keyboard_listener():
# Allow to exit early while recording an episode or resetting the environment,
# by tapping the right arrow key '->'. This might require a sudo permission
# to allow your terminal to monitor keyboard events.
events = {} events = {}
events["exit_early"] = False events["exit_early"] = False
events["rerecord_episode"] = False events["rerecord_episode"] = False
events["stop_recording"] = False events["stop_recording"] = False
from sshkeyboard import listen_keyboard
if is_headless(): import threading
logging.warning(
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
)
listener = None
return listener, events
# Only import pynput if not in a headless environment
from pynput import keyboard
def on_press(key): def on_press(key):
try: try:
if key == keyboard.Key.right: if key == "right":
print("Right arrow key pressed. Exiting loop...") print("Right arrow key pressed. Exiting loop...")
events["exit_early"] = True events["exit_early"] = True
elif key == keyboard.Key.left: elif key == "left":
print("Left arrow key pressed. Exiting loop and rerecord the last episode...") print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
events["rerecord_episode"] = True events["rerecord_episode"] = True
events["exit_early"] = True events["exit_early"] = True
elif key == keyboard.Key.esc: elif key == "q":
print("Escape key pressed. Stopping data recording...") print("Q key pressed. Stopping data recording...")
events["stop_recording"] = True events["stop_recording"] = True
events["exit_early"] = True events["exit_early"] = True
except Exception as e: except Exception as e:
print(f"Error handling key press: {e}") print(f"Error handling key press: {e}")
listener = keyboard.Listener(on_press=on_press) listener = threading.Thread(target=listen_keyboard, kwargs={"on_press": on_press})
listener.start() listener.start()
return listener, events return listener,events
def warmup_record( def warmup_record(
robot, robot,
@ -256,7 +283,8 @@ def control_loop(
frame = {**observation, **action, "task": single_task} frame = {**observation, **action, "task": single_task}
dataset.add_frame(frame) dataset.add_frame(frame)
if display_cameras and not is_headless(): if display_cameras:
# if display_cameras and not is_headless():
image_keys = [key for key in observation if "image" in key] image_keys = [key for key in observation if "image" in key]
for key in image_keys: for key in image_keys:
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)) cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
@ -288,18 +316,26 @@ def reset_environment(robot, events, reset_time_s, fps):
teleoperate=True, teleoperate=True,
) )
# def stop_recording(robot, listener, display_cameras):
# robot.disconnect()
# if not is_headless():
# if listener is not None:
# listener.stop()
# if display_cameras:
# cv2.destroyAllWindows()
def stop_recording(robot, listener, display_cameras): def stop_recording(robot, listener, display_cameras):
robot.disconnect() robot.disconnect()
if not is_headless(): from sshkeyboard import stop_listening
if listener is not None: if listener is not None:
listener.stop() stop_listening()
if display_cameras: if display_cameras:
cv2.destroyAllWindows() cv2.destroyAllWindows()
def sanity_check_dataset_name(repo_id, policy_cfg): def sanity_check_dataset_name(repo_id, policy_cfg):
_, dataset_name = repo_id.split("/") _, dataset_name = repo_id.split("/")
# either repo_id doesnt start with "eval_" and there is no policy # either repo_id doesnt start with "eval_" and there is no policy