369 lines
13 KiB
Python
369 lines
13 KiB
Python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
########################################################################################
|
|
# Utilities
|
|
########################################################################################
|
|
|
|
|
|
import logging
|
|
import time
|
|
import traceback
|
|
from contextlib import nullcontext
|
|
from copy import copy
|
|
from functools import cache
|
|
|
|
import rerun as rr
|
|
import torch
|
|
from deepdiff import DeepDiff
|
|
from termcolor import colored
|
|
|
|
from lerobot.common.datasets.image_writer import safe_stop_image_writer
|
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
|
from lerobot.common.datasets.utils import get_features_from_robot
|
|
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
|
from lerobot.common.robot_devices.robots.utils import Robot
|
|
from lerobot.common.robot_devices.utils import busy_wait
|
|
from lerobot.common.utils.utils import get_safe_torch_device, has_method
|
|
|
|
|
|
def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None):
|
|
log_items = []
|
|
if episode_index is not None:
|
|
log_items.append(f"ep:{episode_index}")
|
|
if frame_index is not None:
|
|
log_items.append(f"frame:{frame_index}")
|
|
|
|
def log_dt(shortname, dt_val_s):
|
|
nonlocal log_items, fps
|
|
info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1 / dt_val_s:3.1f}hz)"
|
|
if fps is not None:
|
|
actual_fps = 1 / dt_val_s
|
|
if actual_fps < fps - 1:
|
|
info_str = colored(info_str, "yellow")
|
|
log_items.append(info_str)
|
|
|
|
# total step time displayed in milliseconds and its frequency
|
|
log_dt("dt", dt_s)
|
|
|
|
# TODO(aliberts): move robot-specific logs logic in robot.print_logs()
|
|
if not robot.robot_type.startswith("stretch"):
|
|
for name in robot.leader_arms:
|
|
key = f"read_leader_{name}_pos_dt_s"
|
|
if key in robot.logs:
|
|
log_dt("dtRlead", robot.logs[key])
|
|
|
|
for name in robot.follower_arms:
|
|
key = f"write_follower_{name}_goal_pos_dt_s"
|
|
if key in robot.logs:
|
|
log_dt("dtWfoll", robot.logs[key])
|
|
|
|
key = f"read_follower_{name}_pos_dt_s"
|
|
if key in robot.logs:
|
|
log_dt("dtRfoll", robot.logs[key])
|
|
|
|
for name in robot.cameras:
|
|
key = f"read_camera_{name}_dt_s"
|
|
if key in robot.logs:
|
|
log_dt(f"dtR{name}", robot.logs[key])
|
|
|
|
for name in robot.microphones:
|
|
key = f"read_microphone_{name}_dt_s"
|
|
if key in robot.logs:
|
|
log_dt(f"dtR{name}", robot.logs[key])
|
|
|
|
info_str = " ".join(log_items)
|
|
logging.info(info_str)
|
|
|
|
|
|
@cache
|
|
def is_headless():
|
|
"""Detects if python is running without a monitor."""
|
|
try:
|
|
import pynput # noqa
|
|
|
|
return False
|
|
except Exception:
|
|
print(
|
|
"Error trying to import pynput. Switching to headless mode. "
|
|
"As a result, the video stream from the cameras won't be shown, "
|
|
"and you won't be able to change the control flow with keyboards. "
|
|
"For more info, see traceback below.\n"
|
|
)
|
|
traceback.print_exc()
|
|
print()
|
|
return True
|
|
|
|
|
|
def predict_action(observation, policy, device, use_amp):
|
|
observation = copy(observation)
|
|
with (
|
|
torch.inference_mode(),
|
|
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
|
|
):
|
|
for name in observation:
|
|
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
|
|
if "image" in name:
|
|
observation[name] = observation[name].type(torch.float32) / 255
|
|
observation[name] = observation[name].permute(2, 0, 1).contiguous()
|
|
# Convert to pytorch format: channel first and float32 in [-1,1] (always the case here) with batch dimension
|
|
if "audio" in name:
|
|
observation[name] = observation[name].permute(1, 0).contiguous()
|
|
observation[name] = observation[name].unsqueeze(0)
|
|
observation[name] = observation[name].to(device)
|
|
|
|
# Compute the next action with the policy
|
|
# based on the current observation
|
|
action = policy.select_action(observation)
|
|
|
|
# Remove batch dimension
|
|
action = action.squeeze(0)
|
|
|
|
# Move to cpu, if not already the case
|
|
action = action.to("cpu")
|
|
|
|
return action
|
|
|
|
|
|
def init_keyboard_listener():
|
|
# Allow to exit early while recording an episode or resetting the environment,
|
|
# by tapping the right arrow key '->'. This might require a sudo permission
|
|
# to allow your terminal to monitor keyboard events.
|
|
events = {}
|
|
events["exit_early"] = False
|
|
events["rerecord_episode"] = False
|
|
events["stop_recording"] = False
|
|
|
|
if is_headless():
|
|
logging.warning(
|
|
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
|
|
)
|
|
listener = None
|
|
return listener, events
|
|
|
|
# Only import pynput if not in a headless environment
|
|
from pynput import keyboard
|
|
|
|
def on_press(key):
|
|
try:
|
|
if key == keyboard.Key.right:
|
|
print("Right arrow key pressed. Exiting loop...")
|
|
events["exit_early"] = True
|
|
elif key == keyboard.Key.left:
|
|
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
|
|
events["rerecord_episode"] = True
|
|
events["exit_early"] = True
|
|
elif key == keyboard.Key.esc:
|
|
print("Escape key pressed. Stopping data recording...")
|
|
events["stop_recording"] = True
|
|
events["exit_early"] = True
|
|
except Exception as e:
|
|
print(f"Error handling key press: {e}")
|
|
|
|
listener = keyboard.Listener(on_press=on_press)
|
|
listener.start()
|
|
|
|
return listener, events
|
|
|
|
|
|
def warmup_record(
|
|
robot,
|
|
events,
|
|
enable_teleoperation,
|
|
warmup_time_s,
|
|
display_data,
|
|
fps,
|
|
):
|
|
control_loop(
|
|
robot=robot,
|
|
control_time_s=warmup_time_s,
|
|
display_data=display_data,
|
|
events=events,
|
|
fps=fps,
|
|
teleoperate=enable_teleoperation,
|
|
)
|
|
|
|
|
|
def record_episode(
|
|
robot,
|
|
dataset,
|
|
events,
|
|
episode_time_s,
|
|
display_data,
|
|
policy,
|
|
fps,
|
|
single_task,
|
|
):
|
|
control_loop(
|
|
robot=robot,
|
|
control_time_s=episode_time_s,
|
|
display_data=display_data,
|
|
dataset=dataset,
|
|
events=events,
|
|
policy=policy,
|
|
fps=fps,
|
|
teleoperate=policy is None,
|
|
single_task=single_task,
|
|
)
|
|
|
|
|
|
@safe_stop_image_writer
|
|
def control_loop(
|
|
robot,
|
|
control_time_s=None,
|
|
teleoperate=False,
|
|
display_data=False,
|
|
dataset: LeRobotDataset | None = None,
|
|
events=None,
|
|
policy: PreTrainedPolicy = None,
|
|
fps: int | None = None,
|
|
single_task: str | None = None,
|
|
):
|
|
# TODO(rcadene): Add option to record logs
|
|
if not robot.is_connected:
|
|
robot.connect()
|
|
|
|
if events is None:
|
|
events = {"exit_early": False}
|
|
|
|
if control_time_s is None:
|
|
control_time_s = float("inf")
|
|
|
|
if teleoperate and policy is not None:
|
|
raise ValueError("When `teleoperate` is True, `policy` should be None.")
|
|
|
|
if dataset is not None and single_task is None:
|
|
raise ValueError("You need to provide a task as argument in `single_task`.")
|
|
|
|
if dataset is not None and fps is not None and dataset.fps != fps:
|
|
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
|
|
|
|
timestamp = 0
|
|
start_episode_t = time.perf_counter()
|
|
|
|
if dataset is not None and not robot.robot_type.startswith("lekiwi"): #For now, LeKiwi only supports frame audio recording (which may lead to audio chunks loss, extended post-processing, increased memory usage)
|
|
for microphone_key, microphone in robot.microphones.items():
|
|
#Start recording both in file writing and data reading mode
|
|
dataset.add_microphone_recording(microphone, microphone_key)
|
|
else:
|
|
for _, microphone in robot.microphones.items():
|
|
# Start recording only in data reading mode
|
|
microphone.start_recording()
|
|
|
|
while timestamp < control_time_s:
|
|
start_loop_t = time.perf_counter()
|
|
|
|
if teleoperate:
|
|
observation, action = robot.teleop_step(record_data=True)
|
|
else:
|
|
observation = robot.capture_observation()
|
|
|
|
if policy is not None:
|
|
pred_action = predict_action(
|
|
observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
|
|
)
|
|
# Action can eventually be clipped using `max_relative_target`,
|
|
# so action actually sent is saved in the dataset.
|
|
action = robot.send_action(pred_action)
|
|
action = {"action": action}
|
|
|
|
if dataset is not None:
|
|
frame = {**observation, **action, "task": single_task}
|
|
dataset.add_frame(frame)
|
|
|
|
# TODO(Steven): This should be more general (for RemoteRobot instead of checking the name, but anyways it will change soon)
|
|
if (display_data and not is_headless()) or (display_data and robot.robot_type.startswith("lekiwi")):
|
|
for k, v in action.items():
|
|
for i, vv in enumerate(v):
|
|
rr.log(f"sent_{k}_{i}", rr.Scalar(vv.numpy()))
|
|
|
|
image_keys = [key for key in observation if "image" in key]
|
|
for key in image_keys:
|
|
rr.log(key, rr.Image(observation[key].numpy()), static=True)
|
|
|
|
if fps is not None:
|
|
dt_s = time.perf_counter() - start_loop_t
|
|
busy_wait(1 / fps - dt_s)
|
|
|
|
dt_s = time.perf_counter() - start_loop_t
|
|
log_control_info(robot, dt_s, fps=fps)
|
|
|
|
timestamp = time.perf_counter() - start_episode_t
|
|
if events["exit_early"]:
|
|
events["exit_early"] = False
|
|
break
|
|
|
|
for _, microphone in robot.microphones.items():
|
|
microphone.stop_recording()
|
|
|
|
|
|
def reset_environment(robot, events, reset_time_s, fps):
|
|
# TODO(rcadene): refactor warmup_record and reset_environment
|
|
if has_method(robot, "teleop_safety_stop"):
|
|
robot.teleop_safety_stop()
|
|
|
|
control_loop(
|
|
robot=robot,
|
|
control_time_s=reset_time_s,
|
|
events=events,
|
|
fps=fps,
|
|
teleoperate=True,
|
|
)
|
|
|
|
|
|
def stop_recording(robot, listener, display_data):
|
|
robot.disconnect()
|
|
|
|
if not is_headless() and listener is not None:
|
|
listener.stop()
|
|
|
|
|
|
def sanity_check_dataset_name(repo_id, policy_cfg):
|
|
_, dataset_name = repo_id.split("/")
|
|
# either repo_id doesnt start with "eval_" and there is no policy
|
|
# or repo_id starts with "eval_" and there is a policy
|
|
|
|
# Check if dataset_name starts with "eval_" but policy is missing
|
|
if dataset_name.startswith("eval_") and policy_cfg is None:
|
|
raise ValueError(
|
|
f"Your dataset name begins with 'eval_' ({dataset_name}), but no policy is provided ({policy_cfg.type})."
|
|
)
|
|
|
|
# Check if dataset_name does not start with "eval_" but policy is provided
|
|
if not dataset_name.startswith("eval_") and policy_cfg is not None:
|
|
raise ValueError(
|
|
f"Your dataset name does not begin with 'eval_' ({dataset_name}), but a policy is provided ({policy_cfg.type})."
|
|
)
|
|
|
|
|
|
def sanity_check_dataset_robot_compatibility(
|
|
dataset: LeRobotDataset, robot: Robot, fps: int, use_videos: bool
|
|
) -> None:
|
|
fields = [
|
|
("robot_type", dataset.meta.robot_type, robot.robot_type),
|
|
("fps", dataset.fps, fps),
|
|
("features", dataset.features, get_features_from_robot(robot, use_videos)),
|
|
]
|
|
|
|
mismatches = []
|
|
for field, dataset_value, present_value in fields:
|
|
diff = DeepDiff(dataset_value, present_value, exclude_regex_paths=[r".*\['info'\]$"])
|
|
if diff:
|
|
mismatches.append(f"{field}: expected {present_value}, got {dataset_value}")
|
|
|
|
if mismatches:
|
|
raise ValueError(
|
|
"Dataset metadata compatibility check failed with mismatches:\n" + "\n".join(mismatches)
|
|
)
|