Update control_utils.py
This commit is contained in:
parent
e060689dc1
commit
472853a818
|
@ -113,46 +113,73 @@ def predict_action(observation, policy, device, use_amp):
|
|||
return action
|
||||
|
||||
|
||||
# def init_keyboard_listener():
|
||||
# # Allow to exit early while recording an episode or resetting the environment,
|
||||
# # by tapping the right arrow key '->'. This might require a sudo permission
|
||||
# # to allow your terminal to monitor keyboard events.
|
||||
# events = {}
|
||||
# events["exit_early"] = False
|
||||
# events["rerecord_episode"] = False
|
||||
# events["stop_recording"] = False
|
||||
|
||||
# if is_headless():
|
||||
# logging.warning(
|
||||
# "Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
|
||||
# )
|
||||
# listener = None
|
||||
# return listener, events
|
||||
|
||||
# # Only import pynput if not in a headless environment
|
||||
# from pynput import keyboard
|
||||
|
||||
# def on_press(key):
|
||||
# try:
|
||||
# if key == keyboard.Key.right:
|
||||
# print("Right arrow key pressed. Exiting loop...")
|
||||
# events["exit_early"] = True
|
||||
# elif key == keyboard.Key.left:
|
||||
# print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
|
||||
# events["rerecord_episode"] = True
|
||||
# events["exit_early"] = True
|
||||
# elif key == keyboard.Key.esc:
|
||||
# print("Escape key pressed. Stopping data recording...")
|
||||
# events["stop_recording"] = True
|
||||
# events["exit_early"] = True
|
||||
# except Exception as e:
|
||||
# print(f"Error handling key press: {e}")
|
||||
|
||||
# listener = keyboard.Listener(on_press=on_press)
|
||||
# listener.start()
|
||||
|
||||
# return listener, events
|
||||
|
||||
def init_keyboard_listener():
|
||||
# Allow to exit early while recording an episode or resetting the environment,
|
||||
# by tapping the right arrow key '->'. This might require a sudo permission
|
||||
# to allow your terminal to monitor keyboard events.
|
||||
events = {}
|
||||
events["exit_early"] = False
|
||||
events["rerecord_episode"] = False
|
||||
events["stop_recording"] = False
|
||||
|
||||
if is_headless():
|
||||
logging.warning(
|
||||
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
|
||||
)
|
||||
listener = None
|
||||
return listener, events
|
||||
|
||||
# Only import pynput if not in a headless environment
|
||||
from pynput import keyboard
|
||||
|
||||
from sshkeyboard import listen_keyboard
|
||||
import threading
|
||||
def on_press(key):
|
||||
try:
|
||||
if key == keyboard.Key.right:
|
||||
if key == "right":
|
||||
print("Right arrow key pressed. Exiting loop...")
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.left:
|
||||
elif key == "left":
|
||||
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
|
||||
events["rerecord_episode"] = True
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.esc:
|
||||
print("Escape key pressed. Stopping data recording...")
|
||||
elif key == "q":
|
||||
print("Q key pressed. Stopping data recording...")
|
||||
events["stop_recording"] = True
|
||||
events["exit_early"] = True
|
||||
except Exception as e:
|
||||
print(f"Error handling key press: {e}")
|
||||
|
||||
listener = keyboard.Listener(on_press=on_press)
|
||||
listener = threading.Thread(target=listen_keyboard, kwargs={"on_press": on_press})
|
||||
listener.start()
|
||||
|
||||
return listener, events
|
||||
|
||||
return listener,events
|
||||
|
||||
def warmup_record(
|
||||
robot,
|
||||
|
@ -256,7 +283,8 @@ def control_loop(
|
|||
frame = {**observation, **action, "task": single_task}
|
||||
dataset.add_frame(frame)
|
||||
|
||||
if display_cameras and not is_headless():
|
||||
if display_cameras:
|
||||
# if display_cameras and not is_headless():
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
for key in image_keys:
|
||||
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||
|
@ -288,17 +316,25 @@ def reset_environment(robot, events, reset_time_s, fps):
|
|||
teleoperate=True,
|
||||
)
|
||||
|
||||
# def stop_recording(robot, listener, display_cameras):
|
||||
# robot.disconnect()
|
||||
|
||||
# if not is_headless():
|
||||
# if listener is not None:
|
||||
# listener.stop()
|
||||
|
||||
# if display_cameras:
|
||||
# cv2.destroyAllWindows()
|
||||
|
||||
def stop_recording(robot, listener, display_cameras):
|
||||
robot.disconnect()
|
||||
|
||||
if not is_headless():
|
||||
if listener is not None:
|
||||
listener.stop()
|
||||
|
||||
if display_cameras:
|
||||
cv2.destroyAllWindows()
|
||||
from sshkeyboard import stop_listening
|
||||
if listener is not None:
|
||||
stop_listening()
|
||||
|
||||
if display_cameras:
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
def sanity_check_dataset_name(repo_id, policy_cfg):
|
||||
_, dataset_name = repo_id.split("/")
|
||||
|
|
Loading…
Reference in New Issue