From 2abfa5838d05795eb4df535773cb3d2cf67dd8b7 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Sat, 22 Mar 2025 00:34:07 +0100 Subject: [PATCH] Improve read ergonomics & typing, rm find_motor_indices --- lerobot/common/motors/motors_bus.py | 88 ++++++++++++----------------- 1 file changed, 37 insertions(+), 51 deletions(-) diff --git a/lerobot/common/motors/motors_bus.py b/lerobot/common/motors/motors_bus.py index 725ea454..8343d922 100644 --- a/lerobot/common/motors/motors_bus.py +++ b/lerobot/common/motors/motors_bus.py @@ -27,16 +27,16 @@ from enum import Enum from functools import cached_property from pathlib import Path from pprint import pformat -from typing import Protocol, TypeAlias, Union +from typing import Protocol, TypeAlias, overload import serial -import tqdm from deepdiff import DeepDiff from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from lerobot.common.utils.utils import capture_timestamp_utc -MotorLike: TypeAlias = Union[str, int, "Motor"] +NameOrID: TypeAlias = str | int +Value: TypeAlias = int | float MAX_ID_RANGE = 252 @@ -361,41 +361,20 @@ class MotorsBus(abc.ABC): print(e) return False - def ping(self, motor: MotorLike, num_retry: int | None = None) -> int: + def ping(self, motor: NameOrID, num_retry: int = 0, raise_on_error: bool = False) -> int | None: idx = self.get_motor_id(motor) - for _ in range(num_retry): - model_number, comm, _ = self.packet_handler.ping(self.port, idx) + for _ in range(1 + num_retry): + model_number, comm, error = self.packet_handler.ping(self.port_handler, idx) if self._is_comm_success(comm): return model_number - # TODO(aliberts): Should we? - return comm + if raise_on_error: + raise ConnectionError(f"Ping motor {motor} returned a {error} error code.") @abc.abstractmethod - def broadcast_ping(self, num_retry: int | None = None): - ... - # TODO(aliberts): this will replace 'find_motor_indices' - - def find_motor_indices(self, possible_ids: list[str] = None, num_retry: int = 2): - if possible_ids is None: - possible_ids = range(MAX_ID_RANGE) - - indices = [] - for idx in tqdm.tqdm(possible_ids): - try: - present_idx = self.read("ID", idx, num_retry=num_retry)[0] - except ConnectionError: - continue - - if idx != present_idx: - # sanity check - raise OSError( - "Motor index used to communicate through the bus is not the same as the one present in the motor " - "memory. The motor memory might be damaged." - ) - indices.append(idx) - - return indices + def broadcast_ping( + self, num_retry: int = 0, raise_on_error: bool = False + ) -> dict[int, list[int, int]] | None: ... def set_baudrate(self, baudrate) -> None: present_bus_baudrate = self.port_handler.getBaudRate() @@ -462,19 +441,23 @@ class MotorsBus(abc.ABC): """ pass - def get_motor_id(self, motor: MotorLike) -> int: + def get_motor_id(self, motor: NameOrID) -> int: if isinstance(motor, str): return self.motors[motor].id elif isinstance(motor, int): return motor - elif isinstance(motor, Motor): - return motor.id else: - raise ValueError(f"{motor} should be int, str or Motor.") + raise TypeError(f"'{motor}' should be int, str.") + @overload + def read(self, data_name: str, motors: None = ...) -> dict[str, Value]: ... + @overload + def read(self, data_name: str, motors: NameOrID) -> dict[NameOrID, Value]: ... + @overload + def read(self, data_name: str, motors: list[NameOrID]) -> dict[NameOrID, Value]: ... def read( - self, data_name: str, motors: MotorLike | list[MotorLike] | None = None, num_retry: int = 1 - ) -> dict[str, float]: + self, data_name: str, motors: NameOrID | list[NameOrID] | None = None, num_retry: int = 0 + ) -> dict[NameOrID, Value]: if not self.is_connected: raise DeviceNotConnectedError( f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`." @@ -482,13 +465,17 @@ class MotorsBus(abc.ABC): start_time = time.perf_counter() + id_key_map: dict[int, NameOrID] = {} if motors is None: - motors = self.ids + id_key_map = {m.id: name for name, m in self.motors.items()} + elif isinstance(motors, (str, int)): + id_key_map = {self.get_motor_id(motors): motors} + elif isinstance(motors, list): + id_key_map = {self.get_motor_id(m): m for m in motors} + else: + raise TypeError(motors) - if isinstance(motors, (str, int)): - motors = [motors] - - motor_ids = [self.get_motor_id(motor) for motor in motors] + motor_ids = list(id_key_map) if self._has_different_ctrl_tables: models = [self.id_to_model(idx) for idx in motor_ids] assert_same_address(self.model_ctrl_table, models, data_name) @@ -506,8 +493,7 @@ class MotorsBus(abc.ABC): if data_name in self.calibration_required and self.calibration is not None: ids_values = self.calibrate_values(ids_values) - # TODO(aliberts): return keys in the same format we got them? - ids_values = {self.id_to_name(idx): val for idx, val in ids_values.items()} + keys_values = {id_key_map[idx]: val for idx, val in ids_values.items()} # log the number of seconds it took to read the data from the motors delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_ids) @@ -517,10 +503,10 @@ class MotorsBus(abc.ABC): ts_utc_name = get_log_name("timestamp_utc", "read", data_name, motor_ids) self.logs[ts_utc_name] = capture_timestamp_utc() - return ids_values + return keys_values def _read( - self, motor_ids: list[str], address: int, n_bytes: int, num_retry: int = 1 + self, motor_ids: list[str], address: int, n_bytes: int, num_retry: int = 0 ) -> tuple[int, dict[int, int]]: self.reader.clearParam() self.reader.start_address = address @@ -534,7 +520,7 @@ class MotorsBus(abc.ABC): for idx in motor_ids: self.reader.addParam(idx) - for _ in range(num_retry): + for _ in range(1 + num_retry): comm = self.reader.txRxPacket() if self._is_comm_success(comm): break @@ -551,7 +537,7 @@ class MotorsBus(abc.ABC): # for idx in motor_ids: # value = self.reader.getData(idx, address, n_bytes) - def write(self, data_name: str, values: int | dict[MotorLike, int], num_retry: int = 1) -> None: + def write(self, data_name: str, values: int | dict[NameOrID, int], num_retry: int = 0) -> None: if not self.is_connected: raise DeviceNotConnectedError( f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`." @@ -591,7 +577,7 @@ class MotorsBus(abc.ABC): ts_utc_name = get_log_name("timestamp_utc", "write", data_name, list(ids_values)) self.logs[ts_utc_name] = capture_timestamp_utc() - def _write(self, ids_values: dict[int, int], address: int, n_bytes: int, num_retry: int = 1) -> int: + def _write(self, ids_values: dict[int, int], address: int, n_bytes: int, num_retry: int = 0) -> int: self.writer.clearParam() self.writer.start_address = address self.writer.data_length = n_bytes @@ -600,7 +586,7 @@ class MotorsBus(abc.ABC): data = self.split_int_bytes(value, n_bytes) self.writer.addParam(idx, data) - for _ in range(num_retry): + for _ in range(1 + num_retry): comm = self.writer.txPacket() if self._is_comm_success(comm): break