Enable CI for robot devices with mocked versions (#398)

Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
This commit is contained in:
Remi 2024-10-03 17:05:23 +02:00 committed by GitHub
parent 72f402d44b
commit 26f97cfd17
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 1053 additions and 237 deletions

View File

@ -11,6 +11,7 @@ on:
- ".github/**"
- "poetry.lock"
- "Makefile"
- ".cache/**"
push:
branches:
- main
@ -21,6 +22,7 @@ on:
- ".github/**"
- "poetry.lock"
- "Makefile"
- ".cache/**"
jobs:
pytest:
@ -60,7 +62,6 @@ jobs:
-W ignore::UserWarning:gymnasium.utils.env_checker:247 \
&& rm -rf tests/outputs outputs
pytest-minimal:
name: Pytest (minimal install)
runs-on: ubuntu-latest

View File

@ -28,6 +28,8 @@ Example:
print(lerobot.available_policies)
print(lerobot.available_policies_per_env)
print(lerobot.available_robots)
print(lerobot.available_cameras)
print(lerobot.available_motors)
```
When implementing a new dataset loadable with LeRobotDataset follow these steps:
@ -198,6 +200,17 @@ available_robots = [
"aloha",
]
# lists all available cameras from `lerobot/common/robot_devices/cameras`
available_cameras = [
"opencv",
"intelrealsense",
]
# lists all available motors from `lerobot/common/robot_devices/motors`
available_motors = [
"dynamixel",
]
# keys and values refer to yaml files
available_policies_per_env = {
"aloha": ["act"],

View File

@ -68,7 +68,7 @@ def get_stats_einops_patterns(dataset, num_workers=0):
return stats_patterns
def compute_stats(dataset, batch_size=32, num_workers=16, max_num_samples=None):
def compute_stats(dataset, batch_size=8, num_workers=8, max_num_samples=None):
"""Compute mean/std and min/max statistics of all data keys in a LeRobotDataset."""
if max_num_samples is None:
max_num_samples = len(dataset)

View File

@ -5,6 +5,7 @@ This file contains utilities for recording frames from Intel Realsense cameras.
import argparse
import concurrent.futures
import logging
import math
import shutil
import threading
import time
@ -13,9 +14,7 @@ from dataclasses import dataclass, replace
from pathlib import Path
from threading import Thread
import cv2
import numpy as np
import pyrealsense2 as rs
from PIL import Image
from lerobot.common.robot_devices.utils import (
@ -28,14 +27,23 @@ from lerobot.scripts.control_robot import busy_wait
SERIAL_NUMBER_INDEX = 1
def find_camera_indices(raise_when_empty=True) -> list[int]:
def find_camera_indices(raise_when_empty=True, mock=False) -> list[int]:
"""
Find the serial numbers of the Intel RealSense cameras
connected to the computer.
"""
if mock:
from tests.mock_pyrealsense2 import (
RSCameraInfo,
RSContext,
)
else:
from pyrealsense2 import camera_info as RSCameraInfo # noqa: N812
from pyrealsense2 import context as RSContext # noqa: N812
camera_ids = []
for device in rs.context().query_devices():
serial_number = int(device.get_info(rs.camera_info(SERIAL_NUMBER_INDEX)))
for device in RSContext().query_devices():
serial_number = int(device.get_info(RSCameraInfo(SERIAL_NUMBER_INDEX)))
camera_ids.append(serial_number)
if raise_when_empty and len(camera_ids) == 0:
@ -64,18 +72,24 @@ def save_images_from_cameras(
width=None,
height=None,
record_time_s=2,
mock=False,
):
"""
Initializes all the cameras and saves images to the directory. Useful to visually identify the camera
associated to a given camera index.
"""
if camera_ids is None:
camera_ids = find_camera_indices()
camera_ids = find_camera_indices(mock=mock)
if mock:
from tests.mock_cv2 import COLOR_RGB2BGR, cvtColor
else:
from cv2 import COLOR_RGB2BGR, cvtColor
print("Connecting cameras")
cameras = []
for cam_idx in camera_ids:
camera = IntelRealSenseCamera(cam_idx, fps=fps, width=width, height=height)
camera = IntelRealSenseCamera(cam_idx, fps=fps, width=width, height=height, mock=mock)
camera.connect()
print(
f"IntelRealSenseCamera({camera.camera_index}, fps={camera.fps}, width={camera.width}, height={camera.height}, color_mode={camera.color_mode})"
@ -103,7 +117,8 @@ def save_images_from_cameras(
image = camera.read() if fps is None else camera.async_read()
if image is None:
print("No Frame")
bgr_converted_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
bgr_converted_image = cvtColor(image, COLOR_RGB2BGR)
executor.submit(
save_image,
@ -149,6 +164,7 @@ class IntelRealSenseCameraConfig:
color_mode: str = "rgb"
use_depth: bool = False
force_hardware_reset: bool = True
mock: bool = False
def __post_init__(self):
if self.color_mode not in ["rgb", "bgr"]:
@ -156,7 +172,9 @@ class IntelRealSenseCameraConfig:
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
)
if (self.fps or self.width or self.height) and not (self.fps and self.width and self.height):
at_least_one_is_not_none = self.fps is not None or self.width is not None or self.height is not None
at_least_one_is_none = self.fps is None or self.width is None or self.height is None
if at_least_one_is_not_none and at_least_one_is_none:
raise ValueError(
"For `fps`, `width` and `height`, either all of them need to be set, or none of them, "
f"but {self.fps=}, {self.width=}, {self.height=} were provided."
@ -228,6 +246,7 @@ class IntelRealSenseCamera:
self.color_mode = config.color_mode
self.use_depth = config.use_depth
self.force_hardware_reset = config.force_hardware_reset
self.mock = config.mock
self.camera = None
self.is_connected = False
@ -243,24 +262,37 @@ class IntelRealSenseCamera:
f"IntelRealSenseCamera({self.camera_index}) is already connected."
)
config = rs.config()
if self.mock:
from tests.mock_pyrealsense2 import (
RSConfig,
RSFormat,
RSPipeline,
RSStream,
)
else:
from pyrealsense2 import config as RSConfig # noqa: N812
from pyrealsense2 import format as RSFormat # noqa: N812
from pyrealsense2 import pipeline as RSPipeline # noqa: N812
from pyrealsense2 import stream as RSStream # noqa: N812
config = RSConfig()
config.enable_device(str(self.camera_index))
if self.fps and self.width and self.height:
# TODO(rcadene): can we set rgb8 directly?
config.enable_stream(rs.stream.color, self.width, self.height, rs.format.rgb8, self.fps)
config.enable_stream(RSStream.color, self.width, self.height, RSFormat.rgb8, self.fps)
else:
config.enable_stream(rs.stream.color)
config.enable_stream(RSStream.color)
if self.use_depth:
if self.fps and self.width and self.height:
config.enable_stream(rs.stream.depth, self.width, self.height, rs.format.z16, self.fps)
config.enable_stream(RSStream.depth, self.width, self.height, RSFormat.z16, self.fps)
else:
config.enable_stream(rs.stream.depth)
config.enable_stream(RSStream.depth)
self.camera = rs.pipeline()
self.camera = RSPipeline()
try:
self.camera.start(config)
profile = self.camera.start(config)
is_camera_open = True
except RuntimeError:
is_camera_open = False
@ -279,6 +311,31 @@ class IntelRealSenseCamera:
raise OSError(f"Can't access IntelRealSenseCamera({self.camera_index}).")
color_stream = profile.get_stream(RSStream.color)
color_profile = color_stream.as_video_stream_profile()
actual_fps = color_profile.fps()
actual_width = color_profile.width()
actual_height = color_profile.height()
# Using `math.isclose` since actual fps can be a float (e.g. 29.9 instead of 30)
if self.fps is not None and not math.isclose(self.fps, actual_fps, rel_tol=1e-3):
# Using `OSError` since it's a broad that encompasses issues related to device communication
raise OSError(
f"Can't set {self.fps=} for IntelRealSenseCamera({self.camera_index}). Actual value is {actual_fps}."
)
if self.width is not None and self.width != actual_width:
raise OSError(
f"Can't set {self.width=} for IntelRealSenseCamera({self.camera_index}). Actual value is {actual_width}."
)
if self.height is not None and self.height != actual_height:
raise OSError(
f"Can't set {self.height=} for IntelRealSenseCamera({self.camera_index}). Actual value is {actual_height}."
)
self.fps = round(actual_fps)
self.width = round(actual_width)
self.height = round(actual_height)
self.is_connected = True
def read(self, temporary_color: str | None = None) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
@ -315,7 +372,12 @@ class IntelRealSenseCamera:
# IntelRealSense uses RGB format as default (red, green, blue).
if requested_color_mode == "bgr":
color_image = cv2.cvtColor(color_image, cv2.COLOR_RGB2BGR)
if self.mock:
from tests.mock_cv2 import COLOR_RGB2BGR, cvtColor
else:
from cv2 import COLOR_RGB2BGR, cvtColor
color_image = cvtColor(color_image, COLOR_RGB2BGR)
h, w, _ = color_image.shape
if h != self.height or w != self.width:
@ -347,7 +409,7 @@ class IntelRealSenseCamera:
return color_image
def read_loop(self):
while self.stop_event is None or not self.stop_event.is_set():
while not self.stop_event.is_set():
if self.use_depth:
self.color_image, self.depth_map = self.read()
else:
@ -368,6 +430,7 @@ class IntelRealSenseCamera:
num_tries = 0
while self.color_image is None:
# TODO(rcadene, aliberts): intelrealsense has diverged compared to opencv over here
num_tries += 1
time.sleep(1 / self.fps)
if num_tries > self.fps and (self.thread.ident is None or not self.thread.is_alive()):

View File

@ -13,7 +13,6 @@ from dataclasses import dataclass, replace
from pathlib import Path
from threading import Thread
import cv2
import numpy as np
from PIL import Image
@ -24,10 +23,6 @@ from lerobot.common.robot_devices.utils import (
)
from lerobot.common.utils.utils import capture_timestamp_utc
# Use 1 thread to avoid blocking the main thread. Especially useful during data collection
# when other threads are used to save the images.
cv2.setNumThreads(1)
# The maximum opencv device index depends on your operating system. For instance,
# if you have 3 cameras, they should be associated to index 0, 1, and 2. This is the case
# on MacOS. However, on Ubuntu, the indices are different like 6, 16, 23.
@ -36,7 +31,7 @@ cv2.setNumThreads(1)
MAX_OPENCV_INDEX = 60
def find_camera_indices(raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX):
def find_camera_indices(raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX, mock=False):
if platform.system() == "Linux":
# Linux uses camera ports
print("Linux detected. Finding available camera indices through scanning '/dev/video*' ports")
@ -51,9 +46,14 @@ def find_camera_indices(raise_when_empty=False, max_index_search_range=MAX_OPENC
)
possible_camera_ids = range(max_index_search_range)
if mock:
from tests.mock_cv2 import VideoCapture
else:
from cv2 import VideoCapture
camera_ids = []
for camera_idx in possible_camera_ids:
camera = cv2.VideoCapture(camera_idx)
camera = VideoCapture(camera_idx)
is_open = camera.isOpened()
camera.release()
@ -78,19 +78,25 @@ def save_image(img_array, camera_index, frame_index, images_dir):
def save_images_from_cameras(
images_dir: Path, camera_ids: list[int] | None = None, fps=None, width=None, height=None, record_time_s=2
images_dir: Path,
camera_ids: list[int] | None = None,
fps=None,
width=None,
height=None,
record_time_s=2,
mock=False,
):
"""
Initializes all the cameras and saves images to the directory. Useful to visually identify the camera
associated to a given camera index.
"""
if camera_ids is None:
camera_ids = find_camera_indices()
camera_ids = find_camera_indices(mock=mock)
print("Connecting cameras")
cameras = []
for cam_idx in camera_ids:
camera = OpenCVCamera(cam_idx, fps=fps, width=width, height=height)
camera = OpenCVCamera(cam_idx, fps=fps, width=width, height=height, mock=mock)
camera.connect()
print(
f"OpenCVCamera({camera.camera_index}, fps={camera.fps}, width={camera.width}, "
@ -156,6 +162,7 @@ class OpenCVCameraConfig:
width: int | None = None
height: int | None = None
color_mode: str = "rgb"
mock: bool = False
def __post_init__(self):
if self.color_mode not in ["rgb", "bgr"]:
@ -215,6 +222,7 @@ class OpenCVCamera:
self.width = config.width
self.height = config.height
self.color_mode = config.color_mode
self.mock = config.mock
self.camera = None
self.is_connected = False
@ -227,17 +235,33 @@ class OpenCVCamera:
if self.is_connected:
raise RobotDeviceAlreadyConnectedError(f"OpenCVCamera({self.camera_index}) is already connected.")
if self.mock:
from tests.mock_cv2 import (
CAP_PROP_FPS,
CAP_PROP_FRAME_HEIGHT,
CAP_PROP_FRAME_WIDTH,
VideoCapture,
)
else:
from cv2 import (
CAP_PROP_FPS,
CAP_PROP_FRAME_HEIGHT,
CAP_PROP_FRAME_WIDTH,
VideoCapture,
setNumThreads,
)
# Use 1 thread to avoid blocking the main thread. Especially useful during data collection
# when other threads are used to save the images.
setNumThreads(1)
camera_idx = f"/dev/video{self.camera_index}" if platform.system() == "Linux" else self.camera_index
# First create a temporary camera trying to access `camera_index`,
# and verify it is a valid camera by calling `isOpened`.
if platform.system() == "Linux":
# Linux uses ports for connecting to cameras
tmp_camera = cv2.VideoCapture(f"/dev/video{self.camera_index}")
else:
tmp_camera = cv2.VideoCapture(self.camera_index)
tmp_camera = VideoCapture(camera_idx)
is_camera_open = tmp_camera.isOpened()
# Release camera to make it accessible for `find_camera_indices`
tmp_camera.release()
del tmp_camera
# If the camera doesn't work, display the camera indices corresponding to
@ -251,28 +275,27 @@ class OpenCVCamera:
"To find the camera index you should use, run `python lerobot/common/robot_devices/cameras/opencv.py`."
)
raise OSError(f"Can't access OpenCVCamera({self.camera_index}).")
raise OSError(f"Can't access OpenCVCamera({camera_idx}).")
# Secondly, create the camera that will be used downstream.
# Note: For some unknown reason, calling `isOpened` blocks the camera which then
# needs to be re-created.
if platform.system() == "Linux":
self.camera = cv2.VideoCapture(f"/dev/video{self.camera_index}")
else:
self.camera = cv2.VideoCapture(self.camera_index)
self.camera = VideoCapture(camera_idx)
if self.fps is not None:
self.camera.set(cv2.CAP_PROP_FPS, self.fps)
self.camera.set(CAP_PROP_FPS, self.fps)
if self.width is not None:
self.camera.set(cv2.CAP_PROP_FRAME_WIDTH, self.width)
self.camera.set(CAP_PROP_FRAME_WIDTH, self.width)
if self.height is not None:
self.camera.set(cv2.CAP_PROP_FRAME_HEIGHT, self.height)
self.camera.set(CAP_PROP_FRAME_HEIGHT, self.height)
actual_fps = self.camera.get(cv2.CAP_PROP_FPS)
actual_width = self.camera.get(cv2.CAP_PROP_FRAME_WIDTH)
actual_height = self.camera.get(cv2.CAP_PROP_FRAME_HEIGHT)
actual_fps = self.camera.get(CAP_PROP_FPS)
actual_width = self.camera.get(CAP_PROP_FRAME_WIDTH)
actual_height = self.camera.get(CAP_PROP_FRAME_HEIGHT)
# Using `math.isclose` since actual fps can be a float (e.g. 29.9 instead of 30)
if self.fps is not None and not math.isclose(self.fps, actual_fps, rel_tol=1e-3):
# Using `OSError` since it's a broad that encompasses issues related to device communication
raise OSError(
f"Can't set {self.fps=} for OpenCVCamera({self.camera_index}). Actual value is {actual_fps}."
)
@ -285,9 +308,9 @@ class OpenCVCamera:
f"Can't set {self.height=} for OpenCVCamera({self.camera_index}). Actual value is {actual_height}."
)
self.fps = actual_fps
self.width = actual_width
self.height = actual_height
self.fps = round(actual_fps)
self.width = round(actual_width)
self.height = round(actual_height)
self.is_connected = True
@ -306,6 +329,7 @@ class OpenCVCamera:
start_time = time.perf_counter()
ret, color_image = self.camera.read()
if not ret:
raise OSError(f"Can't capture color image from camera {self.camera_index}.")
@ -320,7 +344,12 @@ class OpenCVCamera:
# However, Deep Learning framework such as LeRobot uses RGB format as default to train neural networks,
# so we convert the image color from BGR to RGB.
if requested_color_mode == "rgb":
color_image = cv2.cvtColor(color_image, cv2.COLOR_BGR2RGB)
if self.mock:
from tests.mock_cv2 import COLOR_BGR2RGB, cvtColor
else:
from cv2 import COLOR_BGR2RGB, cvtColor
color_image = cvtColor(color_image, COLOR_BGR2RGB)
h, w, _ = color_image.shape
if h != self.height or w != self.width:
@ -334,11 +363,16 @@ class OpenCVCamera:
# log the utc time at which the image was received
self.logs["timestamp_utc"] = capture_timestamp_utc()
self.color_image = color_image
return color_image
def read_loop(self):
while self.stop_event is None or not self.stop_event.is_set():
self.color_image = self.read()
while not self.stop_event.is_set():
try:
self.color_image = self.read()
except Exception as e:
print(f"Error reading in thread: {e}")
def async_read(self):
if not self.is_connected:
@ -353,15 +387,14 @@ class OpenCVCamera:
self.thread.start()
num_tries = 0
while self.color_image is None:
num_tries += 1
time.sleep(1 / self.fps)
if num_tries > self.fps and (self.thread.ident is None or not self.thread.is_alive()):
raise Exception(
"The thread responsible for `self.async_read()` took too much time to start. There might be an issue. Verify that `self.thread.start()` has been called."
)
while True:
if self.color_image is not None:
return self.color_image
return self.color_image
time.sleep(1 / self.fps)
num_tries += 1
if num_tries > self.fps * 2:
raise TimeoutError("Timed out waiting for async_read() to start.")
def disconnect(self):
if not self.is_connected:
@ -369,16 +402,14 @@ class OpenCVCamera:
f"OpenCVCamera({self.camera_index}) is not connected. Try running `camera.connect()` first."
)
if self.thread is not None and self.thread.is_alive():
# wait for the thread to finish
if self.thread is not None:
self.stop_event.set()
self.thread.join()
self.thread.join() # wait for the thread to finish
self.thread = None
self.stop_event = None
self.camera.release()
self.camera = None
self.is_connected = False
def __del__(self):

View File

@ -8,17 +8,6 @@ from pathlib import Path
import numpy as np
import tqdm
from dynamixel_sdk import (
COMM_SUCCESS,
DXL_HIBYTE,
DXL_HIWORD,
DXL_LOBYTE,
DXL_LOWORD,
GroupSyncRead,
GroupSyncWrite,
PacketHandler,
PortHandler,
)
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
from lerobot.common.utils.utils import capture_timestamp_utc
@ -166,7 +155,17 @@ def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str
return steps
def convert_to_bytes(value, bytes):
def convert_to_bytes(value, bytes, mock=False):
if mock:
return value
from dynamixel_sdk import (
DXL_HIBYTE,
DXL_HIWORD,
DXL_LOBYTE,
DXL_LOWORD,
)
# Note: No need to convert back into unsigned int, since this byte preprocessing
# already handles it for us.
if bytes == 1:
@ -333,9 +332,11 @@ class DynamixelMotorsBus:
motors: dict[str, tuple[int, str]],
extra_model_control_table: dict[str, list[tuple]] | None = None,
extra_model_resolution: dict[str, int] | None = None,
mock=False,
):
self.port = port
self.motors = motors
self.mock = mock
self.model_ctrl_table = deepcopy(MODEL_CONTROL_TABLE)
if extra_model_control_table:
@ -359,6 +360,11 @@ class DynamixelMotorsBus:
f"DynamixelMotorsBus({self.port}) is already connected. Do not call `motors_bus.connect()` twice."
)
if self.mock:
from tests.mock_dynamixel_sdk import PacketHandler, PortHandler
else:
from dynamixel_sdk import PacketHandler, PortHandler
self.port_handler = PortHandler(self.port)
self.packet_handler = PacketHandler(PROTOCOL_VERSION)
@ -392,10 +398,17 @@ class DynamixelMotorsBus:
self.configure_motors()
def reconnect(self):
if self.mock:
from tests.mock_dynamixel_sdk import PacketHandler, PortHandler
else:
from dynamixel_sdk import PacketHandler, PortHandler
self.port_handler = PortHandler(self.port)
self.packet_handler = PacketHandler(PROTOCOL_VERSION)
if not self.port_handler.openPort():
raise OSError(f"Failed to open port '{self.port}'.")
self.is_connected = True
def are_motors_configured(self):
@ -781,6 +794,11 @@ class DynamixelMotorsBus:
return values
def _read_with_motor_ids(self, motor_models, motor_ids, data_name):
if self.mock:
from tests.mock_dynamixel_sdk import COMM_SUCCESS, GroupSyncRead
else:
from dynamixel_sdk import COMM_SUCCESS, GroupSyncRead
return_list = True
if not isinstance(motor_ids, list):
return_list = False
@ -817,6 +835,11 @@ class DynamixelMotorsBus:
start_time = time.perf_counter()
if self.mock:
from tests.mock_dynamixel_sdk import COMM_SUCCESS, GroupSyncRead
else:
from dynamixel_sdk import COMM_SUCCESS, GroupSyncRead
if motor_names is None:
motor_names = self.motor_names
@ -876,6 +899,11 @@ class DynamixelMotorsBus:
return values
def _write_with_motor_ids(self, motor_models, motor_ids, data_name, values):
if self.mock:
from tests.mock_dynamixel_sdk import COMM_SUCCESS, GroupSyncWrite
else:
from dynamixel_sdk import COMM_SUCCESS, GroupSyncWrite
if not isinstance(motor_ids, list):
motor_ids = [motor_ids]
if not isinstance(values, list):
@ -885,7 +913,7 @@ class DynamixelMotorsBus:
addr, bytes = self.model_ctrl_table[motor_models[0]][data_name]
group = GroupSyncWrite(self.port_handler, self.packet_handler, addr, bytes)
for idx, value in zip(motor_ids, values, strict=True):
data = convert_to_bytes(value, bytes)
data = convert_to_bytes(value, bytes, self.mock)
group.addParam(idx, data)
comm = group.txPacket()
@ -903,6 +931,11 @@ class DynamixelMotorsBus:
start_time = time.perf_counter()
if self.mock:
from tests.mock_dynamixel_sdk import COMM_SUCCESS, GroupSyncWrite
else:
from dynamixel_sdk import COMM_SUCCESS, GroupSyncWrite
if motor_names is None:
motor_names = self.motor_names
@ -937,7 +970,7 @@ class DynamixelMotorsBus:
)
for idx, value in zip(motor_ids, values, strict=True):
data = convert_to_bytes(value, bytes)
data = convert_to_bytes(value, bytes, self.mock)
if init_group:
self.group_writers[group_key].addParam(idx, data)
else:

View File

@ -242,7 +242,8 @@ def is_headless():
########################################################################################
def calibrate(robot: Robot, arms: list[str] | None):
def get_available_arms(robot):
# TODO(rcadene): moves this function in manipulator class?
available_arms = []
for name in robot.follower_arms:
arm_id = get_arm_id(name, "follower")
@ -250,9 +251,12 @@ def calibrate(robot: Robot, arms: list[str] | None):
for name in robot.leader_arms:
arm_id = get_arm_id(name, "leader")
available_arms.append(arm_id)
return available_arms
def calibrate(robot: Robot, arms: list[str] | None):
available_arms = get_available_arms(robot)
unknown_arms = [arm_id for arm_id in arms if arm_id not in available_arms]
available_arms_str = " ".join(available_arms)
unknown_arms_str = " ".join(unknown_arms)
@ -323,6 +327,7 @@ def record(
tags=None,
num_image_writers_per_camera=4,
force_override=False,
display_cameras=True,
):
# TODO(rcadene): Add option to record logs
# TODO(rcadene): Clean this function via decomposition in higher level functions
@ -333,9 +338,6 @@ def record(
f"Your dataset name begins by 'eval_' ({dataset_name}) but no policy is provided ({policy})."
)
if not video:
raise NotImplementedError()
if not robot.is_connected:
robot.connect()
@ -359,7 +361,7 @@ def record(
episode_index = 0
if is_headless():
logging.info(
logging.warning(
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
)
@ -427,7 +429,7 @@ def record(
else:
observation = robot.capture_observation()
if not is_headless():
if display_cameras and not is_headless():
image_keys = [key for key in observation if "image" in key]
for key in image_keys:
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
@ -445,6 +447,7 @@ def record(
# Using `with` to exist smoothly if an execption is raised.
futures = []
num_image_writers = num_image_writers_per_camera * len(robot.cameras)
num_image_writers = max(num_image_writers, 1)
with concurrent.futures.ThreadPoolExecutor(max_workers=num_image_writers) as executor:
# Start recording all episodes
while episode_index < num_episodes:
@ -472,7 +475,7 @@ def record(
)
]
if not is_headless():
if display_cameras and not is_headless():
image_keys = [key for key in observation if "image" in key]
for key in image_keys:
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
@ -545,15 +548,23 @@ def record(
num_frames = frame_index
for key in image_keys:
tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
fname = f"{key}_episode_{episode_index:06d}.mp4"
video_path = local_dir / "videos" / fname
if video_path.exists():
video_path.unlink()
# Store the reference to the video frame, even tho the videos are not yet encoded
ep_dict[key] = []
for i in range(num_frames):
ep_dict[key].append({"path": f"videos/{fname}", "timestamp": i / fps})
if video:
tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
fname = f"{key}_episode_{episode_index:06d}.mp4"
video_path = local_dir / "videos" / fname
if video_path.exists():
video_path.unlink()
# Store the reference to the video frame, even tho the videos are not yet encoded
ep_dict[key] = []
for i in range(num_frames):
ep_dict[key].append({"path": f"videos/{fname}", "timestamp": i / fps})
else:
imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
ep_dict[key] = []
for i in range(num_frames):
img_path = imgs_dir / f"frame_{i:06d}.png"
ep_dict[key].append({"path": str(img_path)})
for key in not_image_keys:
ep_dict[key] = torch.stack(ep_dict[key])
@ -612,26 +623,27 @@ def record(
break
robot.disconnect()
if not is_headless():
if display_cameras and not is_headless():
cv2.destroyAllWindows()
num_episodes = episode_index
logging.info("Encoding videos")
say("Encoding videos")
# Use ffmpeg to convert frames stored as png into mp4 videos
for episode_index in tqdm.tqdm(range(num_episodes)):
for key in image_keys:
tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
fname = f"{key}_episode_{episode_index:06d}.mp4"
video_path = local_dir / "videos" / fname
if video_path.exists():
# Skip if video is already encoded. Could be the case when resuming data recording.
continue
# note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
# since video encoding with ffmpeg is already using multithreading.
encode_video_frames(tmp_imgs_dir, video_path, fps, overwrite=True)
shutil.rmtree(tmp_imgs_dir)
if video:
logging.info("Encoding videos")
say("Encoding videos")
# Use ffmpeg to convert frames stored as png into mp4 videos
for episode_index in tqdm.tqdm(range(num_episodes)):
for key in image_keys:
tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
fname = f"{key}_episode_{episode_index:06d}.mp4"
video_path = local_dir / "videos" / fname
if video_path.exists():
# Skip if video is already encoded. Could be the case when resuming data recording.
continue
# note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
# since video encoding with ffmpeg is already using multithreading.
encode_video_frames(tmp_imgs_dir, video_path, fps, overwrite=True)
shutil.rmtree(tmp_imgs_dir)
logging.info("Concatenating episodes")
ep_dicts = []

7
poetry.lock generated
View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
[[package]]
name = "absl-py"
@ -2406,7 +2406,6 @@ description = "Nvidia JIT LTO Library"
optional = false
python-versions = ">=3"
files = [
{file = "nvidia_nvjitlink_cu12-12.5.82-py3-none-manylinux2014_aarch64.whl", hash = "sha256:98103729cc5226e13ca319a10bbf9433bbbd44ef64fe72f45f067cacc14b8d27"},
{file = "nvidia_nvjitlink_cu12-12.5.82-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f9b37bc5c8cf7509665cb6ada5aaa0ce65618f2332b7d3e78e9790511f111212"},
{file = "nvidia_nvjitlink_cu12-12.5.82-py3-none-win_amd64.whl", hash = "sha256:e782564d705ff0bf61ac3e1bf730166da66dd2fe9012f111ede5fc49b64ae697"},
]
@ -4578,7 +4577,7 @@ dora = ["gym-dora"]
dynamixel = ["dynamixel-sdk", "pynput"]
intelrealsense = ["pyrealsense2"]
pusht = ["gym-pusht"]
test = ["pytest", "pytest-cov"]
test = ["pyserial", "pytest", "pytest-cov"]
umi = ["imagecodecs"]
video-benchmark = ["pandas", "scikit-image"]
xarm = ["gym-xarm"]
@ -4586,4 +4585,4 @@ xarm = ["gym-xarm"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.10,<3.13"
content-hash = "c9c3beac71f760738baf2fd169378eefdaef7d3a9cd068270bc5190fbefdb42a"
content-hash = "5e4f6b9727d67a37d1c6d94af7661bb688a0866afd30878c5e523b8e768deac6"

View File

@ -67,6 +67,7 @@ pynput = {version = ">=1.7.7", optional = true}
# TODO(rcadene, salibert): 71.0.1 has a bug
setuptools = {version = "!=71.0.1", optional = true}
pyrealsense2 = {version = ">=2.55.1.6486", markers = "sys_platform != 'darwin'", optional = true}
pyserial = {version = ">=3.5", optional = true}
[tool.poetry.extras]
@ -75,7 +76,7 @@ pusht = ["gym-pusht"]
xarm = ["gym-xarm"]
aloha = ["gym-aloha"]
dev = ["pre-commit", "debugpy"]
test = ["pytest", "pytest-cov"]
test = ["pytest", "pytest-cov", "pyserial"]
umi = ["imagecodecs"]
video_benchmark = ["scikit-image", "pandas"]
dynamixel = ["dynamixel-sdk", "pynput"]

View File

@ -13,13 +13,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import traceback
import pytest
from serial import SerialException
from lerobot import available_cameras, available_motors, available_robots
from lerobot.common.utils.utils import init_hydra_config
from .utils import DEVICE, ROBOT_CONFIG_PATH_TEMPLATE
from tests.utils import DEVICE, ROBOT_CONFIG_PATH_TEMPLATE, make_camera, make_motors_bus
def pytest_collection_finish():
@ -28,6 +30,11 @@ def pytest_collection_finish():
@pytest.fixture
def is_robot_available(robot_type):
if robot_type not in available_robots:
raise ValueError(
f"The robot type '{robot_type}' is not valid. Expected one of these '{available_robots}"
)
try:
from lerobot.common.robot_devices.robots.factory import make_robot
@ -37,7 +44,73 @@ def is_robot_available(robot_type):
robot.connect()
del robot
return True
except Exception:
traceback.print_exc()
except Exception as e:
print(f"\nA {robot_type} robot is not available.")
if isinstance(e, ModuleNotFoundError):
print(f"\nInstall module '{e.name}'")
elif isinstance(e, SerialException):
print("\nNo physical motors bus detected.")
traceback.print_exc()
return False
@pytest.fixture
def is_camera_available(camera_type):
if camera_type not in available_cameras:
raise ValueError(
f"The camera type '{camera_type}' is not valid. Expected one of these '{available_cameras}"
)
try:
camera = make_camera(camera_type)
camera.connect()
del camera
return True
except Exception as e:
print(f"\nA {camera_type} camera is not available.")
if isinstance(e, ModuleNotFoundError):
print(f"\nInstall module '{e.name}'")
elif isinstance(e, ValueError) and "camera_index" in e.args[0]:
print("\nNo physical camera detected.")
traceback.print_exc()
return False
@pytest.fixture
def is_motor_available(motor_type):
if motor_type not in available_motors:
raise ValueError(
f"The motor type '{motor_type}' is not valid. Expected one of these '{available_motors}"
)
try:
motors_bus = make_motors_bus(motor_type)
motors_bus.connect()
del motors_bus
return True
except Exception as e:
print(f"\nA {motor_type} motor is not available.")
if isinstance(e, ModuleNotFoundError):
print(f"\nInstall module '{e.name}'")
elif isinstance(e, SerialException):
print("\nNo physical motors bus detected.")
traceback.print_exc()
return False
@pytest.fixture
def patch_builtins_input(monkeypatch):
def print_text(text=None):
if text is not None:
print(text)
monkeypatch.setattr("builtins.input", print_text)

66
tests/mock_cv2.py Normal file
View File

@ -0,0 +1,66 @@
from functools import cache
import numpy as np
CAP_PROP_FPS = 5
CAP_PROP_FRAME_WIDTH = 3
CAP_PROP_FRAME_HEIGHT = 4
COLOR_RGB2BGR = 4
COLOR_BGR2RGB = 4
@cache
def _generate_image(width: int, height: int):
return np.random.randint(0, 256, size=(height, width, 3), dtype=np.uint8)
def cvtColor(color_image, color_convertion): # noqa: N802
if color_convertion in [COLOR_RGB2BGR, COLOR_BGR2RGB]:
return color_image[:, :, [2, 1, 0]]
else:
raise NotImplementedError(color_convertion)
class VideoCapture:
def __init__(self, *args, **kwargs):
self._mock_dict = {
CAP_PROP_FPS: 30,
CAP_PROP_FRAME_WIDTH: 640,
CAP_PROP_FRAME_HEIGHT: 480,
}
self._is_opened = True
def isOpened(self): # noqa: N802
return self._is_opened
def set(self, propId: int, value: float) -> bool: # noqa: N803
if not self._is_opened:
raise RuntimeError("Camera is not opened")
self._mock_dict[propId] = value
return True
def get(self, propId: int) -> float: # noqa: N803
if not self._is_opened:
raise RuntimeError("Camera is not opened")
value = self._mock_dict[propId]
if value == 0:
if propId == CAP_PROP_FRAME_HEIGHT:
value = 480
elif propId == CAP_PROP_FRAME_WIDTH:
value = 640
return value
def read(self):
if not self._is_opened:
raise RuntimeError("Camera is not opened")
h = self.get(CAP_PROP_FRAME_HEIGHT)
w = self.get(CAP_PROP_FRAME_WIDTH)
ret = True
return ret, _generate_image(width=w, height=h)
def release(self):
self._is_opened = False
def __del__(self):
if self._is_opened:
self.release()

View File

@ -0,0 +1,87 @@
"""Mocked classes and functions from dynamixel_sdk to allow for continuous integration
and testing code logic that requires hardware and devices (e.g. robot arms, cameras)
Warning: These mocked versions are minimalist. They do not exactly mock every behaviors
from the original classes and functions (e.g. return types might be None instead of boolean).
"""
# from dynamixel_sdk import COMM_SUCCESS
DEFAULT_BAUDRATE = 9_600
COMM_SUCCESS = 0 # tx or rx packet communication success
def convert_to_bytes(value, bytes):
# TODO(rcadene): remove need to mock `convert_to_bytes` by implemented the inverse transform
# `convert_bytes_to_value`
del bytes # unused
return value
class PortHandler:
def __init__(self, port):
self.port = port
# factory default baudrate
self.baudrate = DEFAULT_BAUDRATE
def openPort(self): # noqa: N802
return True
def closePort(self): # noqa: N802
pass
def setPacketTimeoutMillis(self, timeout_ms): # noqa: N802
del timeout_ms # unused
def getBaudRate(self): # noqa: N802
return self.baudrate
def setBaudRate(self, baudrate): # noqa: N802
self.baudrate = baudrate
class PacketHandler:
def __init__(self, protocol_version):
del protocol_version # unused
# Use packet_handler.data to communicate across Read and Write
self.data = {}
class GroupSyncRead:
def __init__(self, port_handler, packet_handler, address, bytes):
self.packet_handler = packet_handler
def addParam(self, motor_index): # noqa: N802
if motor_index not in self.packet_handler.data:
# Initialize motor default values
self.packet_handler.data[motor_index] = {
# Key (int) are from X_SERIES_CONTROL_TABLE
7: motor_index, # ID
8: DEFAULT_BAUDRATE, # Baud_rate
10: 0, # Drive_Mode
64: 0, # Torque_Enable
# Set 2560 since calibration values for Aloha gripper is between start_pos=2499 and end_pos=3144
# For other joints, 2560 will be autocorrected to be in calibration range
132: 2560, # Present_Position
}
def txRxPacket(self): # noqa: N802
return COMM_SUCCESS
def getData(self, index, address, bytes): # noqa: N802
return self.packet_handler.data[index][address]
class GroupSyncWrite:
def __init__(self, port_handler, packet_handler, address, bytes):
self.packet_handler = packet_handler
self.address = address
def addParam(self, index, data): # noqa: N802
self.changeParam(index, data)
def txPacket(self): # noqa: N802
return COMM_SUCCESS
def changeParam(self, index, data): # noqa: N802
self.packet_handler.data[index][self.address] = data

134
tests/mock_pyrealsense2.py Normal file
View File

@ -0,0 +1,134 @@
import enum
import numpy as np
class RSStream(enum.Enum):
color = 0
depth = 1
class RSFormat(enum.Enum):
rgb8 = 0
z16 = 1
class RSConfig:
def enable_device(self, device_id: str):
self.device_enabled = device_id
def enable_stream(
self, stream_type: RSStream, width=None, height=None, color_format: RSFormat = None, fps=None
):
self.stream_type = stream_type
# Overwrite default values when possible
self.width = 848 if width is None else width
self.height = 480 if height is None else height
self.color_format = RSFormat.rgb8 if color_format is None else color_format
self.fps = 30 if fps is None else fps
class RSColorProfile:
def __init__(self, config: RSConfig):
self.config = config
def fps(self):
return self.config.fps
def width(self):
return self.config.width
def height(self):
return self.config.height
class RSColorStream:
def __init__(self, config: RSConfig):
self.config = config
def as_video_stream_profile(self):
return RSColorProfile(self.config)
class RSProfile:
def __init__(self, config: RSConfig):
self.config = config
def get_stream(self, color_format: RSFormat):
del color_format # unused
return RSColorStream(self.config)
class RSPipeline:
def __init__(self):
self.started = False
self.config = None
def start(self, config: RSConfig):
self.started = True
self.config = config
return RSProfile(self.config)
def stop(self):
if not self.started:
raise RuntimeError("You need to start the camera before stop.")
self.started = False
self.config = None
def wait_for_frames(self, timeout_ms=50000):
del timeout_ms # unused
return RSFrames(self.config)
class RSFrames:
def __init__(self, config: RSConfig):
self.config = config
def get_color_frame(self):
return RSColorFrame(self.config)
def get_depth_frame(self):
return RSDepthFrame(self.config)
class RSColorFrame:
def __init__(self, config: RSConfig):
self.config = config
def get_data(self):
data = np.ones((self.config.height, self.config.width, 3), dtype=np.uint8)
# Create a difference between rgb and bgr
data[:, :, 0] = 2
return data
class RSDepthFrame:
def __init__(self, config: RSConfig):
self.config = config
def get_data(self):
return np.ones((self.config.height, self.config.width), dtype=np.uint16)
class RSDevice:
def __init__(self):
pass
def get_info(self, camera_info) -> str:
del camera_info # unused
# return fake serial number
return "123456789"
class RSContext:
def __init__(self):
pass
def query_devices(self):
return [RSDevice()]
class RSCameraInfo:
def __init__(self, serial_number):
del serial_number
pass

View File

@ -1,21 +1,32 @@
"""
Tests meant to be used locally and launched manually.
Tests for physical cameras and their mocked versions.
If the physical camera is not connected to the computer, or not working,
the test will be skipped.
Example usage:
Example of running a specific test:
```bash
pytest -sx tests/test_cameras.py::test_camera
```
Example of running test on a real camera connected to the computer:
```bash
pytest -sx 'tests/test_cameras.py::test_camera[opencv-False]'
pytest -sx 'tests/test_cameras.py::test_camera[intelrealsense-False]'
```
Example of running test on a mocked version of the camera:
```bash
pytest -sx 'tests/test_cameras.py::test_camera[opencv-True]'
pytest -sx 'tests/test_cameras.py::test_camera[intelrealsense-True]'
```
"""
import numpy as np
import pytest
from lerobot import available_robots
from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera, save_images_from_cameras
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
from tests.utils import require_robot
from tests.utils import TEST_CAMERA_TYPES, make_camera, require_camera
CAMERA_INDEX = 2
# Maximum absolute difference between two consecutive images recored by a camera.
# This value differs with respect to the camera.
MAX_PIXEL_DIFFERENCE = 25
@ -25,9 +36,9 @@ def compute_max_pixel_difference(first_image, second_image):
return np.abs(first_image.astype(float) - second_image.astype(float)).max()
@pytest.mark.parametrize("robot_type", available_robots)
@require_robot
def test_camera(request, robot_type):
@pytest.mark.parametrize("camera_type, mock", TEST_CAMERA_TYPES)
@require_camera
def test_camera(request, camera_type, mock):
"""Test assumes that `camera.read()` returns the same image when called multiple times in a row.
So the environment should not change (you shouldnt be in front of the camera) and the camera should not be moving.
@ -36,10 +47,12 @@ def test_camera(request, robot_type):
"""
# TODO(rcadene): measure fps in nightly?
# TODO(rcadene): test logs
# TODO(rcadene): add compatibility with other camera APIs
if camera_type == "opencv" and not mock:
pytest.skip("TODO(rcadene): fix test for opencv physical camera")
# Test instantiating
camera = OpenCVCamera(CAMERA_INDEX)
camera = make_camera(camera_type, mock=mock)
# Test reading, async reading, disconnecting before connecting raises an error
with pytest.raises(RobotDeviceNotConnectedError):
@ -53,7 +66,7 @@ def test_camera(request, robot_type):
del camera
# Test connecting
camera = OpenCVCamera(CAMERA_INDEX)
camera = make_camera(camera_type, mock=mock)
camera.connect()
assert camera.is_connected
assert camera.fps is not None
@ -78,11 +91,14 @@ def test_camera(request, robot_type):
camera.read()
color_image = camera.read()
async_color_image = camera.async_read()
print(
error_msg = (
"max_pixel_difference between read() and async_read()",
compute_max_pixel_difference(color_image, async_color_image),
)
assert np.allclose(color_image, async_color_image, rtol=1e-5, atol=MAX_PIXEL_DIFFERENCE)
# TODO(rcadene): properly set `rtol`
np.testing.assert_allclose(
color_image, async_color_image, rtol=1e-5, atol=MAX_PIXEL_DIFFERENCE, err_msg=error_msg
)
# Test disconnecting
camera.disconnect()
@ -90,29 +106,25 @@ def test_camera(request, robot_type):
assert camera.thread is None
# Test disconnecting with `__del__`
camera = OpenCVCamera(CAMERA_INDEX)
camera = make_camera(camera_type, mock=mock)
camera.connect()
del camera
# Test acquiring a bgr image
camera = OpenCVCamera(CAMERA_INDEX, color_mode="bgr")
camera = make_camera(camera_type, color_mode="bgr", mock=mock)
camera.connect()
assert camera.color_mode == "bgr"
bgr_color_image = camera.read()
assert np.allclose(color_image, bgr_color_image[:, :, [2, 1, 0]], rtol=1e-5, atol=MAX_PIXEL_DIFFERENCE)
np.testing.assert_allclose(
color_image, bgr_color_image[:, :, [2, 1, 0]], rtol=1e-5, atol=MAX_PIXEL_DIFFERENCE, err_msg=error_msg
)
del camera
# TODO(rcadene): Add a test for a camera that doesnt support fps=60 and raises an OSError
# TODO(rcadene): Add a test for a camera that supports fps=60
# Test fps=10 raises an OSError
camera = OpenCVCamera(CAMERA_INDEX, fps=10)
with pytest.raises(OSError):
camera.connect()
del camera
# Test width and height can be set
camera = OpenCVCamera(CAMERA_INDEX, fps=30, width=1280, height=720)
camera = make_camera(camera_type, fps=30, width=1280, height=720, mock=mock)
camera.connect()
assert camera.fps == 30
assert camera.width == 1280
@ -125,13 +137,19 @@ def test_camera(request, robot_type):
del camera
# Test not supported width and height raise an error
camera = OpenCVCamera(CAMERA_INDEX, fps=30, width=0, height=0)
camera = make_camera(camera_type, fps=30, width=0, height=0, mock=mock)
with pytest.raises(OSError):
camera.connect()
del camera
@pytest.mark.parametrize("robot_type", available_robots)
@require_robot
def test_save_images_from_cameras(tmpdir, request, robot_type):
save_images_from_cameras(tmpdir, record_time_s=1)
@pytest.mark.parametrize("camera_type, mock", TEST_CAMERA_TYPES)
@require_camera
def test_save_images_from_cameras(tmpdir, request, camera_type, mock):
# TODO(rcadene): refactor
if camera_type == "opencv":
from lerobot.common.robot_devices.cameras.opencv import save_images_from_cameras
elif camera_type == "intelrealsense":
from lerobot.common.robot_devices.cameras.intelrealsense import save_images_from_cameras
save_images_from_cameras(tmpdir, record_time_s=1, mock=mock)

View File

@ -1,55 +1,146 @@
"""
Tests for physical robots and their mocked versions.
If the physical robots are not connected to the computer, or not working,
the test will be skipped.
Example of running a specific test:
```bash
pytest -sx tests/test_control_robot.py::test_teleoperate
```
Example of running test on real robots connected to the computer:
```bash
pytest -sx 'tests/test_control_robot.py::test_teleoperate[koch-False]'
pytest -sx 'tests/test_control_robot.py::test_teleoperate[koch_bimanual-False]'
pytest -sx 'tests/test_control_robot.py::test_teleoperate[aloha-False]'
```
Example of running test on a mocked version of robots:
```bash
pytest -sx 'tests/test_control_robot.py::test_teleoperate[koch-True]'
pytest -sx 'tests/test_control_robot.py::test_teleoperate[koch_bimanual-True]'
pytest -sx 'tests/test_control_robot.py::test_teleoperate[aloha-True]'
```
"""
from pathlib import Path
import pytest
from lerobot import available_robots
from lerobot.common.policies.factory import make_policy
from lerobot.common.utils.utils import init_hydra_config
from lerobot.scripts.control_robot import calibrate, record, replay, teleoperate
from lerobot.scripts.control_robot import calibrate, get_available_arms, record, replay, teleoperate
from tests.test_robots import make_robot
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_robot
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, TEST_ROBOT_TYPES, require_robot
@pytest.mark.parametrize("robot_type", available_robots)
@pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES)
@require_robot
def test_teleoperate(request, robot_type):
robot = make_robot(robot_type)
def test_teleoperate(tmpdir, request, robot_type, mock):
if mock:
request.getfixturevalue("patch_builtins_input")
# Create an empty calibration directory to trigger manual calibration
# and avoid writing calibration files in user .cache/calibration folder
tmpdir = Path(tmpdir)
calibration_dir = tmpdir / robot_type
overrides = [f"calibration_dir={calibration_dir}"]
else:
# Use the default .cache/calibration folder when mock=False
overrides = None
robot = make_robot(robot_type, overrides=overrides, mock=mock)
teleoperate(robot, teleop_time_s=1)
teleoperate(robot, fps=30, teleop_time_s=1)
teleoperate(robot, fps=60, teleop_time_s=1)
del robot
@pytest.mark.parametrize("robot_type", available_robots)
@pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES)
@require_robot
def test_calibrate(request, robot_type):
robot = make_robot(robot_type)
calibrate(robot)
def test_calibrate(tmpdir, request, robot_type, mock):
if mock:
request.getfixturevalue("patch_builtins_input")
# Create an empty calibration directory to trigger manual calibration
tmpdir = Path(tmpdir)
calibration_dir = tmpdir / robot_type
overrides_calibration_dir = [f"calibration_dir={calibration_dir}"]
robot = make_robot(robot_type, overrides=overrides_calibration_dir, mock=mock)
calibrate(robot, arms=get_available_arms(robot))
del robot
@pytest.mark.parametrize("robot_type", available_robots)
@pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES)
@require_robot
def test_record_without_cameras(tmpdir, request, robot_type):
root = Path(tmpdir)
def test_record_without_cameras(tmpdir, request, robot_type, mock):
# Avoid using cameras
overrides = ["~cameras"]
if mock:
request.getfixturevalue("patch_builtins_input")
# Create an empty calibration directory to trigger manual calibration
# and avoid writing calibration files in user .cache/calibration folder
calibration_dir = Path(tmpdir) / robot_type
overrides.append(f"calibration_dir={calibration_dir}")
root = Path(tmpdir) / "data"
repo_id = "lerobot/debug"
robot = make_robot(robot_type, overrides=["~cameras"])
record(robot, fps=30, root=root, repo_id=repo_id, warmup_time_s=1, episode_time_s=1, num_episodes=2)
robot = make_robot(robot_type, overrides=overrides, mock=mock)
record(
robot,
fps=30,
root=root,
repo_id=repo_id,
warmup_time_s=1,
episode_time_s=1,
num_episodes=2,
run_compute_stats=False,
push_to_hub=False,
video=False,
)
@pytest.mark.parametrize("robot_type", available_robots)
@pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES)
@require_robot
def test_record_and_replay_and_policy(tmpdir, request, robot_type):
def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
if mock:
request.getfixturevalue("patch_builtins_input")
# Create an empty calibration directory to trigger manual calibration
# and avoid writing calibration files in user .cache/calibration folder
calibration_dir = Path(tmpdir) / robot_type
overrides = [f"calibration_dir={calibration_dir}"]
else:
# Use the default .cache/calibration folder when mock=False
overrides = None
if robot_type == "aloha":
pytest.skip("TODO(rcadene): enable test once aloha_real and act_aloha_real are merged")
env_name = "koch_real"
policy_name = "act_koch_real"
root = Path(tmpdir)
root = Path(tmpdir) / "data"
repo_id = "lerobot/debug"
robot = make_robot(robot_type)
robot = make_robot(robot_type, overrides=overrides, mock=mock)
dataset = record(
robot, fps=30, root=root, repo_id=repo_id, warmup_time_s=1, episode_time_s=1, num_episodes=2
robot,
fps=30,
root=root,
repo_id=repo_id,
warmup_time_s=1,
episode_time_s=1,
num_episodes=2,
push_to_hub=False,
# TODO(rcadene, aliberts): test video=True
video=False,
# TODO(rcadene): display cameras through cv2 sometimes crashes on mac
display_cameras=False,
)
replay(robot, episode=0, fps=30, root=root, repo_id=repo_id)
@ -65,6 +156,17 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type):
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
record(robot, policy, cfg, run_time_s=1)
record(
robot,
policy,
cfg,
warmup_time_s=1,
episode_time_s=1,
num_episodes=2,
run_compute_stats=False,
push_to_hub=False,
video=False,
display_cameras=False,
)
del robot

View File

@ -1,11 +1,23 @@
"""
Tests meant to be used locally and launched manually.
Tests for physical motors and their mocked versions.
If the physical motors are not connected to the computer, or not working,
the test will be skipped.
Example usage:
Example of running a specific test:
```bash
pytest -sx tests/test_motors.py::test_find_port
pytest -sx tests/test_motors.py::test_motors_bus
```
Example of running test on real dynamixel motors connected to the computer:
```bash
pytest -sx 'tests/test_motors.py::test_motors_bus[dynamixel-False]'
```
Example of running test on a mocked version of dynamixel motors:
```bash
pytest -sx 'tests/test_motors.py::test_motors_bus[dynamixel-True]'
```
"""
# TODO(rcadene): measure fps in nightly?
@ -18,38 +30,31 @@ import time
import numpy as np
import pytest
from lerobot import available_robots
from lerobot.common.robot_devices.motors.utils import MotorsBus
from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.robot_devices.motors.dynamixel import find_port
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
from lerobot.common.utils.utils import init_hydra_config
from tests.utils import ROBOT_CONFIG_PATH_TEMPLATE, require_robot
from tests.utils import TEST_MOTOR_TYPES, make_motors_bus, require_motor
def make_motors_bus(robot_type: str) -> MotorsBus:
# Instantiate a robot and return one of its leader arms
config_path = ROBOT_CONFIG_PATH_TEMPLATE.format(robot=robot_type)
robot_cfg = init_hydra_config(config_path)
robot = make_robot(robot_cfg)
first_bus_name = list(robot.leader_arms.keys())[0]
motors_bus = robot.leader_arms[first_bus_name]
return motors_bus
@pytest.mark.parametrize("motor_type, mock", TEST_MOTOR_TYPES)
@require_motor
def test_find_port(request, motor_type, mock):
if mock:
request.getfixturevalue("patch_builtins_input")
with pytest.raises(OSError):
find_port()
else:
find_port()
@pytest.mark.parametrize("robot_type", available_robots)
@require_robot
def test_find_port(request, robot_type):
from lerobot.common.robot_devices.motors.dynamixel import find_port
@pytest.mark.parametrize("motor_type, mock", TEST_MOTOR_TYPES)
@require_motor
def test_configure_motors_all_ids_1(request, motor_type, mock):
if mock:
request.getfixturevalue("patch_builtins_input")
find_port()
@pytest.mark.parametrize("robot_type", available_robots)
@require_robot
def test_configure_motors_all_ids_1(request, robot_type):
input("Are you sure you want to re-configure the motors? Press enter to continue...")
# This test expect the configuration was already correct.
motors_bus = make_motors_bus(robot_type)
motors_bus = make_motors_bus(motor_type, mock=mock)
motors_bus.connect()
motors_bus.write("Baud_Rate", [0] * len(motors_bus.motors))
motors_bus.set_bus_baudrate(9_600)
@ -57,16 +62,19 @@ def test_configure_motors_all_ids_1(request, robot_type):
del motors_bus
# Test configure
motors_bus = make_motors_bus(robot_type)
motors_bus = make_motors_bus(motor_type, mock=mock)
motors_bus.connect()
assert motors_bus.are_motors_configured()
del motors_bus
@pytest.mark.parametrize("robot_type", available_robots)
@require_robot
def test_motors_bus(request, robot_type):
motors_bus = make_motors_bus(robot_type)
@pytest.mark.parametrize("motor_type, mock", TEST_MOTOR_TYPES)
@require_motor
def test_motors_bus(request, motor_type, mock):
if mock:
request.getfixturevalue("patch_builtins_input")
motors_bus = make_motors_bus(motor_type, mock=mock)
# Test reading and writting before connecting raises an error
with pytest.raises(RobotDeviceNotConnectedError):
@ -80,7 +88,7 @@ def test_motors_bus(request, robot_type):
del motors_bus
# Test connecting
motors_bus = make_motors_bus(robot_type)
motors_bus = make_motors_bus(motor_type, mock=mock)
motors_bus.connect()
# Test connecting twice raises an error

View File

@ -1,10 +1,26 @@
"""
Tests meant to be used locally and launched manually.
Tests for physical robots and their mocked versions.
If the physical robots are not connected to the computer, or not working,
the test will be skipped.
Example usage:
Example of running a specific test:
```bash
pytest -sx tests/test_robots.py::test_robot
```
Example of running test on real robots connected to the computer:
```bash
pytest -sx 'tests/test_robots.py::test_robot[koch-False]'
pytest -sx 'tests/test_robots.py::test_robot[koch_bimanual-False]'
pytest -sx 'tests/test_robots.py::test_robot[aloha-False]'
```
Example of running test on a mocked version of robots:
```bash
pytest -sx 'tests/test_robots.py::test_robot[koch-True]'
pytest -sx 'tests/test_robots.py::test_robot[koch_bimanual-True]'
pytest -sx 'tests/test_robots.py::test_robot[aloha-True]'
```
"""
from pathlib import Path
@ -12,41 +28,42 @@ from pathlib import Path
import pytest
import torch
from lerobot import available_robots
from lerobot.common.robot_devices.robots.factory import make_robot as make_robot_from_cfg
from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.robot_devices.robots.manipulator import ManipulatorRobot
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
from lerobot.common.utils.utils import init_hydra_config
from tests.utils import ROBOT_CONFIG_PATH_TEMPLATE, require_robot
from tests.utils import TEST_ROBOT_TYPES, make_robot, require_robot
def make_robot(robot_type: str, overrides: list[str] | None = None) -> Robot:
config_path = ROBOT_CONFIG_PATH_TEMPLATE.format(robot=robot_type)
robot_cfg = init_hydra_config(config_path, overrides)
robot = make_robot_from_cfg(robot_cfg)
return robot
@pytest.mark.parametrize("robot_type", available_robots)
@pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES)
@require_robot
def test_robot(tmpdir, request, robot_type):
def test_robot(tmpdir, request, robot_type, mock):
# TODO(rcadene): measure fps in nightly?
# TODO(rcadene): test logs
# TODO(rcadene): add compatibility with other robots
from lerobot.common.robot_devices.robots.manipulator import ManipulatorRobot
# Save calibration preset
tmpdir = Path(tmpdir)
calibration_dir = tmpdir / robot_type
robot_kwargs = {"robot_type": robot_type}
if robot_type == "aloha" and mock:
# To simplify unit test, we do not rerun manual calibration for Aloha mock=True.
# Instead, we use the files from '.cache/calibration/aloha_default'
overrides_calibration_dir = None
else:
if mock:
request.getfixturevalue("patch_builtins_input")
# Create an empty calibration directory to trigger manual calibration
tmpdir = Path(tmpdir)
calibration_dir = tmpdir / robot_type
overrides_calibration_dir = [f"calibration_dir={calibration_dir}"]
robot_kwargs["calibration_dir"] = calibration_dir
# Test connecting without devices raises an error
robot = ManipulatorRobot()
robot = ManipulatorRobot(**robot_kwargs)
with pytest.raises(ValueError):
robot.connect()
del robot
# Test using robot before connecting raises an error
robot = ManipulatorRobot()
robot = ManipulatorRobot(**robot_kwargs)
with pytest.raises(RobotDeviceNotConnectedError):
robot.teleop_step()
with pytest.raises(RobotDeviceNotConnectedError):
@ -61,21 +78,23 @@ def test_robot(tmpdir, request, robot_type):
# Test deleting the object without connecting first
del robot
# Test connecting
robot = make_robot(robot_type, overrides=[f"calibration_dir={calibration_dir}"])
robot.connect() # run the manual calibration precedure
# Test connecting (triggers manual calibration)
robot = make_robot(robot_type, overrides=overrides_calibration_dir, mock=mock)
robot.connect()
assert robot.is_connected
# Test connecting twice raises an error
with pytest.raises(RobotDeviceAlreadyConnectedError):
robot.connect()
# Test disconnecting with `__del__`
del robot
# TODO(rcadene, aliberts): Test disconnecting with `__del__` instead of `disconnect`
# del robot
robot.disconnect()
# Test teleop can run
robot = make_robot(robot_type, overrides=[f"calibration_dir={calibration_dir}"])
robot.calibration_dir = calibration_dir
robot = make_robot(robot_type, overrides=overrides_calibration_dir, mock=mock)
if overrides_calibration_dir is not None:
robot.calibration_dir = calibration_dir
robot.connect()
robot.teleop_step()
@ -121,4 +140,3 @@ def test_robot(tmpdir, request, robot_type):
assert not robot.leader_arms[name].is_connected
for name in robot.cameras:
assert not robot.cameras[name].is_connected
del robot

View File

@ -13,13 +13,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import platform
from copy import copy
from functools import wraps
import pytest
import torch
from lerobot import available_cameras, available_motors, available_robots
from lerobot.common.robot_devices.cameras.utils import Camera
from lerobot.common.robot_devices.motors.utils import MotorsBus
from lerobot.common.robot_devices.robots.factory import make_robot as make_robot_from_cfg
from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.utils.import_utils import is_package_available
from lerobot.common.utils.utils import init_hydra_config
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
@ -28,6 +36,32 @@ DEFAULT_CONFIG_PATH = "lerobot/configs/default.yaml"
ROBOT_CONFIG_PATH_TEMPLATE = "lerobot/configs/robot/{robot}.yaml"
TEST_ROBOT_TYPES = []
for robot_type in available_robots:
TEST_ROBOT_TYPES += [(robot_type, True), (robot_type, False)]
TEST_CAMERA_TYPES = []
for camera_type in available_cameras:
TEST_CAMERA_TYPES += [(camera_type, True), (camera_type, False)]
TEST_MOTOR_TYPES = []
for motor_type in available_motors:
TEST_MOTOR_TYPES += [(motor_type, True), (motor_type, False)]
# Camera indices used for connecting physical cameras
OPENCV_CAMERA_INDEX = int(os.environ.get("LEROBOT_TEST_OPENCV_CAMERA_INDEX", 0))
INTELREALSENSE_CAMERA_INDEX = int(os.environ.get("LEROBOT_TEST_INTELREALSENSE_CAMERA_INDEX", 128422271614))
DYNAMIXEL_PORT = "/dev/tty.usbmodem575E0032081"
DYNAMIXEL_MOTORS = {
"shoulder_pan": [1, "xl430-w250"],
"shoulder_lift": [2, "xl430-w250"],
"elbow_flex": [3, "xl330-m288"],
"wrist_flex": [4, "xl330-m288"],
"wrist_roll": [5, "xl330-m288"],
"gripper": [6, "xl330-m288"],
}
def require_x86_64_kernel(func):
"""
@ -173,13 +207,136 @@ def require_robot(func):
# Access the pytest request context to get the is_robot_available fixture
request = kwargs.get("request")
robot_type = kwargs.get("robot_type")
mock = kwargs.get("mock")
if robot_type is None:
raise ValueError("The 'robot_type' must be an argument of the test function.")
if request is None:
raise ValueError("The 'request' fixture must be passed to the test function as a parameter.")
raise ValueError("The 'request' fixture must be an argument of the test function.")
if mock is None:
raise ValueError("The 'mock' variable must be an argument of the test function.")
# The function `is_robot_available` is defined in `tests/conftest.py`
if not request.getfixturevalue("is_robot_available"):
# Run test with a real robot. Skip test if robot connection fails.
if not mock and not request.getfixturevalue("is_robot_available"):
pytest.skip(f"A {robot_type} robot is not available.")
return func(*args, **kwargs)
return wrapper
def require_camera(func):
@wraps(func)
def wrapper(*args, **kwargs):
# Access the pytest request context to get the is_camera_available fixture
request = kwargs.get("request")
camera_type = kwargs.get("camera_type")
mock = kwargs.get("mock")
if request is None:
raise ValueError("The 'request' fixture must be an argument of the test function.")
if camera_type is None:
raise ValueError("The 'camera_type' must be an argument of the test function.")
if mock is None:
raise ValueError("The 'mock' variable must be an argument of the test function.")
if not mock and not request.getfixturevalue("is_camera_available"):
pytest.skip(f"A {camera_type} camera is not available.")
return func(*args, **kwargs)
return wrapper
def require_motor(func):
@wraps(func)
def wrapper(*args, **kwargs):
# Access the pytest request context to get the is_motor_available fixture
request = kwargs.get("request")
motor_type = kwargs.get("motor_type")
mock = kwargs.get("mock")
if request is None:
raise ValueError("The 'request' fixture must be an argument of the test function.")
if motor_type is None:
raise ValueError("The 'motor_type' must be an argument of the test function.")
if mock is None:
raise ValueError("The 'mock' variable must be an argument of the test function.")
if not mock and not request.getfixturevalue("is_motor_available"):
pytest.skip(f"A {motor_type} motor is not available.")
return func(*args, **kwargs)
return wrapper
def make_robot(robot_type: str, overrides: list[str] | None = None, mock=False) -> Robot:
if mock:
overrides = [] if overrides is None else copy(overrides)
# Explicitely add mock argument to the cameras and set it to true
# TODO(rcadene, aliberts): redesign when we drop hydra
if robot_type == "koch":
overrides.append("+leader_arms.main.mock=true")
overrides.append("+follower_arms.main.mock=true")
if "~cameras" not in overrides:
overrides.append("+cameras.laptop.mock=true")
overrides.append("+cameras.phone.mock=true")
elif robot_type == "koch_bimanual":
overrides.append("+leader_arms.left.mock=true")
overrides.append("+leader_arms.right.mock=true")
overrides.append("+follower_arms.left.mock=true")
overrides.append("+follower_arms.right.mock=true")
if "~cameras" not in overrides:
overrides.append("+cameras.laptop.mock=true")
overrides.append("+cameras.phone.mock=true")
elif robot_type == "aloha":
overrides.append("+leader_arms.left.mock=true")
overrides.append("+leader_arms.right.mock=true")
overrides.append("+follower_arms.left.mock=true")
overrides.append("+follower_arms.right.mock=true")
if "~cameras" not in overrides:
overrides.append("+cameras.cam_high.mock=true")
overrides.append("+cameras.cam_low.mock=true")
overrides.append("+cameras.cam_left_wrist.mock=true")
overrides.append("+cameras.cam_right_wrist.mock=true")
else:
raise NotImplementedError(robot_type)
config_path = ROBOT_CONFIG_PATH_TEMPLATE.format(robot=robot_type)
robot_cfg = init_hydra_config(config_path, overrides)
robot = make_robot_from_cfg(robot_cfg)
return robot
def make_camera(camera_type, **kwargs) -> Camera:
if camera_type == "opencv":
from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera
camera_index = kwargs.pop("camera_index", OPENCV_CAMERA_INDEX)
return OpenCVCamera(camera_index, **kwargs)
elif camera_type == "intelrealsense":
from lerobot.common.robot_devices.cameras.intelrealsense import IntelRealSenseCamera
camera_index = kwargs.pop("camera_index", INTELREALSENSE_CAMERA_INDEX)
return IntelRealSenseCamera(camera_index, **kwargs)
else:
raise ValueError(f"The camera type '{camera_type}' is not valid.")
def make_motors_bus(motor_type: str, **kwargs) -> MotorsBus:
if motor_type == "dynamixel":
from lerobot.common.robot_devices.motors.dynamixel import DynamixelMotorsBus
port = kwargs.pop("port", DYNAMIXEL_PORT)
motors = kwargs.pop("motors", DYNAMIXEL_MOTORS)
return DynamixelMotorsBus(port, motors, **kwargs)
else:
raise ValueError(f"The motor type '{motor_type}' is not valid.")