Display cameras and add reset env in run_policy (WIP)
This commit is contained in:
parent
6fc1d0dfc1
commit
3e2c43296c
lerobot
|
@ -2,6 +2,7 @@ from pathlib import Path
|
|||
from typing import Protocol
|
||||
|
||||
import cv2
|
||||
import einops
|
||||
import numpy as np
|
||||
|
||||
|
||||
|
@ -39,6 +40,16 @@ def save_depth_image(depth, path, write_shape=False):
|
|||
cv2.imwrite(str(path), depth_image)
|
||||
|
||||
|
||||
def convert_torch_image_to_cv2(tensor, rgb_to_bgr=True):
|
||||
assert tensor.ndim == 3
|
||||
c, h, w = tensor.shape
|
||||
assert c < h and c < w
|
||||
color_image = einops.rearrange(tensor, "c h w -> h w c").numpy()
|
||||
if rgb_to_bgr:
|
||||
color_image = cv2.cvtColor(color_image, cv2.COLOR_RGB2BGR)
|
||||
return color_image
|
||||
|
||||
|
||||
# Defines a camera type
|
||||
class Camera(Protocol):
|
||||
def connect(self): ...
|
||||
|
|
|
@ -88,8 +88,10 @@ import platform
|
|||
import shutil
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from functools import cache
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
import tqdm
|
||||
from huggingface_hub import create_branch
|
||||
|
@ -105,6 +107,7 @@ from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episod
|
|||
from lerobot.common.datasets.utils import calculate_episode_data_index
|
||||
from lerobot.common.datasets.video_utils import encode_video_frames
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.robot_devices.cameras.utils import convert_torch_image_to_cv2
|
||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
|
||||
|
@ -179,7 +182,8 @@ def log_control_info(robot, dt_s, episode_index=None, frame_index=None, fps=None
|
|||
logging.info(info_str)
|
||||
|
||||
|
||||
def get_is_headless():
|
||||
@cache
|
||||
def is_headless():
|
||||
if platform.system() == "Linux":
|
||||
display = os.environ.get("DISPLAY")
|
||||
if display is None or display == "":
|
||||
|
@ -255,7 +259,10 @@ def record_dataset(
|
|||
else:
|
||||
episode_index = 0
|
||||
|
||||
is_headless = get_is_headless()
|
||||
if is_headless():
|
||||
logging.info(
|
||||
"Headless environment detected. Display cameras on screen and keyboard inputs will not be available."
|
||||
)
|
||||
|
||||
# Execute a few seconds without recording data, to give times
|
||||
# to the robot devices to connect and start synchronizing.
|
||||
|
@ -269,10 +276,14 @@ def record_dataset(
|
|||
is_warmup_print = True
|
||||
|
||||
now = time.perf_counter()
|
||||
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
|
||||
if not is_headless:
|
||||
if not is_headless():
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
for key in image_keys:
|
||||
cv2.imshow(key, convert_torch_image_to_cv2(observation[key]))
|
||||
cv2.waitKey(1)
|
||||
|
||||
dt_s = time.perf_counter() - now
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
@ -290,9 +301,7 @@ def record_dataset(
|
|||
stop_recording = False
|
||||
|
||||
# Only import pynput if not in a headless environment
|
||||
if is_headless:
|
||||
logging.info("Headless environment detected. Keyboard input will not be available.")
|
||||
else:
|
||||
if not is_headless():
|
||||
from pynput import keyboard
|
||||
|
||||
def on_press(key):
|
||||
|
@ -342,6 +351,12 @@ def record_dataset(
|
|||
)
|
||||
]
|
||||
|
||||
if not is_headless():
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
for key in image_keys:
|
||||
cv2.imshow(key, convert_torch_image_to_cv2(observation[key]))
|
||||
cv2.waitKey(1)
|
||||
|
||||
for key in not_image_keys:
|
||||
if key not in ep_dict:
|
||||
ep_dict[key] = []
|
||||
|
@ -434,7 +449,7 @@ def record_dataset(
|
|||
if is_last_episode:
|
||||
logging.info("Done recording")
|
||||
os.system('say "Done recording"')
|
||||
if not is_headless:
|
||||
if not is_headless():
|
||||
listener.stop()
|
||||
|
||||
logging.info("Waiting for threads writing the images on disk to terminate...")
|
||||
|
@ -543,7 +558,14 @@ def replay_episode(robot: Robot, episode: int, fps: int | None = None, root="dat
|
|||
log_control_info(robot, dt_s, fps=fps)
|
||||
|
||||
|
||||
def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig, run_time_s: float | None = None):
|
||||
def run_policy(
|
||||
robot: Robot,
|
||||
policy: torch.nn.Module,
|
||||
hydra_cfg: DictConfig,
|
||||
warmup_time_s: float = 4,
|
||||
run_time_s: float | None = None,
|
||||
reset_time_s: float = 15,
|
||||
):
|
||||
# TODO(rcadene): Add option to record eval dataset and logs
|
||||
|
||||
# Check device is available
|
||||
|
@ -561,12 +583,76 @@ def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig, run
|
|||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
||||
if is_headless():
|
||||
logging.info(
|
||||
"Headless environment detected. Display cameras on screen and keyboard inputs will not be available."
|
||||
)
|
||||
|
||||
# Execute a few seconds without recording data, to give times
|
||||
# to the robot devices to connect and start synchronizing.
|
||||
timestamp = 0
|
||||
start_time = time.perf_counter()
|
||||
is_warmup_print = False
|
||||
while timestamp < warmup_time_s:
|
||||
if not is_warmup_print:
|
||||
logging.info("Warming up (no data recording)")
|
||||
os.system('say "Warmup" &')
|
||||
is_warmup_print = True
|
||||
|
||||
now = time.perf_counter()
|
||||
observation = robot.capture_observation()
|
||||
|
||||
if not is_headless():
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
for key in image_keys:
|
||||
cv2.imshow(key, convert_torch_image_to_cv2(observation[key]))
|
||||
cv2.waitKey(1)
|
||||
|
||||
dt_s = time.perf_counter() - now
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
dt_s = time.perf_counter() - now
|
||||
log_control_info(robot, dt_s, fps=fps)
|
||||
|
||||
timestamp = time.perf_counter() - start_time
|
||||
|
||||
# Allow to reset environment or exit early
|
||||
# by tapping the right arrow key '->'. This might require a sudo permission
|
||||
# to allow your terminal to monitor keyboard events.
|
||||
reset_environment = False
|
||||
exit_reset = False
|
||||
|
||||
# Only import pynput if not in a headless environment
|
||||
if not is_headless():
|
||||
from pynput import keyboard
|
||||
|
||||
def on_press(key):
|
||||
nonlocal reset_environment, exit_reset
|
||||
try:
|
||||
if key == keyboard.Key.right and not reset_environment:
|
||||
print("Right arrow key pressed. Suspend robot control to reset environment...")
|
||||
reset_environment = True
|
||||
elif key == keyboard.Key.right and reset_environment:
|
||||
print("Right arrow key pressed. Enable robot control and exit reset environment...")
|
||||
exit_reset = True
|
||||
except Exception as e:
|
||||
print(f"Error handling key press: {e}")
|
||||
|
||||
listener = keyboard.Listener(on_press=on_press)
|
||||
listener.start()
|
||||
|
||||
start_time = time.perf_counter()
|
||||
while True:
|
||||
now = time.perf_counter()
|
||||
|
||||
observation = robot.capture_observation()
|
||||
|
||||
if not is_headless():
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
for key in image_keys:
|
||||
cv2.imshow(key, convert_torch_image_to_cv2(observation[key]))
|
||||
cv2.waitKey(1)
|
||||
|
||||
with (
|
||||
torch.inference_mode(),
|
||||
torch.autocast(device_type=device.type)
|
||||
|
@ -597,6 +683,25 @@ def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig, run
|
|||
if run_time_s is not None and time.perf_counter() - start_time > run_time_s:
|
||||
break
|
||||
|
||||
if reset_environment:
|
||||
# Start resetting env while the executor are finishing
|
||||
logging.info("Reset the environment")
|
||||
os.system('say "Reset the environment" &')
|
||||
|
||||
# Wait if necessary
|
||||
timestamp = 0
|
||||
start_time = time.perf_counter()
|
||||
with tqdm.tqdm(total=reset_time_s, desc="Waiting") as pbar:
|
||||
while timestamp < reset_time_s:
|
||||
time.sleep(1)
|
||||
timestamp = time.perf_counter() - start_time
|
||||
pbar.update(1)
|
||||
if exit_reset:
|
||||
exit_reset = False
|
||||
break
|
||||
|
||||
reset_environment = False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
|
Loading…
Reference in New Issue