All tests passing except test_control_robot.py
This commit is contained in:
parent
a0432f1608
commit
798373e7bf
|
@ -19,7 +19,7 @@ import gymnasium as gym
|
||||||
from omegaconf import DictConfig
|
from omegaconf import DictConfig
|
||||||
|
|
||||||
|
|
||||||
def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv:
|
def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv | None:
|
||||||
"""Makes a gym vector environment according to the evaluation config.
|
"""Makes a gym vector environment according to the evaluation config.
|
||||||
|
|
||||||
n_envs can be used to override eval.batch_size in the configuration. Must be at least 1.
|
n_envs can be used to override eval.batch_size in the configuration. Must be at least 1.
|
||||||
|
@ -27,6 +27,9 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv
|
||||||
if n_envs is not None and n_envs < 1:
|
if n_envs is not None and n_envs < 1:
|
||||||
raise ValueError("`n_envs must be at least 1")
|
raise ValueError("`n_envs must be at least 1")
|
||||||
|
|
||||||
|
if cfg.env.name == "real_world":
|
||||||
|
return
|
||||||
|
|
||||||
package_name = f"gym_{cfg.env.name}"
|
package_name = f"gym_{cfg.env.name}"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import argparse
|
import argparse
|
||||||
import math
|
import math
|
||||||
|
import threading
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass, replace
|
from dataclasses import dataclass, replace
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -8,9 +9,8 @@ from threading import Thread
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
|
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
|
||||||
# Using 1 thread to avoid blocking the main thread.
|
# Use 1 thread to avoid blocking the main thread. Especially useful during data collection
|
||||||
# Especially useful during data collection when other threads are used
|
# when other threads are used to save the images.
|
||||||
# to save the images.
|
|
||||||
cv2.setNumThreads(1)
|
cv2.setNumThreads(1)
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
@ -89,6 +89,10 @@ class OpenCVCameraConfig:
|
||||||
height: int | None = None
|
height: int | None = None
|
||||||
color: str = "rgb"
|
color: str = "rgb"
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.color not in ["rgb", "bgr"]:
|
||||||
|
raise ValueError(f"Expected color values are 'rgb' or 'bgr', but {self.color} is provided.")
|
||||||
|
|
||||||
|
|
||||||
class OpenCVCamera:
|
class OpenCVCamera:
|
||||||
# TODO(rcadene): improve dosctring
|
# TODO(rcadene): improve dosctring
|
||||||
|
@ -122,12 +126,10 @@ class OpenCVCamera:
|
||||||
if not isinstance(self.camera_index, int):
|
if not isinstance(self.camera_index, int):
|
||||||
raise ValueError(f"Camera index must be provided as an int, but {self.camera_index} was given instead.")
|
raise ValueError(f"Camera index must be provided as an int, but {self.camera_index} was given instead.")
|
||||||
|
|
||||||
if self.color not in ["rgb", "bgr"]:
|
|
||||||
raise ValueError(f"Expected color values are 'rgb' or 'bgr', but {self.color} is provided.")
|
|
||||||
|
|
||||||
self.camera = None
|
self.camera = None
|
||||||
self.is_connected = False
|
self.is_connected = False
|
||||||
self.thread = None
|
self.thread = None
|
||||||
|
self.stop_event = None
|
||||||
self.color_image = None
|
self.color_image = None
|
||||||
self.logs = {}
|
self.logs = {}
|
||||||
|
|
||||||
|
@ -159,26 +161,26 @@ class OpenCVCamera:
|
||||||
# needs to be re-created.
|
# needs to be re-created.
|
||||||
self.camera = cv2.VideoCapture(self.camera_index)
|
self.camera = cv2.VideoCapture(self.camera_index)
|
||||||
|
|
||||||
if self.fps:
|
if self.fps is not None:
|
||||||
self.camera.set(cv2.CAP_PROP_FPS, self.fps)
|
self.camera.set(cv2.CAP_PROP_FPS, self.fps)
|
||||||
if self.width:
|
if self.width is not None:
|
||||||
self.camera.set(cv2.CAP_PROP_FRAME_WIDTH, self.width)
|
self.camera.set(cv2.CAP_PROP_FRAME_WIDTH, self.width)
|
||||||
if self.height:
|
if self.height is not None:
|
||||||
self.camera.set(cv2.CAP_PROP_FRAME_HEIGHT, self.height)
|
self.camera.set(cv2.CAP_PROP_FRAME_HEIGHT, self.height)
|
||||||
|
|
||||||
actual_fps = self.camera.get(cv2.CAP_PROP_FPS)
|
actual_fps = self.camera.get(cv2.CAP_PROP_FPS)
|
||||||
actual_width = self.camera.get(cv2.CAP_PROP_FRAME_WIDTH)
|
actual_width = self.camera.get(cv2.CAP_PROP_FRAME_WIDTH)
|
||||||
actual_height = self.camera.get(cv2.CAP_PROP_FRAME_HEIGHT)
|
actual_height = self.camera.get(cv2.CAP_PROP_FRAME_HEIGHT)
|
||||||
|
|
||||||
if self.fps and not math.isclose(self.fps, actual_fps, rel_tol=1e-3):
|
if self.fps is not None and not math.isclose(self.fps, actual_fps, rel_tol=1e-3):
|
||||||
raise OSError(
|
raise OSError(
|
||||||
f"Can't set {self.fps=} for camera {self.camera_index}. Actual value is {actual_fps}."
|
f"Can't set {self.fps=} for camera {self.camera_index}. Actual value is {actual_fps}."
|
||||||
)
|
)
|
||||||
if self.width and self.width != actual_width:
|
if self.width is not None and self.width != actual_width:
|
||||||
raise OSError(
|
raise OSError(
|
||||||
f"Can't set {self.width=} for camera {self.camera_index}. Actual value is {actual_width}."
|
f"Can't set {self.width=} for camera {self.camera_index}. Actual value is {actual_width}."
|
||||||
)
|
)
|
||||||
if self.height and self.height != actual_height:
|
if self.height is not None and self.height != actual_height:
|
||||||
raise OSError(
|
raise OSError(
|
||||||
f"Can't set {self.height=} for camera {self.camera_index}. Actual value is {actual_height}."
|
f"Can't set {self.height=} for camera {self.camera_index}. Actual value is {actual_height}."
|
||||||
)
|
)
|
||||||
|
@ -216,6 +218,10 @@ class OpenCVCamera:
|
||||||
if requested_color == "rgb":
|
if requested_color == "rgb":
|
||||||
color_image = cv2.cvtColor(color_image, cv2.COLOR_BGR2RGB)
|
color_image = cv2.cvtColor(color_image, cv2.COLOR_BGR2RGB)
|
||||||
|
|
||||||
|
h, w, _ = color_image.shape
|
||||||
|
if h != self.height or w != self.width:
|
||||||
|
raise OSError(f"Can't capture color image with expected height and width ({self.height} x {self.width}). ({h} x {w}) returned instead.")
|
||||||
|
|
||||||
# log the number of seconds it took to read the image
|
# log the number of seconds it took to read the image
|
||||||
self.logs["delta_timestamp_s"] = time.perf_counter() - start_time
|
self.logs["delta_timestamp_s"] = time.perf_counter() - start_time
|
||||||
|
|
||||||
|
@ -225,7 +231,7 @@ class OpenCVCamera:
|
||||||
return color_image
|
return color_image
|
||||||
|
|
||||||
def read_loop(self):
|
def read_loop(self):
|
||||||
while True:
|
while self.stop_event is None or not self.stop_event.is_set():
|
||||||
self.color_image = self.read()
|
self.color_image = self.read()
|
||||||
|
|
||||||
def async_read(self):
|
def async_read(self):
|
||||||
|
@ -233,6 +239,7 @@ class OpenCVCamera:
|
||||||
raise RobotDeviceNotConnectedError(f"OpenCVCamera({self.camera_index}) is not connected. Try running `camera.connect()` first.")
|
raise RobotDeviceNotConnectedError(f"OpenCVCamera({self.camera_index}) is not connected. Try running `camera.connect()` first.")
|
||||||
|
|
||||||
if self.thread is None:
|
if self.thread is None:
|
||||||
|
self.stop_event = threading.Event()
|
||||||
self.thread = Thread(target=self.read_loop, args=())
|
self.thread = Thread(target=self.read_loop, args=())
|
||||||
self.thread.daemon = True
|
self.thread.daemon = True
|
||||||
self.thread.start()
|
self.thread.start()
|
||||||
|
@ -242,27 +249,29 @@ class OpenCVCamera:
|
||||||
num_tries += 1
|
num_tries += 1
|
||||||
time.sleep(1/self.fps)
|
time.sleep(1/self.fps)
|
||||||
if num_tries > self.fps:
|
if num_tries > self.fps:
|
||||||
if self.thread.ident is None and not self.thread.is_alive():
|
if self.thread.ident is None or not self.thread.is_alive():
|
||||||
raise Exception("The thread responsible for `self.async_read()` took too much time to start. There might be an issue. Verify that `self.thread.start()` has been called.")
|
raise Exception("The thread responsible for `self.async_read()` took too much time to start. There might be an issue. Verify that `self.thread.start()` has been called.")
|
||||||
|
|
||||||
return self.color_image
|
return self.color_image
|
||||||
|
|
||||||
def disconnect(self):
|
def disconnect(self):
|
||||||
if not self.is_connected:
|
if not self.is_connected:
|
||||||
raise RobotDeviceNotConnectedError(f"OpenCVCamera({self.camera_index}) is not connected. Try running `camera.connect()` first.")
|
raise RobotDeviceNotConnectedError(f"OpenCVCamera({self.camera_index}) is not connected. Try running `camera.connect()` first.")
|
||||||
|
|
||||||
self.camera.release()
|
|
||||||
self.camera = None
|
|
||||||
|
|
||||||
if self.thread is not None and self.thread.is_alive():
|
if self.thread is not None and self.thread.is_alive():
|
||||||
# wait for the thread to finish
|
# wait for the thread to finish
|
||||||
|
self.stop_event.set()
|
||||||
self.thread.join()
|
self.thread.join()
|
||||||
self.thread = None
|
self.thread = None
|
||||||
|
self.stop_event = None
|
||||||
|
|
||||||
|
self.camera.release()
|
||||||
|
self.camera = None
|
||||||
|
|
||||||
self.is_connected = False
|
self.is_connected = False
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
if self.is_connected:
|
if getattr(self, "is_connected", False):
|
||||||
self.disconnect()
|
self.disconnect()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -101,6 +101,7 @@ MODEL_CONTROL_TABLE = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(rcadene): find better namming for these functions
|
||||||
def uint32_to_int32(values: np.ndarray):
|
def uint32_to_int32(values: np.ndarray):
|
||||||
"""
|
"""
|
||||||
Convert an unsigned 32-bit integer array to a signed 32-bit integer array.
|
Convert an unsigned 32-bit integer array to a signed 32-bit integer array.
|
||||||
|
@ -120,35 +121,18 @@ def int32_to_uint32(values: np.ndarray):
|
||||||
values[i] = values[i] + 4294967296
|
values[i] = values[i] + 4294967296
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
# def motor_position_to_angle(position: np.ndarray) -> np.ndarray:
|
||||||
def motor_position_to_angle(position: np.ndarray) -> np.ndarray:
|
|
||||||
"""
|
|
||||||
Convert from motor position in [-2048, 2048] to radian in [-pi, pi]
|
|
||||||
"""
|
|
||||||
return (position / 2048) * 3.14
|
|
||||||
|
|
||||||
|
|
||||||
def motor_angle_to_position(angle: np.ndarray) -> np.ndarray:
|
|
||||||
"""
|
|
||||||
Convert from radian in [-pi, pi] to motor position in [-2048, 2048]
|
|
||||||
"""
|
|
||||||
return ((angle / 3.14) * 2048).astype(np.int64)
|
|
||||||
|
|
||||||
|
|
||||||
# def pwm2vel(pwm: np.ndarray) -> np.ndarray:
|
|
||||||
# """
|
# """
|
||||||
# :param pwm: numpy array of pwm/s joint velocities
|
# Convert from motor position in [-2048, 2048] to radian in [-pi, pi]
|
||||||
# :return: numpy array of rad/s joint velocities
|
|
||||||
# """
|
# """
|
||||||
# return pwm * 3.14 / 2048
|
# return (position / 2048) * 3.14
|
||||||
|
|
||||||
|
|
||||||
# def vel2pwm(vel: np.ndarray) -> np.ndarray:
|
# def motor_angle_to_position(angle: np.ndarray) -> np.ndarray:
|
||||||
# """
|
# """
|
||||||
# :param vel: numpy array of rad/s joint velocities
|
# Convert from radian in [-pi, pi] to motor position in [-2048, 2048]
|
||||||
# :return: numpy array of pwm/s joint velocities
|
|
||||||
# """
|
# """
|
||||||
# return (vel * 2048 / 3.14).astype(np.int64)
|
# return ((angle / 3.14) * 2048).astype(np.int64)
|
||||||
|
|
||||||
|
|
||||||
def get_group_sync_key(data_name, motor_names):
|
def get_group_sync_key(data_name, motor_names):
|
||||||
|
@ -285,15 +269,18 @@ class DynamixelMotorsBus:
|
||||||
|
|
||||||
return values
|
return values
|
||||||
|
|
||||||
def read(self, data_name, motor_names: list[str] | None = None):
|
def read(self, data_name, motor_names: str | list[str] | None = None):
|
||||||
if not self.is_connected:
|
if not self.is_connected:
|
||||||
raise ValueError(f"DynamixelMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`.")
|
raise RobotDeviceNotConnectedError(f"DynamixelMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`.")
|
||||||
|
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
if motor_names is None:
|
if motor_names is None:
|
||||||
motor_names = self.motor_names
|
motor_names = self.motor_names
|
||||||
|
|
||||||
|
if isinstance(motor_names, str):
|
||||||
|
motor_names = [motor_names]
|
||||||
|
|
||||||
motor_ids = []
|
motor_ids = []
|
||||||
models = []
|
models = []
|
||||||
for name in motor_names:
|
for name in motor_names:
|
||||||
|
@ -352,7 +339,7 @@ class DynamixelMotorsBus:
|
||||||
|
|
||||||
def write(self, data_name, values: int | float | np.ndarray, motor_names: str | list[str] | None = None):
|
def write(self, data_name, values: int | float | np.ndarray, motor_names: str | list[str] | None = None):
|
||||||
if not self.is_connected:
|
if not self.is_connected:
|
||||||
raise ValueError(f"DynamixelMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`.")
|
raise RobotDeviceNotConnectedError(f"DynamixelMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`.")
|
||||||
|
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
@ -444,10 +431,17 @@ class DynamixelMotorsBus:
|
||||||
if not self.is_connected:
|
if not self.is_connected:
|
||||||
raise RobotDeviceNotConnectedError(f"DynamixelMotorsBus({self.port}) is not connected. Try running `motors_bus.connect()` first.")
|
raise RobotDeviceNotConnectedError(f"DynamixelMotorsBus({self.port}) is not connected. Try running `motors_bus.connect()` first.")
|
||||||
|
|
||||||
closePort
|
if self.port_handler is not None:
|
||||||
|
self.port_handler.closePort()
|
||||||
|
self.port_handler = None
|
||||||
|
|
||||||
|
self.packet_handler = None
|
||||||
|
self.group_readers = {}
|
||||||
|
self.group_writers = {}
|
||||||
|
self.is_connected = False
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
if self.is_connected:
|
if getattr(self, "is_connected", False):
|
||||||
self.disconnect()
|
self.disconnect()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -14,15 +14,21 @@ from lerobot.common.robot_devices.motors.dynamixel import (
|
||||||
TorqueMode,
|
TorqueMode,
|
||||||
)
|
)
|
||||||
from lerobot.common.robot_devices.motors.utils import MotorsBus
|
from lerobot.common.robot_devices.motors.utils import MotorsBus
|
||||||
|
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
|
||||||
|
|
||||||
|
URL_HORIZONTAL_POSITION = {
|
||||||
|
"follower": "https://raw.githubusercontent.com/huggingface/lerobot/main/media/koch/follower_horizontal.png",
|
||||||
|
"leader": "https://raw.githubusercontent.com/huggingface/lerobot/main/media/koch/leader_horizontal.png",
|
||||||
|
}
|
||||||
|
URL_90_DEGREE_POSITION = {
|
||||||
|
"follower": "https://raw.githubusercontent.com/huggingface/lerobot/main/media/koch/follower_90_degree.png",
|
||||||
|
"leader": "https://raw.githubusercontent.com/huggingface/lerobot/main/media/koch/leader_90_degree.png",
|
||||||
|
}
|
||||||
|
|
||||||
########################################################################
|
########################################################################
|
||||||
# Calibration logic
|
# Calibration logic
|
||||||
########################################################################
|
########################################################################
|
||||||
|
|
||||||
# TARGET_HORIZONTAL_POSITION = motor_position_to_angle(np.array([0, -1024, 1024, 0, -1024, 0]))
|
|
||||||
# TARGET_90_DEGREE_POSITION = motor_position_to_angle(np.array([1024, 0, 0, 1024, 0, -1024]))
|
|
||||||
# GRIPPER_OPEN = motor_position_to_angle(np.array([-400]))
|
|
||||||
|
|
||||||
TARGET_HORIZONTAL_POSITION = np.array([0, -1024, 1024, 0, -1024, 0])
|
TARGET_HORIZONTAL_POSITION = np.array([0, -1024, 1024, 0, -1024, 0])
|
||||||
TARGET_90_DEGREE_POSITION = np.array([1024, 0, 0, 1024, 0, -1024])
|
TARGET_90_DEGREE_POSITION = np.array([1024, 0, 0, 1024, 0, -1024])
|
||||||
GRIPPER_OPEN = np.array([-400])
|
GRIPPER_OPEN = np.array([-400])
|
||||||
|
@ -137,11 +143,16 @@ def reset_arm(arm: MotorsBus):
|
||||||
arm.write("Drive_Mode", DriveMode.NON_INVERTED.value)
|
arm.write("Drive_Mode", DriveMode.NON_INVERTED.value)
|
||||||
|
|
||||||
|
|
||||||
def run_arm_calibration(arm: MotorsBus, name: str):
|
def run_arm_calibration(arm: MotorsBus, name: str, arm_type: str):
|
||||||
|
""" Example of usage:
|
||||||
|
```python
|
||||||
|
run_arm_calibration(arm, "left", "follower")
|
||||||
|
```
|
||||||
|
"""
|
||||||
reset_arm(arm)
|
reset_arm(arm)
|
||||||
|
|
||||||
# TODO(rcadene): document what position 1 mean
|
# TODO(rcadene): document what position 1 mean
|
||||||
print(f"Please move the '{name}' arm to the horizontal position (gripper fully closed)")
|
print(f"Please move the '{name} {arm_type}' arm to the horizontal position (gripper fully closed, see {URL_HORIZONTAL_POSITION[arm_type]})")
|
||||||
input("Press Enter to continue...")
|
input("Press Enter to continue...")
|
||||||
|
|
||||||
horizontal_homing_offset = compute_homing_offset(
|
horizontal_homing_offset = compute_homing_offset(
|
||||||
|
@ -149,7 +160,7 @@ def run_arm_calibration(arm: MotorsBus, name: str):
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO(rcadene): document what position 2 mean
|
# TODO(rcadene): document what position 2 mean
|
||||||
print(f"Please move the '{name}' arm to the 90 degree position (gripper fully open)")
|
print(f"Please move the '{name} {arm_type}' arm to the 90 degree position (gripper fully open, see {URL_90_DEGREE_POSITION[arm_type]})")
|
||||||
input("Press Enter to continue...")
|
input("Press Enter to continue...")
|
||||||
|
|
||||||
drive_mode = compute_drive_mode(arm, horizontal_homing_offset)
|
drive_mode = compute_drive_mode(arm, horizontal_homing_offset)
|
||||||
|
@ -184,42 +195,15 @@ class KochRobotConfig:
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Define all the components of the robot
|
# Define all components of the robot
|
||||||
leader_arms: dict[str, MotorsBus] = field(
|
leader_arms: dict[str, MotorsBus] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {}
|
||||||
"main": DynamixelMotorsBus(
|
|
||||||
port="/dev/tty.usbmodem575E0031751",
|
|
||||||
motors={
|
|
||||||
# name: (index, model)
|
|
||||||
"shoulder_pan": (1, "xl330-m077"),
|
|
||||||
"shoulder_lift": (2, "xl330-m077"),
|
|
||||||
"elbow_flex": (3, "xl330-m077"),
|
|
||||||
"wrist_flex": (4, "xl330-m077"),
|
|
||||||
"wrist_roll": (5, "xl330-m077"),
|
|
||||||
"gripper": (6, "xl330-m077"),
|
|
||||||
},
|
|
||||||
),
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
follower_arms: dict[str, MotorsBus] = field(
|
follower_arms: dict[str, MotorsBus] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {}
|
||||||
"main": DynamixelMotorsBus(
|
|
||||||
port="/dev/tty.usbmodem575E0032081",
|
|
||||||
motors={
|
|
||||||
# name: (index, model)
|
|
||||||
"shoulder_pan": (1, "xl430-w250"),
|
|
||||||
"shoulder_lift": (2, "xl430-w250"),
|
|
||||||
"elbow_flex": (3, "xl330-m288"),
|
|
||||||
"wrist_flex": (4, "xl330-m288"),
|
|
||||||
"wrist_roll": (5, "xl330-m288"),
|
|
||||||
"gripper": (6, "xl330-m288"),
|
|
||||||
},
|
|
||||||
),
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
cameras: dict[str, Camera] = field(default_factory=lambda: {})
|
cameras: dict[str, Camera] = field(default_factory=lambda: {})
|
||||||
|
|
||||||
|
|
||||||
class KochRobot:
|
class KochRobot:
|
||||||
"""Tau Robotics: https://tau-robotics.com
|
"""Tau Robotics: https://tau-robotics.com
|
||||||
|
|
||||||
|
@ -306,6 +290,11 @@ class KochRobot:
|
||||||
# Orders the robot to move
|
# Orders the robot to move
|
||||||
robot.send_action(action)
|
robot.send_action(action)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Example of disconnecting which is not mandatory since we disconnect when the object is deleted:
|
||||||
|
```python
|
||||||
|
robot.disconnect()
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -328,59 +317,68 @@ class KochRobot:
|
||||||
|
|
||||||
def connect(self):
|
def connect(self):
|
||||||
if self.is_connected:
|
if self.is_connected:
|
||||||
raise ValueError(f"KochRobot is already connected.")
|
raise RobotDeviceAlreadyConnectedError(f"KochRobot is already connected. Do not run `robot.connect()` twice.")
|
||||||
|
|
||||||
|
if not self.leader_arms and not self.follower_arms and not self.cameras:
|
||||||
|
raise ValueError("KochRobot doesn't have any device to connect. See example of usage in docstring of the class.")
|
||||||
|
|
||||||
|
# Connect the arms
|
||||||
for name in self.follower_arms:
|
for name in self.follower_arms:
|
||||||
self.follower_arms[name].connect()
|
self.follower_arms[name].connect()
|
||||||
self.leader_arms[name].connect()
|
self.leader_arms[name].connect()
|
||||||
|
|
||||||
|
# Reset the arms and load or run calibration
|
||||||
if self.calibration_path.exists():
|
if self.calibration_path.exists():
|
||||||
# Reset all arms before setting calibration
|
# Reset all arms before setting calibration
|
||||||
for name in self.follower_arms:
|
for name in self.follower_arms:
|
||||||
reset_arm(self.follower_arms[name])
|
reset_arm(self.follower_arms[name])
|
||||||
|
|
||||||
for name in self.leader_arms:
|
for name in self.leader_arms:
|
||||||
reset_arm(self.leader_arms[name])
|
reset_arm(self.leader_arms[name])
|
||||||
|
|
||||||
with open(self.calibration_path, "rb") as f:
|
with open(self.calibration_path, "rb") as f:
|
||||||
calibration = pickle.load(f)
|
calibration = pickle.load(f)
|
||||||
else:
|
else:
|
||||||
|
# Run calibration process which begins by reseting all arms
|
||||||
calibration = self.run_calibration()
|
calibration = self.run_calibration()
|
||||||
|
|
||||||
self.calibration_path.parent.mkdir(parents=True, exist_ok=True)
|
self.calibration_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
with open(self.calibration_path, "wb") as f:
|
with open(self.calibration_path, "wb") as f:
|
||||||
pickle.dump(calibration, f)
|
pickle.dump(calibration, f)
|
||||||
|
|
||||||
|
# Set calibration
|
||||||
for name in self.follower_arms:
|
for name in self.follower_arms:
|
||||||
self.follower_arms[name].set_calibration(calibration[f"follower_{name}"])
|
self.follower_arms[name].set_calibration(calibration[f"follower_{name}"])
|
||||||
self.follower_arms[name].write("Torque_Enable", 1)
|
|
||||||
|
|
||||||
for name in self.leader_arms:
|
for name in self.leader_arms:
|
||||||
self.leader_arms[name].set_calibration(calibration[f"leader_{name}"])
|
self.leader_arms[name].set_calibration(calibration[f"leader_{name}"])
|
||||||
# TODO(rcadene): add comments
|
|
||||||
self.leader_arms[name].write("Goal_Position", GRIPPER_OPEN, "gripper")
|
|
||||||
self.leader_arms[name].write("Torque_Enable", 1, "gripper")
|
|
||||||
|
|
||||||
|
# Enable torque on all motors of the follower arms
|
||||||
|
for name in self.follower_arms:
|
||||||
|
self.follower_arms[name].write("Torque_Enable", 1)
|
||||||
|
|
||||||
|
# Enable torque on the gripper of the leader arms, and move it to 45 degrees,
|
||||||
|
# so that we can use it as a trigger to close the gripper of the follower arms.
|
||||||
|
for name in self.leader_arms:
|
||||||
|
self.leader_arms[name].write("Torque_Enable", 1, "gripper")
|
||||||
|
self.leader_arms[name].write("Goal_Position", GRIPPER_OPEN, "gripper")
|
||||||
|
|
||||||
|
# Connect the cameras
|
||||||
for name in self.cameras:
|
for name in self.cameras:
|
||||||
self.cameras[name].connect()
|
self.cameras[name].connect()
|
||||||
|
|
||||||
self.is_connected = True
|
self.is_connected = True
|
||||||
|
|
||||||
def run_calibration(self):
|
def run_calibration(self):
|
||||||
if not self.is_connected:
|
|
||||||
raise ValueError(f"KochRobot is not connected. You need to run `robot.connect()`.")
|
|
||||||
|
|
||||||
calibration = {}
|
calibration = {}
|
||||||
|
|
||||||
for name in self.follower_arms:
|
for name in self.follower_arms:
|
||||||
homing_offset, drive_mode = run_arm_calibration(self.follower_arms[name], f"{name} follower")
|
homing_offset, drive_mode = run_arm_calibration(self.follower_arms[name], name, "follower")
|
||||||
|
|
||||||
calibration[f"follower_{name}"] = {}
|
calibration[f"follower_{name}"] = {}
|
||||||
for idx, motor_name in enumerate(self.follower_arms[name].motor_names):
|
for idx, motor_name in enumerate(self.follower_arms[name].motor_names):
|
||||||
calibration[f"follower_{name}"][motor_name] = (homing_offset[idx], drive_mode[idx])
|
calibration[f"follower_{name}"][motor_name] = (homing_offset[idx], drive_mode[idx])
|
||||||
|
|
||||||
for name in self.leader_arms:
|
for name in self.leader_arms:
|
||||||
homing_offset, drive_mode = run_arm_calibration(self.leader_arms[name], f"{name} leader")
|
homing_offset, drive_mode = run_arm_calibration(self.leader_arms[name], name, "leader")
|
||||||
|
|
||||||
calibration[f"leader_{name}"] = {}
|
calibration[f"leader_{name}"] = {}
|
||||||
for idx, motor_name in enumerate(self.leader_arms[name].motor_names):
|
for idx, motor_name in enumerate(self.leader_arms[name].motor_names):
|
||||||
|
@ -392,7 +390,7 @@ class KochRobot:
|
||||||
self, record_data=False
|
self, record_data=False
|
||||||
) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
|
) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
|
||||||
if not self.is_connected:
|
if not self.is_connected:
|
||||||
raise ValueError(f"KochRobot is not connected. You need to run `robot.connect()`.")
|
raise RobotDeviceNotConnectedError(f"KochRobot is not connected. You need to run `robot.connect()`.")
|
||||||
|
|
||||||
# Prepare to assign the positions of the leader to the follower
|
# Prepare to assign the positions of the leader to the follower
|
||||||
leader_pos = {}
|
leader_pos = {}
|
||||||
|
@ -455,7 +453,7 @@ class KochRobot:
|
||||||
|
|
||||||
def capture_observation(self):
|
def capture_observation(self):
|
||||||
if not self.is_connected:
|
if not self.is_connected:
|
||||||
raise ValueError(f"KochRobot is not connected. You need to run `robot.connect()`.")
|
raise RobotDeviceNotConnectedError(f"KochRobot is not connected. You need to run `robot.connect()`.")
|
||||||
|
|
||||||
# Read follower position
|
# Read follower position
|
||||||
follower_pos = {}
|
follower_pos = {}
|
||||||
|
@ -481,9 +479,9 @@ class KochRobot:
|
||||||
obs_dict[f"observation.images.{name}"] = torch.from_numpy(images[name])
|
obs_dict[f"observation.images.{name}"] = torch.from_numpy(images[name])
|
||||||
return obs_dict
|
return obs_dict
|
||||||
|
|
||||||
def send_action(self, action):
|
def send_action(self, action: torch.Tensor):
|
||||||
if not self.is_connected:
|
if not self.is_connected:
|
||||||
raise ValueError(f"KochRobot is not connected. You need to run `robot.connect()`.")
|
raise RobotDeviceNotConnectedError(f"KochRobot is not connected. You need to run `robot.connect()`.")
|
||||||
|
|
||||||
from_idx = 0
|
from_idx = 0
|
||||||
to_idx = 0
|
to_idx = 0
|
||||||
|
@ -496,3 +494,22 @@ class KochRobot:
|
||||||
|
|
||||||
for name in self.follower_arms:
|
for name in self.follower_arms:
|
||||||
self.follower_arms[name].write("Goal_Position", follower_goal_pos[name].astype(np.int32))
|
self.follower_arms[name].write("Goal_Position", follower_goal_pos[name].astype(np.int32))
|
||||||
|
|
||||||
|
def disconnect(self):
|
||||||
|
if not self.is_connected:
|
||||||
|
raise RobotDeviceNotConnectedError(f"KochRobot is not connected. You need to run `robot.connect()` before disconnecting.")
|
||||||
|
|
||||||
|
for name in self.follower_arms:
|
||||||
|
self.follower_arms[name].disconnect()
|
||||||
|
|
||||||
|
for name in self.leader_arms:
|
||||||
|
self.leader_arms[name].disconnect()
|
||||||
|
|
||||||
|
for name in self.cameras:
|
||||||
|
self.cameras[name].disconnect()
|
||||||
|
|
||||||
|
self.is_connected = False
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
if getattr(self, "is_connected", False):
|
||||||
|
self.disconnect()
|
||||||
|
|
|
@ -0,0 +1,102 @@
|
||||||
|
# @package _global_
|
||||||
|
|
||||||
|
# Use `act_koch_real.yaml` to train on real-world datasets collected on Alexander Koch's robots.
|
||||||
|
# Compared to `act.yaml`, it contains 2 cameras (i.e. laptop, phone) instead of 1 camera (i.e. top).
|
||||||
|
# Also, `training.eval_freq` is set to -1. This config is used to evaluate checkpoints at a certain frequency of training steps.
|
||||||
|
# When it is set to -1, it deactivates evaluation. This is because real-world evaluation is done through our `control_robot.py` script.
|
||||||
|
# Look at the documentation in header of `control_robot.py` for more information on how to collect data , train and evaluate a policy.
|
||||||
|
#
|
||||||
|
# Example of usage for training:
|
||||||
|
# ```bash
|
||||||
|
# python lerobot/scripts/train.py \
|
||||||
|
# policy=act_koch_real \
|
||||||
|
# env=koch_real
|
||||||
|
# ```
|
||||||
|
|
||||||
|
seed: 1000
|
||||||
|
dataset_repo_id: lerobot/koch_pick_place_lego
|
||||||
|
|
||||||
|
override_dataset_stats:
|
||||||
|
observation.images.laptop:
|
||||||
|
# stats from imagenet, since we use a pretrained vision model
|
||||||
|
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||||
|
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||||
|
observation.images.phone:
|
||||||
|
# stats from imagenet, since we use a pretrained vision model
|
||||||
|
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||||
|
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||||
|
|
||||||
|
training:
|
||||||
|
offline_steps: 80000
|
||||||
|
online_steps: 0
|
||||||
|
eval_freq: -1
|
||||||
|
save_freq: 10000
|
||||||
|
log_freq: 100
|
||||||
|
save_checkpoint: true
|
||||||
|
|
||||||
|
batch_size: 8
|
||||||
|
lr: 1e-5
|
||||||
|
lr_backbone: 1e-5
|
||||||
|
weight_decay: 1e-4
|
||||||
|
grad_clip_norm: 10
|
||||||
|
online_steps_between_rollouts: 1
|
||||||
|
|
||||||
|
delta_timestamps:
|
||||||
|
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
||||||
|
|
||||||
|
eval:
|
||||||
|
n_episodes: 50
|
||||||
|
batch_size: 50
|
||||||
|
|
||||||
|
# See `configuration_act.py` for more details.
|
||||||
|
policy:
|
||||||
|
name: act
|
||||||
|
|
||||||
|
# Input / output structure.
|
||||||
|
n_obs_steps: 1
|
||||||
|
chunk_size: 100
|
||||||
|
n_action_steps: 100
|
||||||
|
|
||||||
|
input_shapes:
|
||||||
|
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||||
|
observation.images.laptop: [3, 480, 640]
|
||||||
|
observation.images.phone: [3, 480, 640]
|
||||||
|
observation.state: ["${env.state_dim}"]
|
||||||
|
output_shapes:
|
||||||
|
action: ["${env.action_dim}"]
|
||||||
|
|
||||||
|
# Normalization / Unnormalization
|
||||||
|
input_normalization_modes:
|
||||||
|
observation.images.laptop: mean_std
|
||||||
|
observation.images.phone: mean_std
|
||||||
|
observation.state: mean_std
|
||||||
|
output_normalization_modes:
|
||||||
|
action: mean_std
|
||||||
|
|
||||||
|
# Architecture.
|
||||||
|
# Vision backbone.
|
||||||
|
vision_backbone: resnet18
|
||||||
|
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
|
||||||
|
replace_final_stride_with_dilation: false
|
||||||
|
# Transformer layers.
|
||||||
|
pre_norm: false
|
||||||
|
dim_model: 512
|
||||||
|
n_heads: 8
|
||||||
|
dim_feedforward: 3200
|
||||||
|
feedforward_activation: relu
|
||||||
|
n_encoder_layers: 4
|
||||||
|
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
|
||||||
|
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
|
||||||
|
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
|
||||||
|
n_decoder_layers: 1
|
||||||
|
# VAE.
|
||||||
|
use_vae: true
|
||||||
|
latent_dim: 32
|
||||||
|
n_vae_encoder_layers: 4
|
||||||
|
|
||||||
|
# Inference.
|
||||||
|
temporal_ensemble_momentum: null
|
||||||
|
|
||||||
|
# Training and loss computation.
|
||||||
|
dropout: 0.1
|
||||||
|
kl_weight: 10.0
|
|
@ -1,5 +1,5 @@
|
||||||
"""
|
"""
|
||||||
Example of usage:
|
Examples of usage:
|
||||||
|
|
||||||
- Unlimited teleoperation at highest frequency (~200 Hz is expected), to exit with CTRL+C:
|
- Unlimited teleoperation at highest frequency (~200 Hz is expected), to exit with CTRL+C:
|
||||||
```bash
|
```bash
|
||||||
|
@ -49,15 +49,19 @@ python lerobot/scripts/control_robot.py record_dataset \
|
||||||
--run-compute-stats 1
|
--run-compute-stats 1
|
||||||
```
|
```
|
||||||
|
|
||||||
- Train on this dataset (TODO(rcadene)):
|
- Train on this dataset with the ACT policy:
|
||||||
```bash
|
```bash
|
||||||
python lerobot/scripts/train.py
|
DATA_DIR=data python lerobot/scripts/train.py \
|
||||||
|
policy=act_koch_real \
|
||||||
|
env=koch_real \
|
||||||
|
dataset_repo_id=$USER/koch_pick_place_lego \
|
||||||
|
hydra.run.dir=outputs/train/act_koch_real
|
||||||
```
|
```
|
||||||
|
|
||||||
- Run the pretrained policy on the robot:
|
- Run the pretrained policy on the robot:
|
||||||
```bash
|
```bash
|
||||||
python lerobot/scripts/control_robot.py run_policy \
|
python lerobot/scripts/control_robot.py run_policy \
|
||||||
-p TODO(rcadene)
|
-p outputs/train/act_koch_real/checkpoints/080000/pretrained_model
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -117,29 +121,37 @@ def log_control_info(robot, dt_s, episode_index=None, frame_index=None):
|
||||||
log_items += [f"ep:{episode_index}"]
|
log_items += [f"ep:{episode_index}"]
|
||||||
if frame_index is not None:
|
if frame_index is not None:
|
||||||
log_items += [f"frame:{frame_index}"]
|
log_items += [f"frame:{frame_index}"]
|
||||||
|
|
||||||
# total step time displayed in milliseconds and its frequency
|
|
||||||
log_items += [f"dt:{dt_s * 1000:5.2f}={1/ dt_s:3.1f}hz"]
|
|
||||||
|
|
||||||
|
def log_dt(shortname, dt_val_s):
|
||||||
|
nonlocal log_items
|
||||||
|
log_items += [f"{shortname}:{dt_val_s * 1000:5.2f}={1/ dt_val_s:3.1f}hz"]
|
||||||
|
|
||||||
|
# total step time displayed in milliseconds and its frequency
|
||||||
|
log_dt("dt", dt_s)
|
||||||
|
|
||||||
for name in robot.leader_arms:
|
for name in robot.leader_arms:
|
||||||
read_dt_s = robot.logs[f'read_leader_{name}_pos_dt_s']
|
key = f'read_leader_{name}_pos_dt_s'
|
||||||
log_items += [
|
if key in robot.logs:
|
||||||
f"dtRlead{name[0]}:{read_dt_s * 1000:5.2f}={1/ read_dt_s:3.1f}hz",
|
log_dt("dtRlead", robot.logs[key])
|
||||||
]
|
|
||||||
for name in robot.follower_arms:
|
for name in robot.follower_arms:
|
||||||
write_dt_s = robot.logs[f'write_follower_{name}_goal_pos_dt_s']
|
key = f'write_follower_{name}_goal_pos_dt_s'
|
||||||
read_dt_s = robot.logs[f'read_follower_{name}_pos_dt_s']
|
if key in robot.logs:
|
||||||
log_items += [
|
log_dt("dtRfoll", robot.logs[key])
|
||||||
f"dtRfoll{name[0]}:{write_dt_s * 1000:5.2f}={1/ write_dt_s:3.1f}hz",
|
|
||||||
f"dtWfoll{name[0]}:{read_dt_s * 1000:5.2f}={1/ read_dt_s:3.1f}hz",
|
key = f'read_follower_{name}_pos_dt_s'
|
||||||
]
|
if key in robot.logs:
|
||||||
|
log_dt("dtWfoll", robot.logs[key])
|
||||||
|
|
||||||
for name in robot.cameras:
|
for name in robot.cameras:
|
||||||
read_dt_s = robot.logs[f"read_camera_{name}_dt_s"]
|
key = f"read_camera_{name}_dt_s"
|
||||||
async_read_dt_s = robot.logs[f"async_read_camera_{name}_dt_s"]
|
if key in robot.logs:
|
||||||
log_items += [
|
log_dt("dtRcam", robot.logs[key])
|
||||||
f"dtRcam{name[0]}:{read_dt_s * 1000:5.2f}={1/read_dt_s:3.1f}hz",
|
|
||||||
f"dtARcam{name[0]}:{async_read_dt_s * 1000:5.2f}={1/async_read_dt_s:3.1f}hz",
|
key = f"async_read_camera_{name}_dt_s"
|
||||||
]
|
if key in robot.logs:
|
||||||
|
log_dt("dtARcam", robot.logs[key])
|
||||||
|
|
||||||
logging.info(" ".join(log_items))
|
logging.info(" ".join(log_items))
|
||||||
|
|
||||||
########################################################################################
|
########################################################################################
|
||||||
|
@ -147,10 +159,12 @@ def log_control_info(robot, dt_s, episode_index=None, frame_index=None):
|
||||||
########################################################################################
|
########################################################################################
|
||||||
|
|
||||||
|
|
||||||
def teleoperate(robot: Robot, fps: int | None = None):
|
def teleoperate(robot: Robot, fps: int | None = None, teleop_time_s: float | None = None):
|
||||||
|
# TODO(rcadene): Add option to record logs
|
||||||
if not robot.is_connected:
|
if not robot.is_connected:
|
||||||
robot.connect()
|
robot.connect()
|
||||||
|
|
||||||
|
start_time = time.perf_counter()
|
||||||
while True:
|
while True:
|
||||||
now = time.perf_counter()
|
now = time.perf_counter()
|
||||||
robot.teleop_step()
|
robot.teleop_step()
|
||||||
|
@ -162,6 +176,9 @@ def teleoperate(robot: Robot, fps: int | None = None):
|
||||||
dt_s = time.perf_counter() - now
|
dt_s = time.perf_counter() - now
|
||||||
log_control_info(robot, dt_s)
|
log_control_info(robot, dt_s)
|
||||||
|
|
||||||
|
if teleop_time_s is not None and time.perf_counter() - start_time > teleop_time_s:
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
def record_dataset(
|
def record_dataset(
|
||||||
robot: Robot,
|
robot: Robot,
|
||||||
|
@ -174,6 +191,8 @@ def record_dataset(
|
||||||
video=True,
|
video=True,
|
||||||
run_compute_stats=True,
|
run_compute_stats=True,
|
||||||
):
|
):
|
||||||
|
# TODO(rcadene): Add option to record logs
|
||||||
|
|
||||||
if not video:
|
if not video:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@ -327,8 +346,11 @@ def record_dataset(
|
||||||
|
|
||||||
# TODO(rcadene): push to hub
|
# TODO(rcadene): push to hub
|
||||||
|
|
||||||
|
return lerobot_dataset
|
||||||
|
|
||||||
|
|
||||||
def replay_episode(robot: Robot, episode: int, fps: int | None = None, root="data", repo_id="lerobot/debug"):
|
def replay_episode(robot: Robot, episode: int, fps: int | None = None, root="data", repo_id="lerobot/debug"):
|
||||||
|
# TODO(rcadene): Add option to record logs
|
||||||
local_dir = Path(root) / repo_id
|
local_dir = Path(root) / repo_id
|
||||||
if not local_dir.exists():
|
if not local_dir.exists():
|
||||||
raise ValueError(local_dir)
|
raise ValueError(local_dir)
|
||||||
|
@ -357,7 +379,8 @@ def replay_episode(robot: Robot, episode: int, fps: int | None = None, root="dat
|
||||||
log_control_info(robot, dt_s)
|
log_control_info(robot, dt_s)
|
||||||
|
|
||||||
|
|
||||||
def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig):
|
def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig, run_time_s: float | None = None):
|
||||||
|
# TODO(rcadene): Add option to record eval dataset and logs
|
||||||
policy.eval()
|
policy.eval()
|
||||||
|
|
||||||
# Check device is available
|
# Check device is available
|
||||||
|
@ -372,6 +395,7 @@ def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig):
|
||||||
if not robot.is_connected:
|
if not robot.is_connected:
|
||||||
robot.connect()
|
robot.connect()
|
||||||
|
|
||||||
|
start_time = time.perf_counter()
|
||||||
while True:
|
while True:
|
||||||
now = time.perf_counter()
|
now = time.perf_counter()
|
||||||
|
|
||||||
|
@ -391,6 +415,9 @@ def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig):
|
||||||
dt_s = time.perf_counter() - now
|
dt_s = time.perf_counter() - now
|
||||||
log_control_info(robot, dt_s)
|
log_control_info(robot, dt_s)
|
||||||
|
|
||||||
|
if run_time_s is not None and time.perf_counter() - start_time > run_time_s:
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
Binary file not shown.
After Width: | Height: | Size: 416 KiB |
Binary file not shown.
After Width: | Height: | Size: 446 KiB |
Binary file not shown.
After Width: | Height: | Size: 318 KiB |
Binary file not shown.
After Width: | Height: | Size: 420 KiB |
|
@ -1,5 +1,6 @@
|
||||||
|
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
@ -7,17 +8,28 @@ from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera
|
||||||
from lerobot.common.robot_devices.utils import RobotDeviceNotConnectedError, RobotDeviceAlreadyConnectedError
|
from lerobot.common.robot_devices.utils import RobotDeviceNotConnectedError, RobotDeviceAlreadyConnectedError
|
||||||
|
|
||||||
|
|
||||||
def test_camera():
|
CAMERA_INDEX = 2
|
||||||
# Test instantiating with missing camera index raises an error
|
# Maximum absolute difference between two consecutive images recored by a camera.
|
||||||
with pytest.raises(ValueError):
|
# This value differs with respect to the camera.
|
||||||
camera = OpenCVCamera()
|
MAX_PIXEL_DIFFERENCE = 25
|
||||||
|
|
||||||
# Test instantiating with a wrong camera index raises an error
|
def compute_max_pixel_difference(first_image, second_image):
|
||||||
with pytest.raises(ValueError):
|
return np.abs(first_image.astype(float) - second_image.astype(float)).max()
|
||||||
camera = OpenCVCamera(-1)
|
|
||||||
|
|
||||||
|
def test_camera():
|
||||||
|
"""Test assumes that `camera.read()` returns the same image when called multiple times in a row.
|
||||||
|
So the environment should not change (you shouldnt be in front of the camera) and the camera should not be moving.
|
||||||
|
|
||||||
|
Warning: The tests worked for a macbookpro camera, but I am getting assertion error (`np.allclose(color_image, async_color_image)`)
|
||||||
|
for my iphone camera and my LG monitor camera.
|
||||||
|
"""
|
||||||
|
# TODO(rcadene): measure fps in nightly?
|
||||||
|
# TODO(rcadene): test logs
|
||||||
|
# TODO(rcadene): add compatibility with other camera APIs
|
||||||
|
|
||||||
# Test instantiating
|
# Test instantiating
|
||||||
camera = OpenCVCamera(0)
|
camera = OpenCVCamera(CAMERA_INDEX)
|
||||||
|
|
||||||
# Test reading, async reading, disconnecting before connecting raises an error
|
# Test reading, async reading, disconnecting before connecting raises an error
|
||||||
with pytest.raises(RobotDeviceNotConnectedError):
|
with pytest.raises(RobotDeviceNotConnectedError):
|
||||||
|
@ -31,7 +43,7 @@ def test_camera():
|
||||||
del camera
|
del camera
|
||||||
|
|
||||||
# Test connecting
|
# Test connecting
|
||||||
camera = OpenCVCamera(0)
|
camera = OpenCVCamera(CAMERA_INDEX)
|
||||||
camera.connect()
|
camera.connect()
|
||||||
assert camera.is_connected
|
assert camera.is_connected
|
||||||
assert camera.fps is not None
|
assert camera.fps is not None
|
||||||
|
@ -50,9 +62,14 @@ def test_camera():
|
||||||
assert c == 3
|
assert c == 3
|
||||||
assert w > h
|
assert w > h
|
||||||
|
|
||||||
# Test reading asynchronously from the camera and image is similar
|
# Test read and async_read outputs similar images
|
||||||
|
# ...warming up as the first frames can be black
|
||||||
|
for _ in range(30):
|
||||||
|
camera.read()
|
||||||
|
color_image = camera.read()
|
||||||
async_color_image = camera.async_read()
|
async_color_image = camera.async_read()
|
||||||
assert np.allclose(color_image, async_color_image)
|
print("max_pixel_difference between read() and async_read()", compute_max_pixel_difference(color_image, async_color_image))
|
||||||
|
assert np.allclose(color_image, async_color_image, rtol=1e-5, atol=MAX_PIXEL_DIFFERENCE)
|
||||||
|
|
||||||
# Test disconnecting
|
# Test disconnecting
|
||||||
camera.disconnect()
|
camera.disconnect()
|
||||||
|
@ -60,27 +77,29 @@ def test_camera():
|
||||||
assert camera.thread is None
|
assert camera.thread is None
|
||||||
|
|
||||||
# Test disconnecting with `__del__`
|
# Test disconnecting with `__del__`
|
||||||
camera = OpenCVCamera(0)
|
camera = OpenCVCamera(CAMERA_INDEX)
|
||||||
camera.connect()
|
camera.connect()
|
||||||
del camera
|
del camera
|
||||||
|
|
||||||
# Test acquiring a bgr image
|
# Test acquiring a bgr image
|
||||||
camera = OpenCVCamera(0, color="bgr")
|
camera = OpenCVCamera(CAMERA_INDEX, color="bgr")
|
||||||
camera.connect()
|
camera.connect()
|
||||||
assert camera.color == "bgr"
|
assert camera.color == "bgr"
|
||||||
bgr_color_image = camera.read()
|
bgr_color_image = camera.read()
|
||||||
assert np.allclose(color_image, bgr_color_image[[2,1,0]])
|
assert np.allclose(color_image, bgr_color_image[:, :, [2,1,0]], rtol=1e-5, atol=MAX_PIXEL_DIFFERENCE)
|
||||||
del camera
|
del camera
|
||||||
|
|
||||||
# Test fps can be set
|
# TODO(rcadene): Add a test for a camera that doesnt support fps=60 and raises an OSError
|
||||||
camera = OpenCVCamera(0, fps=60)
|
# TODO(rcadene): Add a test for a camera that supports fps=60
|
||||||
camera.connect()
|
|
||||||
assert camera.fps == 60
|
# Test fps=10 raises an OSError
|
||||||
# TODO(rcadene): measure fps in nightly?
|
camera = OpenCVCamera(CAMERA_INDEX, fps=10)
|
||||||
|
with pytest.raises(OSError):
|
||||||
|
camera.connect()
|
||||||
del camera
|
del camera
|
||||||
|
|
||||||
# Test width and height can be set
|
# Test width and height can be set
|
||||||
camera = OpenCVCamera(0, fps=30, width=1280, height=720)
|
camera = OpenCVCamera(CAMERA_INDEX, fps=30, width=1280, height=720)
|
||||||
camera.connect()
|
camera.connect()
|
||||||
assert camera.fps == 30
|
assert camera.fps == 30
|
||||||
assert camera.width == 1280
|
assert camera.width == 1280
|
||||||
|
@ -92,7 +111,9 @@ def test_camera():
|
||||||
assert c == 3
|
assert c == 3
|
||||||
del camera
|
del camera
|
||||||
|
|
||||||
|
# Test not supported width and height raise an error
|
||||||
|
camera = OpenCVCamera(CAMERA_INDEX, fps=30, width=0, height=0)
|
||||||
|
with pytest.raises(OSError):
|
||||||
|
camera.connect()
|
||||||
|
del camera
|
||||||
|
|
||||||
|
|
|
@ -1,13 +1,47 @@
|
||||||
|
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from lerobot.common.policies.factory import make_policy
|
||||||
|
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||||
|
from lerobot.common.utils.utils import init_hydra_config
|
||||||
|
from lerobot.scripts.control_robot import record_dataset, replay_episode, run_policy, teleoperate
|
||||||
|
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE
|
||||||
|
|
||||||
|
|
||||||
def test_teleoperate():
|
def test_teleoperate():
|
||||||
pass
|
robot = make_robot("koch")
|
||||||
|
teleoperate(robot, teleop_time_s=1)
|
||||||
|
teleoperate(robot, fps=30, teleop_time_s=1)
|
||||||
|
teleoperate(robot, fps=60, teleop_time_s=1)
|
||||||
|
del robot
|
||||||
|
|
||||||
def test_record_dataset():
|
|
||||||
pass
|
|
||||||
|
|
||||||
def test_replay_episode():
|
def test_record_dataset_and_replay_episode_and_run_policy(tmpdir):
|
||||||
pass
|
robot_name = "koch"
|
||||||
|
env_name = "koch_real"
|
||||||
|
policy_name = "act_real"
|
||||||
|
|
||||||
|
#root = Path(tmpdir)
|
||||||
|
root = Path("tmp/data")
|
||||||
|
repo_id = "lerobot/debug"
|
||||||
|
|
||||||
|
robot = make_robot(robot_name)
|
||||||
|
dataset = record_dataset(robot, fps=30, root=root, repo_id=repo_id, warmup_time_s=2, episode_time_s=2, num_episodes=2)
|
||||||
|
|
||||||
|
replay_episode(robot, episode=0, fps=30, root=root, repo_id=repo_id)
|
||||||
|
|
||||||
|
cfg = init_hydra_config(
|
||||||
|
DEFAULT_CONFIG_PATH,
|
||||||
|
overrides=[
|
||||||
|
f"env={env_name}",
|
||||||
|
f"policy={policy_name}",
|
||||||
|
f"device={DEVICE}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
|
||||||
|
|
||||||
|
run_policy(robot, policy, cfg, run_time_s=1)
|
||||||
|
|
||||||
|
del robot
|
||||||
|
|
||||||
def test_run_policy():
|
|
||||||
pass
|
|
|
@ -3,11 +3,18 @@ import time
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
from lerobot.common.robot_devices.motors.dynamixel import DynamixelMotorsBus
|
from lerobot.common.robot_devices.motors.dynamixel import DynamixelMotorsBus
|
||||||
|
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
|
||||||
|
|
||||||
|
|
||||||
def test_motors_bus():
|
def test_motors_bus():
|
||||||
|
# TODO(rcadene): measure fps in nightly?
|
||||||
|
# TODO(rcadene): test logs
|
||||||
|
# TODO(rcadene): test calibration
|
||||||
|
# TODO(rcadene): add compatibility with other motors bus
|
||||||
|
|
||||||
# Test instantiating a common motors structure.
|
# Test instantiating a common motors structure.
|
||||||
# Here the one from Alexander Koch follower arm.
|
# Here the one from Alexander Koch follower arm.
|
||||||
|
port = "/dev/tty.usbmodem575E0032081"
|
||||||
motors = {
|
motors = {
|
||||||
# name: (index, model)
|
# name: (index, model)
|
||||||
"shoulder_pan": (1, "xl430-w250"),
|
"shoulder_pan": (1, "xl430-w250"),
|
||||||
|
@ -17,24 +24,29 @@ def test_motors_bus():
|
||||||
"wrist_roll": (5, "xl330-m288"),
|
"wrist_roll": (5, "xl330-m288"),
|
||||||
"gripper": (6, "xl330-m288"),
|
"gripper": (6, "xl330-m288"),
|
||||||
}
|
}
|
||||||
motors_bus = DynamixelMotorsBus(
|
motors_bus = DynamixelMotorsBus(port, motors)
|
||||||
port="/dev/tty.usbmodem575E0032081",
|
|
||||||
motors=motors,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Test reading and writting before connecting raises an error
|
# Test reading and writting before connecting raises an error
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(RobotDeviceNotConnectedError):
|
||||||
motors_bus.read("Torque_Enable")
|
motors_bus.read("Torque_Enable")
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(RobotDeviceNotConnectedError):
|
||||||
motors_bus.write("Torque_Enable")
|
motors_bus.write("Torque_Enable", 1)
|
||||||
|
with pytest.raises(RobotDeviceNotConnectedError):
|
||||||
|
motors_bus.disconnect()
|
||||||
|
|
||||||
|
# Test deleting the object without connecting first
|
||||||
|
del motors_bus
|
||||||
|
|
||||||
|
# Test connecting
|
||||||
|
motors_bus = DynamixelMotorsBus(port, motors)
|
||||||
motors_bus.connect()
|
motors_bus.connect()
|
||||||
|
|
||||||
# Test connecting twice raises an error
|
# Test connecting twice raises an error
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(RobotDeviceAlreadyConnectedError):
|
||||||
motors_bus.connect()
|
motors_bus.connect()
|
||||||
|
|
||||||
# Test reading torque on all motors and its 0 after first connection
|
# Test disabling torque and reading torque on all motors
|
||||||
|
motors_bus.write("Torque_Enable", 0)
|
||||||
values = motors_bus.read("Torque_Enable")
|
values = motors_bus.read("Torque_Enable")
|
||||||
assert isinstance(values, np.ndarray)
|
assert isinstance(values, np.ndarray)
|
||||||
assert len(values) == len(motors)
|
assert len(values) == len(motors)
|
||||||
|
@ -67,7 +79,5 @@ def test_motors_bus():
|
||||||
# Give time for the motors to move to the goal position
|
# Give time for the motors to move to the goal position
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
new_values = motors_bus.read("Present_Position")
|
new_values = motors_bus.read("Present_Position")
|
||||||
assert new_values == values
|
assert (new_values == values).all()
|
||||||
|
|
||||||
# TODO(rcadene): test calibration
|
|
||||||
# TODO(rcadene): test logs?
|
|
||||||
|
|
|
@ -0,0 +1,108 @@
|
||||||
|
from pathlib import Path
|
||||||
|
import pickle
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||||
|
from lerobot.common.robot_devices.robots.koch import KochRobot
|
||||||
|
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
|
||||||
|
|
||||||
|
|
||||||
|
def test_robot(tmpdir):
|
||||||
|
# TODO(rcadene): measure fps in nightly?
|
||||||
|
# TODO(rcadene): test logs
|
||||||
|
# TODO(rcadene): add compatibility with other robots
|
||||||
|
|
||||||
|
# Save calibration preset
|
||||||
|
calibration = {'follower_main': {'shoulder_pan': (-2048, False), 'shoulder_lift': (2048, True), 'elbow_flex': (-1024, False), 'wrist_flex': (2048, True), 'wrist_roll': (2048, True), 'gripper': (2048, True)}, 'leader_main': {'shoulder_pan': (-2048, False), 'shoulder_lift': (1024, True), 'elbow_flex': (2048, True), 'wrist_flex': (-2048, False), 'wrist_roll': (2048, True), 'gripper': (2048, True)}}
|
||||||
|
tmpdir = Path(tmpdir)
|
||||||
|
calibration_path = tmpdir / "calibration.pkl"
|
||||||
|
calibration_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(calibration_path, "wb") as f:
|
||||||
|
pickle.dump(calibration, f)
|
||||||
|
|
||||||
|
# Test connecting without devices raises an error
|
||||||
|
robot = KochRobot()
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
robot.connect()
|
||||||
|
del robot
|
||||||
|
|
||||||
|
# Test using robot before connecting raises an error
|
||||||
|
robot = KochRobot()
|
||||||
|
with pytest.raises(RobotDeviceNotConnectedError):
|
||||||
|
robot.teleop_step()
|
||||||
|
with pytest.raises(RobotDeviceNotConnectedError):
|
||||||
|
robot.teleop_step(record_data=True)
|
||||||
|
with pytest.raises(RobotDeviceNotConnectedError):
|
||||||
|
robot.capture_observation()
|
||||||
|
with pytest.raises(RobotDeviceNotConnectedError):
|
||||||
|
robot.send_action(None)
|
||||||
|
with pytest.raises(RobotDeviceNotConnectedError):
|
||||||
|
robot.disconnect()
|
||||||
|
|
||||||
|
# Test deleting the object without connecting first
|
||||||
|
del robot
|
||||||
|
|
||||||
|
# Test connecting
|
||||||
|
robot = make_robot("koch")
|
||||||
|
# TODO(rcadene): proper monkey patch
|
||||||
|
robot.calibration_path = calibration_path
|
||||||
|
robot.connect() # run the manual calibration precedure
|
||||||
|
assert robot.is_connected
|
||||||
|
|
||||||
|
# Test connecting twice raises an error
|
||||||
|
with pytest.raises(RobotDeviceAlreadyConnectedError):
|
||||||
|
robot.connect()
|
||||||
|
|
||||||
|
# Test disconnecting with `__del__`
|
||||||
|
del robot
|
||||||
|
|
||||||
|
# Test teleop can run
|
||||||
|
robot = make_robot("koch")
|
||||||
|
robot.calibration_path = calibration_path
|
||||||
|
robot.connect()
|
||||||
|
robot.teleop_step()
|
||||||
|
|
||||||
|
# Test data recorded during teleop are well formated
|
||||||
|
observation, action = robot.teleop_step(record_data=True)
|
||||||
|
# State
|
||||||
|
assert "observation.state" in observation
|
||||||
|
assert isinstance(observation["observation.state"], torch.Tensor)
|
||||||
|
assert observation["observation.state"].ndim == 1
|
||||||
|
dim_state = sum(len(robot.follower_arms[name].motors) for name in robot.follower_arms)
|
||||||
|
assert observation["observation.state"].shape[0] == dim_state
|
||||||
|
# Cameras
|
||||||
|
for name in robot.cameras:
|
||||||
|
assert f"observation.images.{name}" in observation
|
||||||
|
assert isinstance(observation[f"observation.images.{name}"], torch.Tensor)
|
||||||
|
assert observation[f"observation.images.{name}"].ndim == 3
|
||||||
|
# Action
|
||||||
|
assert "action" in action
|
||||||
|
assert isinstance(action["action"], torch.Tensor)
|
||||||
|
assert action["action"].ndim == 1
|
||||||
|
dim_action = sum(len(robot.follower_arms[name].motors) for name in robot.follower_arms)
|
||||||
|
assert action["action"].shape[0] == dim_action
|
||||||
|
# TODO(rcadene): test if observation and action data are returned as expected
|
||||||
|
|
||||||
|
# Test capture_observation can run and observation returned are the same (since the arm didnt move)
|
||||||
|
captured_observation = robot.capture_observation()
|
||||||
|
assert set(captured_observation.keys()) == set(observation.keys())
|
||||||
|
for name in captured_observation:
|
||||||
|
if "image" in name:
|
||||||
|
# TODO(rcadene): skipping image for now as it's challenging to assess equality between two consecutive frames
|
||||||
|
continue
|
||||||
|
assert torch.allclose(captured_observation[name], observation[name], atol=1)
|
||||||
|
|
||||||
|
# Test send_action can run
|
||||||
|
robot.send_action(action["action"])
|
||||||
|
|
||||||
|
# Test disconnecting
|
||||||
|
robot.disconnect()
|
||||||
|
assert not robot.is_connected
|
||||||
|
for name in robot.follower_arms:
|
||||||
|
assert not robot.follower_arms[name].is_connected
|
||||||
|
for name in robot.leader_arms:
|
||||||
|
assert not robot.leader_arms[name].is_connected
|
||||||
|
for name in robot.cameras:
|
||||||
|
assert not robot.cameras[name].is_connected
|
||||||
|
del robot
|
||||||
|
|
Loading…
Reference in New Issue