diff --git a/lerobot/common/motors/motors_bus.py b/lerobot/common/motors/motors_bus.py index 26ae3836..24683730 100644 --- a/lerobot/common/motors/motors_bus.py +++ b/lerobot/common/motors/motors_bus.py @@ -20,13 +20,17 @@ # ruff: noqa: N802 import abc +import json import time -import traceback from enum import Enum +from functools import cached_property +from pathlib import Path +from pprint import pformat from typing import Protocol -import numpy as np +import serial import tqdm +from deepdiff import DeepDiff from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from lerobot.common.utils.utils import capture_timestamp_utc @@ -34,8 +38,8 @@ from lerobot.common.utils.utils import capture_timestamp_utc MAX_ID_RANGE = 252 -def get_group_sync_key(data_name: str, motor_names: list[str]) -> str: - group_key = f"{data_name}_" + "_".join(motor_names) +def get_group_sync_key(data_name: str, motor_ids: list[int]) -> str: + group_key = f"{data_name}_" + "_".join([str(idx) for idx in motor_ids]) return group_key @@ -98,7 +102,7 @@ class PortHandler(Protocol): self.tx_time_per_byte: float self.is_using: bool self.port_name: str - self.ser: object + self.ser: serial.Serial def openPort(self): ... def closePort(self): ... @@ -153,6 +157,46 @@ class PacketHandler(Protocol): def syncWriteTxOnly(self, port, start_address, data_length, param, param_length): ... +class GroupSyncRead(Protocol): + def __init__(self, port, ph, start_address, data_length): + self.port: str + self.ph: PortHandler + self.start_address: int + self.data_length: int + self.last_result: bool + self.is_param_changed: bool + self.param: list + self.data_dict: dict + + def makeParam(self): ... + def addParam(self, id): ... + def removeParam(self, id): ... + def clearParam(self): ... + def txPacket(self): ... + def rxPacket(self): ... + def txRxPacket(self): ... + def isAvailable(self, id, address, data_length): ... + def getData(self, id, address, data_length): ... + + +class GroupSyncWrite(Protocol): + def __init__(self, port, ph, start_address, data_length): + self.port: str + self.ph: PortHandler + self.start_address: int + self.data_length: int + self.is_param_changed: bool + self.param: list + self.data_dict: dict + + def makeParam(self): ... + def addParam(self, id, data): ... + def removeParam(self, id): ... + def changeParam(self, id, data): ... + def clearParam(self): ... + def txPacket(self): ... + + class MotorsBus(abc.ABC): """The main LeRobot class for implementing motors buses. @@ -163,7 +207,7 @@ class MotorsBus(abc.ABC): Note: This class may evolve in the future should we add support for other manufacturers SDKs. A MotorsBus allows to efficiently read and write to the attached motors. - It represents a several motors daisy-chained together and connected through a serial port. + It represents several motors daisy-chained together and connected through a serial port. A MotorsBus subclass instance requires a port (e.g. `FeetechMotorsBus(port="/dev/tty.usbmodem575E0031751"`)). To find the port, you can run our utility script: @@ -198,6 +242,8 @@ class MotorsBus(abc.ABC): model_ctrl_table: dict[str, dict] model_resolution_table: dict[str, int] model_baudrate_table: dict[str, dict] + calibration_required: list[str] + default_timeout: int def __init__( self, @@ -206,73 +252,121 @@ class MotorsBus(abc.ABC): ): self.port = port self.motors = motors - self.port_handler: PortHandler | None = None - self.packet_handler: PacketHandler | None = None + self._validate_motors() - self.group_readers = {} - self.group_writers = {} - self.logs = {} + self.port_handler: PortHandler + self.packet_handler: PacketHandler + self.reader: GroupSyncRead + self.writer: GroupSyncWrite + self.logs = {} # TODO(aliberts): use subclass logger self.calibration = None - self.is_connected: bool = False + + self._id_to_model = dict(self.motors.values()) + self._id_to_name = {idx: name for name, (idx, _) in self.motors.items()} def __len__(self): return len(self.motors) - @property + def __repr__(self): + return ( + f"{self.__class__.__name__}(\n" + f" Port: '{self.port}',\n" + f" Motors: \n{pformat(self.motors, indent=8)},\n" + ")',\n" + ) + + @cached_property + def _has_different_ctrl_tables(self) -> bool: + if len(self.motor_models) < 2: + return False + + first_table = self.model_ctrl_table[self.motor_models[0]] + return any(DeepDiff(first_table, self.model_ctrl_table[model]) for model in self.motor_models[1:]) + + def idx_to_model(self, idx: int) -> str: + return self._id_to_model[idx] + + def idx_to_name(self, idx: int) -> str: + return self._id_to_name[idx] + + @cached_property def motor_names(self) -> list[str]: return list(self.motors) - @property + @cached_property def motor_models(self) -> list[str]: return [model for _, model in self.motors.values()] - @property - def motor_indices(self) -> list[int]: + @cached_property + def motor_ids(self) -> list[int]: return [idx for idx, _ in self.motors.values()] - def connect(self): + def _validate_motors(self) -> None: + # TODO(aliberts): improve error messages for this (display problematics values) + if len(self.motor_ids) != len(set(self.motor_ids)): + raise ValueError("Some motors have the same id.") + + if len(self.motor_names) != len(set(self.motor_names)): + raise ValueError("Some motors have the same name.") + + if any(model not in self.model_resolution_table for model in self.motor_models): + raise ValueError("Some motors models are not available.") + + @property + def is_connected(self) -> bool: + return self.port_handler.is_open + + def connect(self) -> None: if self.is_connected: raise DeviceAlreadyConnectedError( - f"{self.__name__}({self.port}) is already connected. Do not call `{self.__name__}.connect()` twice." + f"{self.__class__.__name__}('{self.port}') is already connected. Do not call `{self.__class__.__name__}.connect()` twice." ) - self._set_handlers() - try: if not self.port_handler.openPort(): raise OSError(f"Failed to open port '{self.port}'.") - except Exception: - traceback.print_exc() + except (FileNotFoundError, OSError, serial.SerialException) as e: print( - "\nTry running `python lerobot/scripts/find_motors_bus_port.py` to make sure you are using the correct port.\n" + f"\nCould not connect on port '{self.port}'. Make sure you are using the correct port." + "\nTry running `python lerobot/scripts/find_motors_bus_port.py`\n" ) - raise + raise e - self._set_timeout() + self.set_timeout() - # Allow to read and write - self.is_connected = True + def set_timeout(self, timeout_ms: int | None = None): + timeout_ms = timeout_ms if timeout_ms is not None else self.default_timeout + self.port_handler.setPacketTimeoutMillis(timeout_ms) - @abc.abstractmethod - def _set_handlers(self): - pass - - @abc.abstractmethod - def _set_timeout(self, timeout: int): - pass - - def are_motors_configured(self): + @property + def are_motors_configured(self) -> bool: """ Only check the motor indices and not baudrate, since if the motor baudrates are incorrect, a ConnectionError will be raised anyway. """ try: - return (self.motor_indices == self.read("ID")).all() + # TODO(aliberts): use ping instead + return (self.motor_ids == self.read("ID")).all() except ConnectionError as e: print(e) return False + def ping(self, motor: str | int, num_retry: int | None = None) -> int: + idx = self.get_safe_id(motor) + for _ in range(num_retry): + model_number, comm, _ = self.packet_handler.ping(self.port, idx) + if self._is_comm_success(comm): + return model_number + + # TODO(aliberts): Should we? + return comm + + @abc.abstractmethod + def broadcast_ping(self, num_retry: int | None = None): + ... + # TODO(aliberts): this will replace 'find_motor_indices' + def find_motor_indices(self, possible_ids: list[str] = None, num_retry: int = 2): if possible_ids is None: possible_ids = range(MAX_ID_RANGE) @@ -280,7 +374,7 @@ class MotorsBus(abc.ABC): indices = [] for idx in tqdm.tqdm(possible_ids): try: - present_idx = self.read_with_motor_ids(self.motor_models, [idx], "ID", num_retry=num_retry)[0] + present_idx = self.read("ID", idx, num_retry=num_retry)[0] except ConnectionError: continue @@ -294,7 +388,7 @@ class MotorsBus(abc.ABC): return indices - def set_baudrate(self, baudrate): + def set_baudrate(self, baudrate) -> None: present_bus_baudrate = self.port_handler.getBaudRate() if present_bus_baudrate != baudrate: print(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.") @@ -303,94 +397,207 @@ class MotorsBus(abc.ABC): if self.port_handler.getBaudRate() != baudrate: raise OSError("Failed to write bus baud rate.") - def set_calibration(self, calibration_dict: dict[str, list]): - self.calibration = calibration_dict + def set_calibration(self, calibration_fpath: Path) -> None: + with open(calibration_fpath) as f: + calibration = json.load(f) + + self.calibration = {int(idx): val for idx, val in calibration.items()} @abc.abstractmethod - def apply_calibration(self): + def calibrate_values(self, ids_values: dict[int, int]) -> dict[int, float]: pass @abc.abstractmethod - def revert_calibration(self): + def uncalibrate_values(self, ids_values: dict[int, float]) -> dict[int, int]: pass - def read(self, data_name, motor_names: str | list[str] | None = None): + @abc.abstractmethod + def _is_comm_success(self, comm: int) -> bool: + pass + + @staticmethod + @abc.abstractmethod + def split_int_bytes(value: int, n_bytes: int) -> list[int]: + """ + Splits an unsigned integer into a list of bytes in little-endian order. + + This function extracts the individual bytes of an integer based on the + specified number of bytes (`n_bytes`). The output is a list of integers, + each representing a byte (0-255). + + **Byte order:** The function returns bytes in **little-endian format**, + meaning the least significant byte (LSB) comes first. + + Args: + value (int): The unsigned integer to be converted into a byte list. Must be within + the valid range for the specified `n_bytes`. + n_bytes (int): The number of bytes to use for conversion. Supported values: + - 1 (for values 0 to 255) + - 2 (for values 0 to 65,535) + - 4 (for values 0 to 4,294,967,295) + + Raises: + ValueError: If `value` is negative or exceeds the maximum allowed for `n_bytes`. + NotImplementedError: If `n_bytes` is not 1, 2, or 4. + + Returns: + list[int]: A list of integers, each representing a byte in **little-endian order**. + + Examples: + >>> split_int_bytes(0x12, 1) + [18] + >>> split_int_bytes(0x1234, 2) + [52, 18] # 0x1234 → 0x34 0x12 (little-endian) + >>> split_int_bytes(0x12345678, 4) + [120, 86, 52, 18] # 0x12345678 → 0x78 0x56 0x34 0x12 + """ + pass + + def get_safe_id(self, motor: str | int) -> int: + if isinstance(motor, str): + return self.motors[motor][0] + elif isinstance(motor, int): + return motor + else: + raise ValueError(f"{motor} should be int or str.") + + def read( + self, data_name: str, motors: str | int | list[str | int] | None = None, num_retry: int = 1 + ) -> dict[str, float]: if not self.is_connected: raise DeviceNotConnectedError( - f"{self.__name__}({self.port}) is not connected. You need to run `{self.__name__}.connect()`." + f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`." ) start_time = time.perf_counter() - if motor_names is None: - motor_names = self.motor_names - if isinstance(motor_names, str): - motor_names = [motor_names] + if motors is None: + motors = self.motor_ids - values = self._read(data_name, motor_names) + if isinstance(motors, (str, int)): + motors = [motors] + + motor_ids = [self.get_safe_id(motor) for motor in motors] + if self._has_different_ctrl_tables: + models = [self.idx_to_model(idx) for idx in motor_ids] + assert_same_address(self.model_ctrl_table, models, data_name) + + model = self.idx_to_model(next(iter(motor_ids))) + addr, n_bytes = self.model_ctrl_table[model][data_name] + + comm, ids_values = self._read(motor_ids, addr, n_bytes, num_retry) + if not self._is_comm_success(comm): + raise ConnectionError( + f"Failed to read {data_name} on port {self.port} for ids {motor_ids}:" + f"{self.packet_handler.getTxRxResult(comm)}" + ) + + if data_name in self.calibration_required and self.calibration is not None: + ids_values = self.calibrate_values(ids_values) + + # TODO(aliberts): return keys in the same format we got them? + ids_values = {self.idx_to_name(idx): val for idx, val in ids_values.items()} # log the number of seconds it took to read the data from the motors - delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_names) + delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_ids) self.logs[delta_ts_name] = time.perf_counter() - start_time # log the utc time at which the data was received - ts_utc_name = get_log_name("timestamp_utc", "read", data_name, motor_names) + ts_utc_name = get_log_name("timestamp_utc", "read", data_name, motor_ids) self.logs[ts_utc_name] = capture_timestamp_utc() - return values + return ids_values - @abc.abstractmethod - def _read(self, data_name: str, motor_names: list[str]): - pass + def _read( + self, motor_ids: list[str], address: int, n_bytes: int, num_retry: int = 1 + ) -> tuple[int, dict[int, int]]: + self.reader.clearParam() + self.reader.start_address = address + self.reader.data_length = n_bytes - def write( - self, data_name: str, values: int | float | np.ndarray, motor_names: str | list[str] | None = None - ) -> None: + # FIXME(aliberts, pkooij): We should probably not have to do this. + # Let's try to see if we can do with better comm status handling instead. + # self.port_handler.ser.reset_output_buffer() + # self.port_handler.ser.reset_input_buffer() + + for idx in motor_ids: + self.reader.addParam(idx) + + for _ in range(num_retry): + comm = self.reader.txRxPacket() + if self._is_comm_success(comm): + break + + values = {idx: self.reader.getData(idx, address, n_bytes) for idx in motor_ids} + return comm, values + + # TODO(aliberts, pkooij): Implementing something like this could get much faster read times. + # Note: this could be at the cost of increase latency between the moment the data is produced by the + # motors and the moment it is used by a policy + # def _async_read(self, motor_ids: list[str], address: int, n_bytes: int): + # self.reader.rxPacket() + # self.reader.txPacket() + # for idx in motor_ids: + # value = self.reader.getData(idx, address, n_bytes) + + def write(self, data_name: str, values_dict: dict[str | int, int], num_retry: int = 1) -> None: if not self.is_connected: raise DeviceNotConnectedError( - f"{self.__name__}({self.port}) is not connected. You need to run `{self.__name__}.connect()`." + f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`." ) start_time = time.perf_counter() - if motor_names is None: - motor_names = self.motor_names + ids_values = {self.get_safe_id(motor): val for motor, val in values_dict.items()} - if isinstance(motor_names, str): - motor_names = [motor_names] + if self._has_different_ctrl_tables: + models = [self.idx_to_model(idx) for idx in ids_values] + assert_same_address(self.model_ctrl_table, models, data_name) - if isinstance(values, (int, float, np.integer)): - values = [int(values)] * len(motor_names) + if data_name in self.calibration_required and self.calibration is not None: + ids_values = self.uncalibrate_values(ids_values) - self._write(data_name, values, motor_names) + model = self.idx_to_model(next(iter(ids_values))) + addr, n_bytes = self.model_ctrl_table[model][data_name] + + comm = self._write(ids_values, addr, n_bytes, num_retry) + if not self._is_comm_success(comm): + raise ConnectionError( + f"Failed to write {data_name} on port {self.port} for ids {list(ids_values)}:" + f"{self.packet_handler.getTxRxResult(comm)}" + ) # log the number of seconds it took to write the data to the motors - delta_ts_name = get_log_name("delta_timestamp_s", "write", data_name, motor_names) + delta_ts_name = get_log_name("delta_timestamp_s", "write", data_name, list(ids_values)) self.logs[delta_ts_name] = time.perf_counter() - start_time - # TODO(rcadene): should we log the time before sending the write command? # log the utc time when the write has been completed - ts_utc_name = get_log_name("timestamp_utc", "write", data_name, motor_names) + ts_utc_name = get_log_name("timestamp_utc", "write", data_name, list(ids_values)) self.logs[ts_utc_name] = capture_timestamp_utc() - @abc.abstractmethod - def _write(self, data_name: str, values: list[int], motor_names: list[str]) -> None: - pass + def _write(self, ids_values: dict[int, int], address: int, n_bytes: int, num_retry: int = 1) -> int: + self.writer.clearParam() + self.writer.start_address = address + self.writer.data_length = n_bytes + + for idx, value in ids_values.items(): + data = self.split_int_bytes(value, n_bytes) + self.writer.addParam(idx, data) + + for _ in range(num_retry): + comm = self.writer.txPacket() + if self._is_comm_success(comm): + break + + return comm def disconnect(self) -> None: if not self.is_connected: raise DeviceNotConnectedError( - f"{self.__name__}({self.port}) is not connected. Try running `{self.__name__}.connect()` first." + f"{self.__class__.__name__}('{self.port}') is not connected. Try running `{self.__class__.__name__}.connect()` first." ) - if self.port_handler is not None: - self.port_handler.closePort() - self.port_handler = None - - self.packet_handler = None - self.group_readers = {} - self.group_writers = {} - self.is_connected = False + self.port_handler.closePort() def __del__(self): if self.is_connected: