Update control_utils.py
This commit is contained in:
parent
e060689dc1
commit
472853a818
|
@ -113,46 +113,73 @@ def predict_action(observation, policy, device, use_amp):
|
||||||
return action
|
return action
|
||||||
|
|
||||||
|
|
||||||
|
# def init_keyboard_listener():
|
||||||
|
# # Allow to exit early while recording an episode or resetting the environment,
|
||||||
|
# # by tapping the right arrow key '->'. This might require a sudo permission
|
||||||
|
# # to allow your terminal to monitor keyboard events.
|
||||||
|
# events = {}
|
||||||
|
# events["exit_early"] = False
|
||||||
|
# events["rerecord_episode"] = False
|
||||||
|
# events["stop_recording"] = False
|
||||||
|
|
||||||
|
# if is_headless():
|
||||||
|
# logging.warning(
|
||||||
|
# "Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
|
||||||
|
# )
|
||||||
|
# listener = None
|
||||||
|
# return listener, events
|
||||||
|
|
||||||
|
# # Only import pynput if not in a headless environment
|
||||||
|
# from pynput import keyboard
|
||||||
|
|
||||||
|
# def on_press(key):
|
||||||
|
# try:
|
||||||
|
# if key == keyboard.Key.right:
|
||||||
|
# print("Right arrow key pressed. Exiting loop...")
|
||||||
|
# events["exit_early"] = True
|
||||||
|
# elif key == keyboard.Key.left:
|
||||||
|
# print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
|
||||||
|
# events["rerecord_episode"] = True
|
||||||
|
# events["exit_early"] = True
|
||||||
|
# elif key == keyboard.Key.esc:
|
||||||
|
# print("Escape key pressed. Stopping data recording...")
|
||||||
|
# events["stop_recording"] = True
|
||||||
|
# events["exit_early"] = True
|
||||||
|
# except Exception as e:
|
||||||
|
# print(f"Error handling key press: {e}")
|
||||||
|
|
||||||
|
# listener = keyboard.Listener(on_press=on_press)
|
||||||
|
# listener.start()
|
||||||
|
|
||||||
|
# return listener, events
|
||||||
|
|
||||||
def init_keyboard_listener():
|
def init_keyboard_listener():
|
||||||
# Allow to exit early while recording an episode or resetting the environment,
|
|
||||||
# by tapping the right arrow key '->'. This might require a sudo permission
|
|
||||||
# to allow your terminal to monitor keyboard events.
|
|
||||||
events = {}
|
events = {}
|
||||||
events["exit_early"] = False
|
events["exit_early"] = False
|
||||||
events["rerecord_episode"] = False
|
events["rerecord_episode"] = False
|
||||||
events["stop_recording"] = False
|
events["stop_recording"] = False
|
||||||
|
from sshkeyboard import listen_keyboard
|
||||||
if is_headless():
|
import threading
|
||||||
logging.warning(
|
|
||||||
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
|
|
||||||
)
|
|
||||||
listener = None
|
|
||||||
return listener, events
|
|
||||||
|
|
||||||
# Only import pynput if not in a headless environment
|
|
||||||
from pynput import keyboard
|
|
||||||
|
|
||||||
def on_press(key):
|
def on_press(key):
|
||||||
try:
|
try:
|
||||||
if key == keyboard.Key.right:
|
if key == "right":
|
||||||
print("Right arrow key pressed. Exiting loop...")
|
print("Right arrow key pressed. Exiting loop...")
|
||||||
events["exit_early"] = True
|
events["exit_early"] = True
|
||||||
elif key == keyboard.Key.left:
|
elif key == "left":
|
||||||
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
|
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
|
||||||
events["rerecord_episode"] = True
|
events["rerecord_episode"] = True
|
||||||
events["exit_early"] = True
|
events["exit_early"] = True
|
||||||
elif key == keyboard.Key.esc:
|
elif key == "q":
|
||||||
print("Escape key pressed. Stopping data recording...")
|
print("Q key pressed. Stopping data recording...")
|
||||||
events["stop_recording"] = True
|
events["stop_recording"] = True
|
||||||
events["exit_early"] = True
|
events["exit_early"] = True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error handling key press: {e}")
|
print(f"Error handling key press: {e}")
|
||||||
|
|
||||||
listener = keyboard.Listener(on_press=on_press)
|
listener = threading.Thread(target=listen_keyboard, kwargs={"on_press": on_press})
|
||||||
listener.start()
|
listener.start()
|
||||||
|
|
||||||
return listener, events
|
return listener,events
|
||||||
|
|
||||||
|
|
||||||
def warmup_record(
|
def warmup_record(
|
||||||
robot,
|
robot,
|
||||||
|
@ -256,7 +283,8 @@ def control_loop(
|
||||||
frame = {**observation, **action, "task": single_task}
|
frame = {**observation, **action, "task": single_task}
|
||||||
dataset.add_frame(frame)
|
dataset.add_frame(frame)
|
||||||
|
|
||||||
if display_cameras and not is_headless():
|
if display_cameras:
|
||||||
|
# if display_cameras and not is_headless():
|
||||||
image_keys = [key for key in observation if "image" in key]
|
image_keys = [key for key in observation if "image" in key]
|
||||||
for key in image_keys:
|
for key in image_keys:
|
||||||
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||||
|
@ -288,18 +316,26 @@ def reset_environment(robot, events, reset_time_s, fps):
|
||||||
teleoperate=True,
|
teleoperate=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# def stop_recording(robot, listener, display_cameras):
|
||||||
|
# robot.disconnect()
|
||||||
|
|
||||||
|
# if not is_headless():
|
||||||
|
# if listener is not None:
|
||||||
|
# listener.stop()
|
||||||
|
|
||||||
|
# if display_cameras:
|
||||||
|
# cv2.destroyAllWindows()
|
||||||
|
|
||||||
def stop_recording(robot, listener, display_cameras):
|
def stop_recording(robot, listener, display_cameras):
|
||||||
robot.disconnect()
|
robot.disconnect()
|
||||||
|
|
||||||
if not is_headless():
|
from sshkeyboard import stop_listening
|
||||||
if listener is not None:
|
if listener is not None:
|
||||||
listener.stop()
|
stop_listening()
|
||||||
|
|
||||||
if display_cameras:
|
if display_cameras:
|
||||||
cv2.destroyAllWindows()
|
cv2.destroyAllWindows()
|
||||||
|
|
||||||
|
|
||||||
def sanity_check_dataset_name(repo_id, policy_cfg):
|
def sanity_check_dataset_name(repo_id, policy_cfg):
|
||||||
_, dataset_name = repo_id.split("/")
|
_, dataset_name = repo_id.split("/")
|
||||||
# either repo_id doesnt start with "eval_" and there is no policy
|
# either repo_id doesnt start with "eval_" and there is no policy
|
||||||
|
|
Loading…
Reference in New Issue