######################################################################################## # Utilities ######################################################################################## import logging import time import traceback from contextlib import nullcontext from copy import copy from functools import cache import cv2 import torch import tqdm from deepdiff import DeepDiff from termcolor import colored from lerobot.common.datasets.image_writer import safe_stop_image_writer from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.utils import get_features_from_robot from lerobot.common.policies.factory import make_policy from lerobot.common.robot_devices.robots.utils import Robot from lerobot.common.robot_devices.utils import busy_wait from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, set_global_seed from lerobot.scripts.eval import get_pretrained_policy_path def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None): log_items = [] if episode_index is not None: log_items.append(f"ep:{episode_index}") if frame_index is not None: log_items.append(f"frame:{frame_index}") def log_dt(shortname, dt_val_s): nonlocal log_items, fps info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1/ dt_val_s:3.1f}hz)" if fps is not None: actual_fps = 1 / dt_val_s if actual_fps < fps - 1: info_str = colored(info_str, "yellow") log_items.append(info_str) # total step time displayed in milliseconds and its frequency log_dt("dt", dt_s) # TODO(aliberts): move robot-specific logs logic in robot.print_logs() if not robot.robot_type.startswith(("stretch", "Reachy")): for name in robot.leader_arms: key = f"read_leader_{name}_pos_dt_s" if key in robot.logs: log_dt("dtRlead", robot.logs[key]) for name in robot.follower_arms: key = f"write_follower_{name}_goal_pos_dt_s" if key in robot.logs: log_dt("dtWfoll", robot.logs[key]) key = f"read_follower_{name}_pos_dt_s" if key in robot.logs: log_dt("dtRfoll", robot.logs[key]) for name in robot.cameras: key = f"read_camera_{name}_dt_s" if key in robot.logs: log_dt(f"dtR{name}", robot.logs[key]) info_str = " ".join(log_items) logging.info(info_str) @cache def is_headless(): """Detects if python is running without a monitor.""" try: import pynput # noqa return False except Exception: print( "Error trying to import pynput. Switching to headless mode. " "As a result, the video stream from the cameras won't be shown, " "and you won't be able to change the control flow with keyboards. " "For more info, see traceback below.\n" ) traceback.print_exc() print() return True def has_method(_object: object, method_name: str): return hasattr(_object, method_name) and callable(getattr(_object, method_name)) def predict_action(observation, policy, device, use_amp): observation = copy(observation) with ( torch.inference_mode(), torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(), ): # Convert to pytorch format: channel first and float32 in [0,1] with batch dimension for name in observation: if "image" in name: observation[name] = observation[name].type(torch.float32) / 255 observation[name] = observation[name].permute(2, 0, 1).contiguous() observation[name] = observation[name].unsqueeze(0) observation[name] = observation[name].to(device) # Compute the next action with the policy # based on the current observation action = policy.select_action(observation) # Remove batch dimension action = action.squeeze(0) # Move to cpu, if not already the case action = action.to("cpu") return action def init_keyboard_listener(): # Allow to exit early while recording an episode or resetting the environment, # by tapping the right arrow key '->'. This might require a sudo permission # to allow your terminal to monitor keyboard events. events = {} events["exit_early"] = False events["rerecord_episode"] = False events["stop_recording"] = False if is_headless(): logging.warning( "Headless environment detected. On-screen cameras display and keyboard inputs will not be available." ) listener = None return listener, events # Only import pynput if not in a headless environment from pynput import keyboard def on_press(key): try: if key == keyboard.Key.right: print("Right arrow key pressed. Exiting loop...") events["exit_early"] = True elif key == keyboard.Key.left: print("Left arrow key pressed. Exiting loop and rerecord the last episode...") events["rerecord_episode"] = True events["exit_early"] = True elif key == keyboard.Key.esc: print("Escape key pressed. Stopping data recording...") events["stop_recording"] = True events["exit_early"] = True except Exception as e: print(f"Error handling key press: {e}") listener = keyboard.Listener(on_press=on_press) listener.start() return listener, events def init_policy(pretrained_policy_name_or_path, policy_overrides): """Instantiate the policy and load fps, device and use_amp from config yaml""" pretrained_policy_path = get_pretrained_policy_path(pretrained_policy_name_or_path) hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", policy_overrides) policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path) # Check device is available device = get_safe_torch_device(hydra_cfg.device, log=True) use_amp = hydra_cfg.use_amp policy_fps = hydra_cfg.env.fps policy.eval() policy.to(device) torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True set_global_seed(hydra_cfg.seed) return policy, policy_fps, device, use_amp def warmup_record( robot, events, enable_teloperation, warmup_time_s, display_cameras, fps, ): control_loop( robot=robot, control_time_s=warmup_time_s, display_cameras=display_cameras, events=events, fps=fps, teleoperate=enable_teloperation, ) def record_episode( robot, dataset, events, episode_time_s, display_cameras, policy, device, use_amp, fps, ): control_loop( robot=robot, control_time_s=episode_time_s, display_cameras=display_cameras, dataset=dataset, events=events, policy=policy, device=device, use_amp=use_amp, fps=fps, teleoperate=policy is None, ) @safe_stop_image_writer def control_loop( robot, control_time_s=None, teleoperate=False, display_cameras=False, dataset: LeRobotDataset | None = None, events=None, policy=None, device=None, use_amp=None, fps=None, ): # TODO(rcadene): Add option to record logs if not robot.is_connected: robot.connect() if events is None: events = {"exit_early": False} if control_time_s is None: control_time_s = float("inf") if teleoperate and policy is not None: raise ValueError("When `teleoperate` is True, `policy` should be None.") if dataset is not None and fps is not None and dataset.fps != fps: raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).") timestamp = 0 start_episode_t = time.perf_counter() while timestamp < control_time_s: start_loop_t = time.perf_counter() if teleoperate: observation, action = robot.teleop_step(record_data=True) else: observation = robot.capture_observation() if policy is not None: pred_action = predict_action(observation, policy, device, use_amp) # Action can eventually be clipped using `max_relative_target`, # so action actually sent is saved in the dataset. action = robot.send_action(pred_action) action = {"action": action} if dataset is not None: frame = {**observation, **action} dataset.add_frame(frame) if display_cameras and not is_headless(): image_keys = [key for key in observation if "image" in key] for key in image_keys: cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)) cv2.waitKey(1) if fps is not None: dt_s = time.perf_counter() - start_loop_t busy_wait(1 / fps - dt_s) dt_s = time.perf_counter() - start_loop_t log_control_info(robot, dt_s, fps=fps) timestamp = time.perf_counter() - start_episode_t if events["exit_early"]: events["exit_early"] = False break def reset_environment(robot, events, reset_time_s): # TODO(rcadene): refactor warmup_record and reset_environment # TODO(alibets): allow for teleop during reset if has_method(robot, "teleop_safety_stop"): robot.teleop_safety_stop() timestamp = 0 start_vencod_t = time.perf_counter() # Wait if necessary with tqdm.tqdm(total=reset_time_s, desc="Waiting") as pbar: while timestamp < reset_time_s: time.sleep(1) timestamp = time.perf_counter() - start_vencod_t pbar.update(1) if events["exit_early"]: events["exit_early"] = False break def stop_recording(robot, listener, display_cameras): robot.disconnect() if not is_headless(): if listener is not None: listener.stop() if display_cameras: cv2.destroyAllWindows() def sanity_check_dataset_name(repo_id, policy): _, dataset_name = repo_id.split("/") # either repo_id doesnt start with "eval_" and there is no policy # or repo_id starts with "eval_" and there is a policy # Check if dataset_name starts with "eval_" but policy is missing if dataset_name.startswith("eval_") and policy is None: raise ValueError( f"Your dataset name begins with 'eval_' ({dataset_name}), but no policy is provided." ) # Check if dataset_name does not start with "eval_" but policy is provided if not dataset_name.startswith("eval_") and policy is not None: raise ValueError( f"Your dataset name does not begin with 'eval_' ({dataset_name}), but a policy is provided ({policy})." ) def sanity_check_dataset_robot_compatibility( dataset: LeRobotDataset, robot: Robot, fps: int, use_videos: bool ) -> None: fields = [ ("robot_type", dataset.meta.robot_type, robot.robot_type), ("fps", dataset.fps, fps), ("features", dataset.features, get_features_from_robot(robot, use_videos)), ] mismatches = [] for field, dataset_value, present_value in fields: diff = DeepDiff(dataset_value, present_value) if diff: mismatches.append(f"{field}: expected {present_value}, got {dataset_value}") if mismatches: raise ValueError( "Dataset metadata compatibility check failed with mismatches:\n" + "\n".join(mismatches) )