Display cameras and add reset env in run_policy (WIP)
This commit is contained in:
parent
6fc1d0dfc1
commit
3e2c43296c
|
@ -2,6 +2,7 @@ from pathlib import Path
|
||||||
from typing import Protocol
|
from typing import Protocol
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
|
import einops
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
@ -39,6 +40,16 @@ def save_depth_image(depth, path, write_shape=False):
|
||||||
cv2.imwrite(str(path), depth_image)
|
cv2.imwrite(str(path), depth_image)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_torch_image_to_cv2(tensor, rgb_to_bgr=True):
|
||||||
|
assert tensor.ndim == 3
|
||||||
|
c, h, w = tensor.shape
|
||||||
|
assert c < h and c < w
|
||||||
|
color_image = einops.rearrange(tensor, "c h w -> h w c").numpy()
|
||||||
|
if rgb_to_bgr:
|
||||||
|
color_image = cv2.cvtColor(color_image, cv2.COLOR_RGB2BGR)
|
||||||
|
return color_image
|
||||||
|
|
||||||
|
|
||||||
# Defines a camera type
|
# Defines a camera type
|
||||||
class Camera(Protocol):
|
class Camera(Protocol):
|
||||||
def connect(self): ...
|
def connect(self): ...
|
||||||
|
|
|
@ -88,8 +88,10 @@ import platform
|
||||||
import shutil
|
import shutil
|
||||||
import time
|
import time
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
|
from functools import cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import cv2
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
from huggingface_hub import create_branch
|
from huggingface_hub import create_branch
|
||||||
|
@ -105,6 +107,7 @@ from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episod
|
||||||
from lerobot.common.datasets.utils import calculate_episode_data_index
|
from lerobot.common.datasets.utils import calculate_episode_data_index
|
||||||
from lerobot.common.datasets.video_utils import encode_video_frames
|
from lerobot.common.datasets.video_utils import encode_video_frames
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
|
from lerobot.common.robot_devices.cameras.utils import convert_torch_image_to_cv2
|
||||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||||
from lerobot.common.robot_devices.robots.utils import Robot
|
from lerobot.common.robot_devices.robots.utils import Robot
|
||||||
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
|
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
|
||||||
|
@ -179,7 +182,8 @@ def log_control_info(robot, dt_s, episode_index=None, frame_index=None, fps=None
|
||||||
logging.info(info_str)
|
logging.info(info_str)
|
||||||
|
|
||||||
|
|
||||||
def get_is_headless():
|
@cache
|
||||||
|
def is_headless():
|
||||||
if platform.system() == "Linux":
|
if platform.system() == "Linux":
|
||||||
display = os.environ.get("DISPLAY")
|
display = os.environ.get("DISPLAY")
|
||||||
if display is None or display == "":
|
if display is None or display == "":
|
||||||
|
@ -255,7 +259,10 @@ def record_dataset(
|
||||||
else:
|
else:
|
||||||
episode_index = 0
|
episode_index = 0
|
||||||
|
|
||||||
is_headless = get_is_headless()
|
if is_headless():
|
||||||
|
logging.info(
|
||||||
|
"Headless environment detected. Display cameras on screen and keyboard inputs will not be available."
|
||||||
|
)
|
||||||
|
|
||||||
# Execute a few seconds without recording data, to give times
|
# Execute a few seconds without recording data, to give times
|
||||||
# to the robot devices to connect and start synchronizing.
|
# to the robot devices to connect and start synchronizing.
|
||||||
|
@ -269,10 +276,14 @@ def record_dataset(
|
||||||
is_warmup_print = True
|
is_warmup_print = True
|
||||||
|
|
||||||
now = time.perf_counter()
|
now = time.perf_counter()
|
||||||
|
|
||||||
observation, action = robot.teleop_step(record_data=True)
|
observation, action = robot.teleop_step(record_data=True)
|
||||||
|
|
||||||
if not is_headless:
|
if not is_headless():
|
||||||
image_keys = [key for key in observation if "image" in key]
|
image_keys = [key for key in observation if "image" in key]
|
||||||
|
for key in image_keys:
|
||||||
|
cv2.imshow(key, convert_torch_image_to_cv2(observation[key]))
|
||||||
|
cv2.waitKey(1)
|
||||||
|
|
||||||
dt_s = time.perf_counter() - now
|
dt_s = time.perf_counter() - now
|
||||||
busy_wait(1 / fps - dt_s)
|
busy_wait(1 / fps - dt_s)
|
||||||
|
@ -290,9 +301,7 @@ def record_dataset(
|
||||||
stop_recording = False
|
stop_recording = False
|
||||||
|
|
||||||
# Only import pynput if not in a headless environment
|
# Only import pynput if not in a headless environment
|
||||||
if is_headless:
|
if not is_headless():
|
||||||
logging.info("Headless environment detected. Keyboard input will not be available.")
|
|
||||||
else:
|
|
||||||
from pynput import keyboard
|
from pynput import keyboard
|
||||||
|
|
||||||
def on_press(key):
|
def on_press(key):
|
||||||
|
@ -342,6 +351,12 @@ def record_dataset(
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if not is_headless():
|
||||||
|
image_keys = [key for key in observation if "image" in key]
|
||||||
|
for key in image_keys:
|
||||||
|
cv2.imshow(key, convert_torch_image_to_cv2(observation[key]))
|
||||||
|
cv2.waitKey(1)
|
||||||
|
|
||||||
for key in not_image_keys:
|
for key in not_image_keys:
|
||||||
if key not in ep_dict:
|
if key not in ep_dict:
|
||||||
ep_dict[key] = []
|
ep_dict[key] = []
|
||||||
|
@ -434,7 +449,7 @@ def record_dataset(
|
||||||
if is_last_episode:
|
if is_last_episode:
|
||||||
logging.info("Done recording")
|
logging.info("Done recording")
|
||||||
os.system('say "Done recording"')
|
os.system('say "Done recording"')
|
||||||
if not is_headless:
|
if not is_headless():
|
||||||
listener.stop()
|
listener.stop()
|
||||||
|
|
||||||
logging.info("Waiting for threads writing the images on disk to terminate...")
|
logging.info("Waiting for threads writing the images on disk to terminate...")
|
||||||
|
@ -543,7 +558,14 @@ def replay_episode(robot: Robot, episode: int, fps: int | None = None, root="dat
|
||||||
log_control_info(robot, dt_s, fps=fps)
|
log_control_info(robot, dt_s, fps=fps)
|
||||||
|
|
||||||
|
|
||||||
def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig, run_time_s: float | None = None):
|
def run_policy(
|
||||||
|
robot: Robot,
|
||||||
|
policy: torch.nn.Module,
|
||||||
|
hydra_cfg: DictConfig,
|
||||||
|
warmup_time_s: float = 4,
|
||||||
|
run_time_s: float | None = None,
|
||||||
|
reset_time_s: float = 15,
|
||||||
|
):
|
||||||
# TODO(rcadene): Add option to record eval dataset and logs
|
# TODO(rcadene): Add option to record eval dataset and logs
|
||||||
|
|
||||||
# Check device is available
|
# Check device is available
|
||||||
|
@ -561,12 +583,76 @@ def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig, run
|
||||||
if not robot.is_connected:
|
if not robot.is_connected:
|
||||||
robot.connect()
|
robot.connect()
|
||||||
|
|
||||||
|
if is_headless():
|
||||||
|
logging.info(
|
||||||
|
"Headless environment detected. Display cameras on screen and keyboard inputs will not be available."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute a few seconds without recording data, to give times
|
||||||
|
# to the robot devices to connect and start synchronizing.
|
||||||
|
timestamp = 0
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
is_warmup_print = False
|
||||||
|
while timestamp < warmup_time_s:
|
||||||
|
if not is_warmup_print:
|
||||||
|
logging.info("Warming up (no data recording)")
|
||||||
|
os.system('say "Warmup" &')
|
||||||
|
is_warmup_print = True
|
||||||
|
|
||||||
|
now = time.perf_counter()
|
||||||
|
observation = robot.capture_observation()
|
||||||
|
|
||||||
|
if not is_headless():
|
||||||
|
image_keys = [key for key in observation if "image" in key]
|
||||||
|
for key in image_keys:
|
||||||
|
cv2.imshow(key, convert_torch_image_to_cv2(observation[key]))
|
||||||
|
cv2.waitKey(1)
|
||||||
|
|
||||||
|
dt_s = time.perf_counter() - now
|
||||||
|
busy_wait(1 / fps - dt_s)
|
||||||
|
|
||||||
|
dt_s = time.perf_counter() - now
|
||||||
|
log_control_info(robot, dt_s, fps=fps)
|
||||||
|
|
||||||
|
timestamp = time.perf_counter() - start_time
|
||||||
|
|
||||||
|
# Allow to reset environment or exit early
|
||||||
|
# by tapping the right arrow key '->'. This might require a sudo permission
|
||||||
|
# to allow your terminal to monitor keyboard events.
|
||||||
|
reset_environment = False
|
||||||
|
exit_reset = False
|
||||||
|
|
||||||
|
# Only import pynput if not in a headless environment
|
||||||
|
if not is_headless():
|
||||||
|
from pynput import keyboard
|
||||||
|
|
||||||
|
def on_press(key):
|
||||||
|
nonlocal reset_environment, exit_reset
|
||||||
|
try:
|
||||||
|
if key == keyboard.Key.right and not reset_environment:
|
||||||
|
print("Right arrow key pressed. Suspend robot control to reset environment...")
|
||||||
|
reset_environment = True
|
||||||
|
elif key == keyboard.Key.right and reset_environment:
|
||||||
|
print("Right arrow key pressed. Enable robot control and exit reset environment...")
|
||||||
|
exit_reset = True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error handling key press: {e}")
|
||||||
|
|
||||||
|
listener = keyboard.Listener(on_press=on_press)
|
||||||
|
listener.start()
|
||||||
|
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
while True:
|
while True:
|
||||||
now = time.perf_counter()
|
now = time.perf_counter()
|
||||||
|
|
||||||
observation = robot.capture_observation()
|
observation = robot.capture_observation()
|
||||||
|
|
||||||
|
if not is_headless():
|
||||||
|
image_keys = [key for key in observation if "image" in key]
|
||||||
|
for key in image_keys:
|
||||||
|
cv2.imshow(key, convert_torch_image_to_cv2(observation[key]))
|
||||||
|
cv2.waitKey(1)
|
||||||
|
|
||||||
with (
|
with (
|
||||||
torch.inference_mode(),
|
torch.inference_mode(),
|
||||||
torch.autocast(device_type=device.type)
|
torch.autocast(device_type=device.type)
|
||||||
|
@ -597,6 +683,25 @@ def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig, run
|
||||||
if run_time_s is not None and time.perf_counter() - start_time > run_time_s:
|
if run_time_s is not None and time.perf_counter() - start_time > run_time_s:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if reset_environment:
|
||||||
|
# Start resetting env while the executor are finishing
|
||||||
|
logging.info("Reset the environment")
|
||||||
|
os.system('say "Reset the environment" &')
|
||||||
|
|
||||||
|
# Wait if necessary
|
||||||
|
timestamp = 0
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
with tqdm.tqdm(total=reset_time_s, desc="Waiting") as pbar:
|
||||||
|
while timestamp < reset_time_s:
|
||||||
|
time.sleep(1)
|
||||||
|
timestamp = time.perf_counter() - start_time
|
||||||
|
pbar.update(1)
|
||||||
|
if exit_reset:
|
||||||
|
exit_reset = False
|
||||||
|
break
|
||||||
|
|
||||||
|
reset_environment = False
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
Loading…
Reference in New Issue