Display cameras and add reset env in run_policy (WIP)

This commit is contained in:
Remi Cadene 2024-07-17 10:37:23 +02:00
parent 6fc1d0dfc1
commit 3e2c43296c
2 changed files with 124 additions and 8 deletions

View File

@ -2,6 +2,7 @@ from pathlib import Path
from typing import Protocol
import cv2
import einops
import numpy as np
@ -39,6 +40,16 @@ def save_depth_image(depth, path, write_shape=False):
cv2.imwrite(str(path), depth_image)
def convert_torch_image_to_cv2(tensor, rgb_to_bgr=True):
assert tensor.ndim == 3
c, h, w = tensor.shape
assert c < h and c < w
color_image = einops.rearrange(tensor, "c h w -> h w c").numpy()
if rgb_to_bgr:
color_image = cv2.cvtColor(color_image, cv2.COLOR_RGB2BGR)
return color_image
# Defines a camera type
class Camera(Protocol):
def connect(self): ...

View File

@ -88,8 +88,10 @@ import platform
import shutil
import time
from contextlib import nullcontext
from functools import cache
from pathlib import Path
import cv2
import torch
import tqdm
from huggingface_hub import create_branch
@ -105,6 +107,7 @@ from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episod
from lerobot.common.datasets.utils import calculate_episode_data_index
from lerobot.common.datasets.video_utils import encode_video_frames
from lerobot.common.policies.factory import make_policy
from lerobot.common.robot_devices.cameras.utils import convert_torch_image_to_cv2
from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
@ -179,7 +182,8 @@ def log_control_info(robot, dt_s, episode_index=None, frame_index=None, fps=None
logging.info(info_str)
def get_is_headless():
@cache
def is_headless():
if platform.system() == "Linux":
display = os.environ.get("DISPLAY")
if display is None or display == "":
@ -255,7 +259,10 @@ def record_dataset(
else:
episode_index = 0
is_headless = get_is_headless()
if is_headless():
logging.info(
"Headless environment detected. Display cameras on screen and keyboard inputs will not be available."
)
# Execute a few seconds without recording data, to give times
# to the robot devices to connect and start synchronizing.
@ -269,10 +276,14 @@ def record_dataset(
is_warmup_print = True
now = time.perf_counter()
observation, action = robot.teleop_step(record_data=True)
if not is_headless:
if not is_headless():
image_keys = [key for key in observation if "image" in key]
for key in image_keys:
cv2.imshow(key, convert_torch_image_to_cv2(observation[key]))
cv2.waitKey(1)
dt_s = time.perf_counter() - now
busy_wait(1 / fps - dt_s)
@ -290,9 +301,7 @@ def record_dataset(
stop_recording = False
# Only import pynput if not in a headless environment
if is_headless:
logging.info("Headless environment detected. Keyboard input will not be available.")
else:
if not is_headless():
from pynput import keyboard
def on_press(key):
@ -342,6 +351,12 @@ def record_dataset(
)
]
if not is_headless():
image_keys = [key for key in observation if "image" in key]
for key in image_keys:
cv2.imshow(key, convert_torch_image_to_cv2(observation[key]))
cv2.waitKey(1)
for key in not_image_keys:
if key not in ep_dict:
ep_dict[key] = []
@ -434,7 +449,7 @@ def record_dataset(
if is_last_episode:
logging.info("Done recording")
os.system('say "Done recording"')
if not is_headless:
if not is_headless():
listener.stop()
logging.info("Waiting for threads writing the images on disk to terminate...")
@ -543,7 +558,14 @@ def replay_episode(robot: Robot, episode: int, fps: int | None = None, root="dat
log_control_info(robot, dt_s, fps=fps)
def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig, run_time_s: float | None = None):
def run_policy(
robot: Robot,
policy: torch.nn.Module,
hydra_cfg: DictConfig,
warmup_time_s: float = 4,
run_time_s: float | None = None,
reset_time_s: float = 15,
):
# TODO(rcadene): Add option to record eval dataset and logs
# Check device is available
@ -561,12 +583,76 @@ def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig, run
if not robot.is_connected:
robot.connect()
if is_headless():
logging.info(
"Headless environment detected. Display cameras on screen and keyboard inputs will not be available."
)
# Execute a few seconds without recording data, to give times
# to the robot devices to connect and start synchronizing.
timestamp = 0
start_time = time.perf_counter()
is_warmup_print = False
while timestamp < warmup_time_s:
if not is_warmup_print:
logging.info("Warming up (no data recording)")
os.system('say "Warmup" &')
is_warmup_print = True
now = time.perf_counter()
observation = robot.capture_observation()
if not is_headless():
image_keys = [key for key in observation if "image" in key]
for key in image_keys:
cv2.imshow(key, convert_torch_image_to_cv2(observation[key]))
cv2.waitKey(1)
dt_s = time.perf_counter() - now
busy_wait(1 / fps - dt_s)
dt_s = time.perf_counter() - now
log_control_info(robot, dt_s, fps=fps)
timestamp = time.perf_counter() - start_time
# Allow to reset environment or exit early
# by tapping the right arrow key '->'. This might require a sudo permission
# to allow your terminal to monitor keyboard events.
reset_environment = False
exit_reset = False
# Only import pynput if not in a headless environment
if not is_headless():
from pynput import keyboard
def on_press(key):
nonlocal reset_environment, exit_reset
try:
if key == keyboard.Key.right and not reset_environment:
print("Right arrow key pressed. Suspend robot control to reset environment...")
reset_environment = True
elif key == keyboard.Key.right and reset_environment:
print("Right arrow key pressed. Enable robot control and exit reset environment...")
exit_reset = True
except Exception as e:
print(f"Error handling key press: {e}")
listener = keyboard.Listener(on_press=on_press)
listener.start()
start_time = time.perf_counter()
while True:
now = time.perf_counter()
observation = robot.capture_observation()
if not is_headless():
image_keys = [key for key in observation if "image" in key]
for key in image_keys:
cv2.imshow(key, convert_torch_image_to_cv2(observation[key]))
cv2.waitKey(1)
with (
torch.inference_mode(),
torch.autocast(device_type=device.type)
@ -597,6 +683,25 @@ def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig, run
if run_time_s is not None and time.perf_counter() - start_time > run_time_s:
break
if reset_environment:
# Start resetting env while the executor are finishing
logging.info("Reset the environment")
os.system('say "Reset the environment" &')
# Wait if necessary
timestamp = 0
start_time = time.perf_counter()
with tqdm.tqdm(total=reset_time_s, desc="Waiting") as pbar:
while timestamp < reset_time_s:
time.sleep(1)
timestamp = time.perf_counter() - start_time
pbar.update(1)
if exit_reset:
exit_reset = False
break
reset_environment = False
if __name__ == "__main__":
parser = argparse.ArgumentParser()