Rewrite MotorsBus

This commit is contained in:
Simon Alibert 2025-03-19 18:44:05 +01:00
parent c85a9253e7
commit 9358d334c7
1 changed files with 289 additions and 82 deletions

View File

@ -20,13 +20,17 @@
# ruff: noqa: N802 # ruff: noqa: N802
import abc import abc
import json
import time import time
import traceback
from enum import Enum from enum import Enum
from functools import cached_property
from pathlib import Path
from pprint import pformat
from typing import Protocol from typing import Protocol
import numpy as np import serial
import tqdm import tqdm
from deepdiff import DeepDiff
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.common.utils.utils import capture_timestamp_utc from lerobot.common.utils.utils import capture_timestamp_utc
@ -34,8 +38,8 @@ from lerobot.common.utils.utils import capture_timestamp_utc
MAX_ID_RANGE = 252 MAX_ID_RANGE = 252
def get_group_sync_key(data_name: str, motor_names: list[str]) -> str: def get_group_sync_key(data_name: str, motor_ids: list[int]) -> str:
group_key = f"{data_name}_" + "_".join(motor_names) group_key = f"{data_name}_" + "_".join([str(idx) for idx in motor_ids])
return group_key return group_key
@ -98,7 +102,7 @@ class PortHandler(Protocol):
self.tx_time_per_byte: float self.tx_time_per_byte: float
self.is_using: bool self.is_using: bool
self.port_name: str self.port_name: str
self.ser: object self.ser: serial.Serial
def openPort(self): ... def openPort(self): ...
def closePort(self): ... def closePort(self): ...
@ -153,6 +157,46 @@ class PacketHandler(Protocol):
def syncWriteTxOnly(self, port, start_address, data_length, param, param_length): ... def syncWriteTxOnly(self, port, start_address, data_length, param, param_length): ...
class GroupSyncRead(Protocol):
def __init__(self, port, ph, start_address, data_length):
self.port: str
self.ph: PortHandler
self.start_address: int
self.data_length: int
self.last_result: bool
self.is_param_changed: bool
self.param: list
self.data_dict: dict
def makeParam(self): ...
def addParam(self, id): ...
def removeParam(self, id): ...
def clearParam(self): ...
def txPacket(self): ...
def rxPacket(self): ...
def txRxPacket(self): ...
def isAvailable(self, id, address, data_length): ...
def getData(self, id, address, data_length): ...
class GroupSyncWrite(Protocol):
def __init__(self, port, ph, start_address, data_length):
self.port: str
self.ph: PortHandler
self.start_address: int
self.data_length: int
self.is_param_changed: bool
self.param: list
self.data_dict: dict
def makeParam(self): ...
def addParam(self, id, data): ...
def removeParam(self, id): ...
def changeParam(self, id, data): ...
def clearParam(self): ...
def txPacket(self): ...
class MotorsBus(abc.ABC): class MotorsBus(abc.ABC):
"""The main LeRobot class for implementing motors buses. """The main LeRobot class for implementing motors buses.
@ -163,7 +207,7 @@ class MotorsBus(abc.ABC):
Note: This class may evolve in the future should we add support for other manufacturers SDKs. Note: This class may evolve in the future should we add support for other manufacturers SDKs.
A MotorsBus allows to efficiently read and write to the attached motors. A MotorsBus allows to efficiently read and write to the attached motors.
It represents a several motors daisy-chained together and connected through a serial port. It represents several motors daisy-chained together and connected through a serial port.
A MotorsBus subclass instance requires a port (e.g. `FeetechMotorsBus(port="/dev/tty.usbmodem575E0031751"`)). A MotorsBus subclass instance requires a port (e.g. `FeetechMotorsBus(port="/dev/tty.usbmodem575E0031751"`)).
To find the port, you can run our utility script: To find the port, you can run our utility script:
@ -198,6 +242,8 @@ class MotorsBus(abc.ABC):
model_ctrl_table: dict[str, dict] model_ctrl_table: dict[str, dict]
model_resolution_table: dict[str, int] model_resolution_table: dict[str, int]
model_baudrate_table: dict[str, dict] model_baudrate_table: dict[str, dict]
calibration_required: list[str]
default_timeout: int
def __init__( def __init__(
self, self,
@ -206,73 +252,121 @@ class MotorsBus(abc.ABC):
): ):
self.port = port self.port = port
self.motors = motors self.motors = motors
self.port_handler: PortHandler | None = None self._validate_motors()
self.packet_handler: PacketHandler | None = None
self.group_readers = {} self.port_handler: PortHandler
self.group_writers = {} self.packet_handler: PacketHandler
self.logs = {} self.reader: GroupSyncRead
self.writer: GroupSyncWrite
self.logs = {} # TODO(aliberts): use subclass logger
self.calibration = None self.calibration = None
self.is_connected: bool = False
self._id_to_model = dict(self.motors.values())
self._id_to_name = {idx: name for name, (idx, _) in self.motors.items()}
def __len__(self): def __len__(self):
return len(self.motors) return len(self.motors)
@property def __repr__(self):
return (
f"{self.__class__.__name__}(\n"
f" Port: '{self.port}',\n"
f" Motors: \n{pformat(self.motors, indent=8)},\n"
")',\n"
)
@cached_property
def _has_different_ctrl_tables(self) -> bool:
if len(self.motor_models) < 2:
return False
first_table = self.model_ctrl_table[self.motor_models[0]]
return any(DeepDiff(first_table, self.model_ctrl_table[model]) for model in self.motor_models[1:])
def idx_to_model(self, idx: int) -> str:
return self._id_to_model[idx]
def idx_to_name(self, idx: int) -> str:
return self._id_to_name[idx]
@cached_property
def motor_names(self) -> list[str]: def motor_names(self) -> list[str]:
return list(self.motors) return list(self.motors)
@property @cached_property
def motor_models(self) -> list[str]: def motor_models(self) -> list[str]:
return [model for _, model in self.motors.values()] return [model for _, model in self.motors.values()]
@property @cached_property
def motor_indices(self) -> list[int]: def motor_ids(self) -> list[int]:
return [idx for idx, _ in self.motors.values()] return [idx for idx, _ in self.motors.values()]
def connect(self): def _validate_motors(self) -> None:
# TODO(aliberts): improve error messages for this (display problematics values)
if len(self.motor_ids) != len(set(self.motor_ids)):
raise ValueError("Some motors have the same id.")
if len(self.motor_names) != len(set(self.motor_names)):
raise ValueError("Some motors have the same name.")
if any(model not in self.model_resolution_table for model in self.motor_models):
raise ValueError("Some motors models are not available.")
@property
def is_connected(self) -> bool:
return self.port_handler.is_open
def connect(self) -> None:
if self.is_connected: if self.is_connected:
raise DeviceAlreadyConnectedError( raise DeviceAlreadyConnectedError(
f"{self.__name__}({self.port}) is already connected. Do not call `{self.__name__}.connect()` twice." f"{self.__class__.__name__}('{self.port}') is already connected. Do not call `{self.__class__.__name__}.connect()` twice."
) )
self._set_handlers()
try: try:
if not self.port_handler.openPort(): if not self.port_handler.openPort():
raise OSError(f"Failed to open port '{self.port}'.") raise OSError(f"Failed to open port '{self.port}'.")
except Exception: except (FileNotFoundError, OSError, serial.SerialException) as e:
traceback.print_exc()
print( print(
"\nTry running `python lerobot/scripts/find_motors_bus_port.py` to make sure you are using the correct port.\n" f"\nCould not connect on port '{self.port}'. Make sure you are using the correct port."
"\nTry running `python lerobot/scripts/find_motors_bus_port.py`\n"
) )
raise raise e
self._set_timeout() self.set_timeout()
# Allow to read and write def set_timeout(self, timeout_ms: int | None = None):
self.is_connected = True timeout_ms = timeout_ms if timeout_ms is not None else self.default_timeout
self.port_handler.setPacketTimeoutMillis(timeout_ms)
@abc.abstractmethod @property
def _set_handlers(self): def are_motors_configured(self) -> bool:
pass
@abc.abstractmethod
def _set_timeout(self, timeout: int):
pass
def are_motors_configured(self):
""" """
Only check the motor indices and not baudrate, since if the motor baudrates are incorrect, a Only check the motor indices and not baudrate, since if the motor baudrates are incorrect, a
ConnectionError will be raised anyway. ConnectionError will be raised anyway.
""" """
try: try:
return (self.motor_indices == self.read("ID")).all() # TODO(aliberts): use ping instead
return (self.motor_ids == self.read("ID")).all()
except ConnectionError as e: except ConnectionError as e:
print(e) print(e)
return False return False
def ping(self, motor: str | int, num_retry: int | None = None) -> int:
idx = self.get_safe_id(motor)
for _ in range(num_retry):
model_number, comm, _ = self.packet_handler.ping(self.port, idx)
if self._is_comm_success(comm):
return model_number
# TODO(aliberts): Should we?
return comm
@abc.abstractmethod
def broadcast_ping(self, num_retry: int | None = None):
...
# TODO(aliberts): this will replace 'find_motor_indices'
def find_motor_indices(self, possible_ids: list[str] = None, num_retry: int = 2): def find_motor_indices(self, possible_ids: list[str] = None, num_retry: int = 2):
if possible_ids is None: if possible_ids is None:
possible_ids = range(MAX_ID_RANGE) possible_ids = range(MAX_ID_RANGE)
@ -280,7 +374,7 @@ class MotorsBus(abc.ABC):
indices = [] indices = []
for idx in tqdm.tqdm(possible_ids): for idx in tqdm.tqdm(possible_ids):
try: try:
present_idx = self.read_with_motor_ids(self.motor_models, [idx], "ID", num_retry=num_retry)[0] present_idx = self.read("ID", idx, num_retry=num_retry)[0]
except ConnectionError: except ConnectionError:
continue continue
@ -294,7 +388,7 @@ class MotorsBus(abc.ABC):
return indices return indices
def set_baudrate(self, baudrate): def set_baudrate(self, baudrate) -> None:
present_bus_baudrate = self.port_handler.getBaudRate() present_bus_baudrate = self.port_handler.getBaudRate()
if present_bus_baudrate != baudrate: if present_bus_baudrate != baudrate:
print(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.") print(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.")
@ -303,94 +397,207 @@ class MotorsBus(abc.ABC):
if self.port_handler.getBaudRate() != baudrate: if self.port_handler.getBaudRate() != baudrate:
raise OSError("Failed to write bus baud rate.") raise OSError("Failed to write bus baud rate.")
def set_calibration(self, calibration_dict: dict[str, list]): def set_calibration(self, calibration_fpath: Path) -> None:
self.calibration = calibration_dict with open(calibration_fpath) as f:
calibration = json.load(f)
self.calibration = {int(idx): val for idx, val in calibration.items()}
@abc.abstractmethod @abc.abstractmethod
def apply_calibration(self): def calibrate_values(self, ids_values: dict[int, int]) -> dict[int, float]:
pass pass
@abc.abstractmethod @abc.abstractmethod
def revert_calibration(self): def uncalibrate_values(self, ids_values: dict[int, float]) -> dict[int, int]:
pass pass
def read(self, data_name, motor_names: str | list[str] | None = None): @abc.abstractmethod
def _is_comm_success(self, comm: int) -> bool:
pass
@staticmethod
@abc.abstractmethod
def split_int_bytes(value: int, n_bytes: int) -> list[int]:
"""
Splits an unsigned integer into a list of bytes in little-endian order.
This function extracts the individual bytes of an integer based on the
specified number of bytes (`n_bytes`). The output is a list of integers,
each representing a byte (0-255).
**Byte order:** The function returns bytes in **little-endian format**,
meaning the least significant byte (LSB) comes first.
Args:
value (int): The unsigned integer to be converted into a byte list. Must be within
the valid range for the specified `n_bytes`.
n_bytes (int): The number of bytes to use for conversion. Supported values:
- 1 (for values 0 to 255)
- 2 (for values 0 to 65,535)
- 4 (for values 0 to 4,294,967,295)
Raises:
ValueError: If `value` is negative or exceeds the maximum allowed for `n_bytes`.
NotImplementedError: If `n_bytes` is not 1, 2, or 4.
Returns:
list[int]: A list of integers, each representing a byte in **little-endian order**.
Examples:
>>> split_int_bytes(0x12, 1)
[18]
>>> split_int_bytes(0x1234, 2)
[52, 18] # 0x1234 → 0x34 0x12 (little-endian)
>>> split_int_bytes(0x12345678, 4)
[120, 86, 52, 18] # 0x12345678 → 0x78 0x56 0x34 0x12
"""
pass
def get_safe_id(self, motor: str | int) -> int:
if isinstance(motor, str):
return self.motors[motor][0]
elif isinstance(motor, int):
return motor
else:
raise ValueError(f"{motor} should be int or str.")
def read(
self, data_name: str, motors: str | int | list[str | int] | None = None, num_retry: int = 1
) -> dict[str, float]:
if not self.is_connected: if not self.is_connected:
raise DeviceNotConnectedError( raise DeviceNotConnectedError(
f"{self.__name__}({self.port}) is not connected. You need to run `{self.__name__}.connect()`." f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
) )
start_time = time.perf_counter() start_time = time.perf_counter()
if motor_names is None:
motor_names = self.motor_names
if isinstance(motor_names, str): if motors is None:
motor_names = [motor_names] motors = self.motor_ids
values = self._read(data_name, motor_names) if isinstance(motors, (str, int)):
motors = [motors]
motor_ids = [self.get_safe_id(motor) for motor in motors]
if self._has_different_ctrl_tables:
models = [self.idx_to_model(idx) for idx in motor_ids]
assert_same_address(self.model_ctrl_table, models, data_name)
model = self.idx_to_model(next(iter(motor_ids)))
addr, n_bytes = self.model_ctrl_table[model][data_name]
comm, ids_values = self._read(motor_ids, addr, n_bytes, num_retry)
if not self._is_comm_success(comm):
raise ConnectionError(
f"Failed to read {data_name} on port {self.port} for ids {motor_ids}:"
f"{self.packet_handler.getTxRxResult(comm)}"
)
if data_name in self.calibration_required and self.calibration is not None:
ids_values = self.calibrate_values(ids_values)
# TODO(aliberts): return keys in the same format we got them?
ids_values = {self.idx_to_name(idx): val for idx, val in ids_values.items()}
# log the number of seconds it took to read the data from the motors # log the number of seconds it took to read the data from the motors
delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_names) delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_ids)
self.logs[delta_ts_name] = time.perf_counter() - start_time self.logs[delta_ts_name] = time.perf_counter() - start_time
# log the utc time at which the data was received # log the utc time at which the data was received
ts_utc_name = get_log_name("timestamp_utc", "read", data_name, motor_names) ts_utc_name = get_log_name("timestamp_utc", "read", data_name, motor_ids)
self.logs[ts_utc_name] = capture_timestamp_utc() self.logs[ts_utc_name] = capture_timestamp_utc()
return values return ids_values
@abc.abstractmethod def _read(
def _read(self, data_name: str, motor_names: list[str]): self, motor_ids: list[str], address: int, n_bytes: int, num_retry: int = 1
pass ) -> tuple[int, dict[int, int]]:
self.reader.clearParam()
self.reader.start_address = address
self.reader.data_length = n_bytes
def write( # FIXME(aliberts, pkooij): We should probably not have to do this.
self, data_name: str, values: int | float | np.ndarray, motor_names: str | list[str] | None = None # Let's try to see if we can do with better comm status handling instead.
) -> None: # self.port_handler.ser.reset_output_buffer()
# self.port_handler.ser.reset_input_buffer()
for idx in motor_ids:
self.reader.addParam(idx)
for _ in range(num_retry):
comm = self.reader.txRxPacket()
if self._is_comm_success(comm):
break
values = {idx: self.reader.getData(idx, address, n_bytes) for idx in motor_ids}
return comm, values
# TODO(aliberts, pkooij): Implementing something like this could get much faster read times.
# Note: this could be at the cost of increase latency between the moment the data is produced by the
# motors and the moment it is used by a policy
# def _async_read(self, motor_ids: list[str], address: int, n_bytes: int):
# self.reader.rxPacket()
# self.reader.txPacket()
# for idx in motor_ids:
# value = self.reader.getData(idx, address, n_bytes)
def write(self, data_name: str, values_dict: dict[str | int, int], num_retry: int = 1) -> None:
if not self.is_connected: if not self.is_connected:
raise DeviceNotConnectedError( raise DeviceNotConnectedError(
f"{self.__name__}({self.port}) is not connected. You need to run `{self.__name__}.connect()`." f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
) )
start_time = time.perf_counter() start_time = time.perf_counter()
if motor_names is None: ids_values = {self.get_safe_id(motor): val for motor, val in values_dict.items()}
motor_names = self.motor_names
if isinstance(motor_names, str): if self._has_different_ctrl_tables:
motor_names = [motor_names] models = [self.idx_to_model(idx) for idx in ids_values]
assert_same_address(self.model_ctrl_table, models, data_name)
if isinstance(values, (int, float, np.integer)): if data_name in self.calibration_required and self.calibration is not None:
values = [int(values)] * len(motor_names) ids_values = self.uncalibrate_values(ids_values)
self._write(data_name, values, motor_names) model = self.idx_to_model(next(iter(ids_values)))
addr, n_bytes = self.model_ctrl_table[model][data_name]
comm = self._write(ids_values, addr, n_bytes, num_retry)
if not self._is_comm_success(comm):
raise ConnectionError(
f"Failed to write {data_name} on port {self.port} for ids {list(ids_values)}:"
f"{self.packet_handler.getTxRxResult(comm)}"
)
# log the number of seconds it took to write the data to the motors # log the number of seconds it took to write the data to the motors
delta_ts_name = get_log_name("delta_timestamp_s", "write", data_name, motor_names) delta_ts_name = get_log_name("delta_timestamp_s", "write", data_name, list(ids_values))
self.logs[delta_ts_name] = time.perf_counter() - start_time self.logs[delta_ts_name] = time.perf_counter() - start_time
# TODO(rcadene): should we log the time before sending the write command?
# log the utc time when the write has been completed # log the utc time when the write has been completed
ts_utc_name = get_log_name("timestamp_utc", "write", data_name, motor_names) ts_utc_name = get_log_name("timestamp_utc", "write", data_name, list(ids_values))
self.logs[ts_utc_name] = capture_timestamp_utc() self.logs[ts_utc_name] = capture_timestamp_utc()
@abc.abstractmethod def _write(self, ids_values: dict[int, int], address: int, n_bytes: int, num_retry: int = 1) -> int:
def _write(self, data_name: str, values: list[int], motor_names: list[str]) -> None: self.writer.clearParam()
pass self.writer.start_address = address
self.writer.data_length = n_bytes
for idx, value in ids_values.items():
data = self.split_int_bytes(value, n_bytes)
self.writer.addParam(idx, data)
for _ in range(num_retry):
comm = self.writer.txPacket()
if self._is_comm_success(comm):
break
return comm
def disconnect(self) -> None: def disconnect(self) -> None:
if not self.is_connected: if not self.is_connected:
raise DeviceNotConnectedError( raise DeviceNotConnectedError(
f"{self.__name__}({self.port}) is not connected. Try running `{self.__name__}.connect()` first." f"{self.__class__.__name__}('{self.port}') is not connected. Try running `{self.__class__.__name__}.connect()` first."
) )
if self.port_handler is not None: self.port_handler.closePort()
self.port_handler.closePort()
self.port_handler = None
self.packet_handler = None
self.group_readers = {}
self.group_writers = {}
self.is_connected = False
def __del__(self): def __del__(self):
if self.is_connected: if self.is_connected: