Display cameras and add reset env in run_policy (WIP)

This commit is contained in:
Remi Cadene 2024-07-17 10:37:23 +02:00
parent 6fc1d0dfc1
commit 3e2c43296c
2 changed files with 124 additions and 8 deletions

View File

@ -2,6 +2,7 @@ from pathlib import Path
from typing import Protocol from typing import Protocol
import cv2 import cv2
import einops
import numpy as np import numpy as np
@ -39,6 +40,16 @@ def save_depth_image(depth, path, write_shape=False):
cv2.imwrite(str(path), depth_image) cv2.imwrite(str(path), depth_image)
def convert_torch_image_to_cv2(tensor, rgb_to_bgr=True):
assert tensor.ndim == 3
c, h, w = tensor.shape
assert c < h and c < w
color_image = einops.rearrange(tensor, "c h w -> h w c").numpy()
if rgb_to_bgr:
color_image = cv2.cvtColor(color_image, cv2.COLOR_RGB2BGR)
return color_image
# Defines a camera type # Defines a camera type
class Camera(Protocol): class Camera(Protocol):
def connect(self): ... def connect(self): ...

View File

@ -88,8 +88,10 @@ import platform
import shutil import shutil
import time import time
from contextlib import nullcontext from contextlib import nullcontext
from functools import cache
from pathlib import Path from pathlib import Path
import cv2
import torch import torch
import tqdm import tqdm
from huggingface_hub import create_branch from huggingface_hub import create_branch
@ -105,6 +107,7 @@ from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episod
from lerobot.common.datasets.utils import calculate_episode_data_index from lerobot.common.datasets.utils import calculate_episode_data_index
from lerobot.common.datasets.video_utils import encode_video_frames from lerobot.common.datasets.video_utils import encode_video_frames
from lerobot.common.policies.factory import make_policy from lerobot.common.policies.factory import make_policy
from lerobot.common.robot_devices.cameras.utils import convert_torch_image_to_cv2
from lerobot.common.robot_devices.robots.factory import make_robot from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.robot_devices.robots.utils import Robot from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
@ -179,7 +182,8 @@ def log_control_info(robot, dt_s, episode_index=None, frame_index=None, fps=None
logging.info(info_str) logging.info(info_str)
def get_is_headless(): @cache
def is_headless():
if platform.system() == "Linux": if platform.system() == "Linux":
display = os.environ.get("DISPLAY") display = os.environ.get("DISPLAY")
if display is None or display == "": if display is None or display == "":
@ -255,7 +259,10 @@ def record_dataset(
else: else:
episode_index = 0 episode_index = 0
is_headless = get_is_headless() if is_headless():
logging.info(
"Headless environment detected. Display cameras on screen and keyboard inputs will not be available."
)
# Execute a few seconds without recording data, to give times # Execute a few seconds without recording data, to give times
# to the robot devices to connect and start synchronizing. # to the robot devices to connect and start synchronizing.
@ -269,10 +276,14 @@ def record_dataset(
is_warmup_print = True is_warmup_print = True
now = time.perf_counter() now = time.perf_counter()
observation, action = robot.teleop_step(record_data=True) observation, action = robot.teleop_step(record_data=True)
if not is_headless: if not is_headless():
image_keys = [key for key in observation if "image" in key] image_keys = [key for key in observation if "image" in key]
for key in image_keys:
cv2.imshow(key, convert_torch_image_to_cv2(observation[key]))
cv2.waitKey(1)
dt_s = time.perf_counter() - now dt_s = time.perf_counter() - now
busy_wait(1 / fps - dt_s) busy_wait(1 / fps - dt_s)
@ -290,9 +301,7 @@ def record_dataset(
stop_recording = False stop_recording = False
# Only import pynput if not in a headless environment # Only import pynput if not in a headless environment
if is_headless: if not is_headless():
logging.info("Headless environment detected. Keyboard input will not be available.")
else:
from pynput import keyboard from pynput import keyboard
def on_press(key): def on_press(key):
@ -342,6 +351,12 @@ def record_dataset(
) )
] ]
if not is_headless():
image_keys = [key for key in observation if "image" in key]
for key in image_keys:
cv2.imshow(key, convert_torch_image_to_cv2(observation[key]))
cv2.waitKey(1)
for key in not_image_keys: for key in not_image_keys:
if key not in ep_dict: if key not in ep_dict:
ep_dict[key] = [] ep_dict[key] = []
@ -434,7 +449,7 @@ def record_dataset(
if is_last_episode: if is_last_episode:
logging.info("Done recording") logging.info("Done recording")
os.system('say "Done recording"') os.system('say "Done recording"')
if not is_headless: if not is_headless():
listener.stop() listener.stop()
logging.info("Waiting for threads writing the images on disk to terminate...") logging.info("Waiting for threads writing the images on disk to terminate...")
@ -543,7 +558,14 @@ def replay_episode(robot: Robot, episode: int, fps: int | None = None, root="dat
log_control_info(robot, dt_s, fps=fps) log_control_info(robot, dt_s, fps=fps)
def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig, run_time_s: float | None = None): def run_policy(
robot: Robot,
policy: torch.nn.Module,
hydra_cfg: DictConfig,
warmup_time_s: float = 4,
run_time_s: float | None = None,
reset_time_s: float = 15,
):
# TODO(rcadene): Add option to record eval dataset and logs # TODO(rcadene): Add option to record eval dataset and logs
# Check device is available # Check device is available
@ -561,12 +583,76 @@ def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig, run
if not robot.is_connected: if not robot.is_connected:
robot.connect() robot.connect()
if is_headless():
logging.info(
"Headless environment detected. Display cameras on screen and keyboard inputs will not be available."
)
# Execute a few seconds without recording data, to give times
# to the robot devices to connect and start synchronizing.
timestamp = 0
start_time = time.perf_counter()
is_warmup_print = False
while timestamp < warmup_time_s:
if not is_warmup_print:
logging.info("Warming up (no data recording)")
os.system('say "Warmup" &')
is_warmup_print = True
now = time.perf_counter()
observation = robot.capture_observation()
if not is_headless():
image_keys = [key for key in observation if "image" in key]
for key in image_keys:
cv2.imshow(key, convert_torch_image_to_cv2(observation[key]))
cv2.waitKey(1)
dt_s = time.perf_counter() - now
busy_wait(1 / fps - dt_s)
dt_s = time.perf_counter() - now
log_control_info(robot, dt_s, fps=fps)
timestamp = time.perf_counter() - start_time
# Allow to reset environment or exit early
# by tapping the right arrow key '->'. This might require a sudo permission
# to allow your terminal to monitor keyboard events.
reset_environment = False
exit_reset = False
# Only import pynput if not in a headless environment
if not is_headless():
from pynput import keyboard
def on_press(key):
nonlocal reset_environment, exit_reset
try:
if key == keyboard.Key.right and not reset_environment:
print("Right arrow key pressed. Suspend robot control to reset environment...")
reset_environment = True
elif key == keyboard.Key.right and reset_environment:
print("Right arrow key pressed. Enable robot control and exit reset environment...")
exit_reset = True
except Exception as e:
print(f"Error handling key press: {e}")
listener = keyboard.Listener(on_press=on_press)
listener.start()
start_time = time.perf_counter() start_time = time.perf_counter()
while True: while True:
now = time.perf_counter() now = time.perf_counter()
observation = robot.capture_observation() observation = robot.capture_observation()
if not is_headless():
image_keys = [key for key in observation if "image" in key]
for key in image_keys:
cv2.imshow(key, convert_torch_image_to_cv2(observation[key]))
cv2.waitKey(1)
with ( with (
torch.inference_mode(), torch.inference_mode(),
torch.autocast(device_type=device.type) torch.autocast(device_type=device.type)
@ -597,6 +683,25 @@ def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig, run
if run_time_s is not None and time.perf_counter() - start_time > run_time_s: if run_time_s is not None and time.perf_counter() - start_time > run_time_s:
break break
if reset_environment:
# Start resetting env while the executor are finishing
logging.info("Reset the environment")
os.system('say "Reset the environment" &')
# Wait if necessary
timestamp = 0
start_time = time.perf_counter()
with tqdm.tqdm(total=reset_time_s, desc="Waiting") as pbar:
while timestamp < reset_time_s:
time.sleep(1)
timestamp = time.perf_counter() - start_time
pbar.update(1)
if exit_reset:
exit_reset = False
break
reset_environment = False
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()