Improve read ergonomics & typing, rm find_motor_indices

This commit is contained in:
Simon Alibert 2025-03-22 00:34:07 +01:00
parent 3d119c0ccb
commit 2abfa5838d
1 changed files with 37 additions and 51 deletions

View File

@ -27,16 +27,16 @@ from enum import Enum
from functools import cached_property from functools import cached_property
from pathlib import Path from pathlib import Path
from pprint import pformat from pprint import pformat
from typing import Protocol, TypeAlias, Union from typing import Protocol, TypeAlias, overload
import serial import serial
import tqdm
from deepdiff import DeepDiff from deepdiff import DeepDiff
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.common.utils.utils import capture_timestamp_utc from lerobot.common.utils.utils import capture_timestamp_utc
MotorLike: TypeAlias = Union[str, int, "Motor"] NameOrID: TypeAlias = str | int
Value: TypeAlias = int | float
MAX_ID_RANGE = 252 MAX_ID_RANGE = 252
@ -361,41 +361,20 @@ class MotorsBus(abc.ABC):
print(e) print(e)
return False return False
def ping(self, motor: MotorLike, num_retry: int | None = None) -> int: def ping(self, motor: NameOrID, num_retry: int = 0, raise_on_error: bool = False) -> int | None:
idx = self.get_motor_id(motor) idx = self.get_motor_id(motor)
for _ in range(num_retry): for _ in range(1 + num_retry):
model_number, comm, _ = self.packet_handler.ping(self.port, idx) model_number, comm, error = self.packet_handler.ping(self.port_handler, idx)
if self._is_comm_success(comm): if self._is_comm_success(comm):
return model_number return model_number
# TODO(aliberts): Should we? if raise_on_error:
return comm raise ConnectionError(f"Ping motor {motor} returned a {error} error code.")
@abc.abstractmethod @abc.abstractmethod
def broadcast_ping(self, num_retry: int | None = None): def broadcast_ping(
... self, num_retry: int = 0, raise_on_error: bool = False
# TODO(aliberts): this will replace 'find_motor_indices' ) -> dict[int, list[int, int]] | None: ...
def find_motor_indices(self, possible_ids: list[str] = None, num_retry: int = 2):
if possible_ids is None:
possible_ids = range(MAX_ID_RANGE)
indices = []
for idx in tqdm.tqdm(possible_ids):
try:
present_idx = self.read("ID", idx, num_retry=num_retry)[0]
except ConnectionError:
continue
if idx != present_idx:
# sanity check
raise OSError(
"Motor index used to communicate through the bus is not the same as the one present in the motor "
"memory. The motor memory might be damaged."
)
indices.append(idx)
return indices
def set_baudrate(self, baudrate) -> None: def set_baudrate(self, baudrate) -> None:
present_bus_baudrate = self.port_handler.getBaudRate() present_bus_baudrate = self.port_handler.getBaudRate()
@ -462,19 +441,23 @@ class MotorsBus(abc.ABC):
""" """
pass pass
def get_motor_id(self, motor: MotorLike) -> int: def get_motor_id(self, motor: NameOrID) -> int:
if isinstance(motor, str): if isinstance(motor, str):
return self.motors[motor].id return self.motors[motor].id
elif isinstance(motor, int): elif isinstance(motor, int):
return motor return motor
elif isinstance(motor, Motor):
return motor.id
else: else:
raise ValueError(f"{motor} should be int, str or Motor.") raise TypeError(f"'{motor}' should be int, str.")
@overload
def read(self, data_name: str, motors: None = ...) -> dict[str, Value]: ...
@overload
def read(self, data_name: str, motors: NameOrID) -> dict[NameOrID, Value]: ...
@overload
def read(self, data_name: str, motors: list[NameOrID]) -> dict[NameOrID, Value]: ...
def read( def read(
self, data_name: str, motors: MotorLike | list[MotorLike] | None = None, num_retry: int = 1 self, data_name: str, motors: NameOrID | list[NameOrID] | None = None, num_retry: int = 0
) -> dict[str, float]: ) -> dict[NameOrID, Value]:
if not self.is_connected: if not self.is_connected:
raise DeviceNotConnectedError( raise DeviceNotConnectedError(
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`." f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
@ -482,13 +465,17 @@ class MotorsBus(abc.ABC):
start_time = time.perf_counter() start_time = time.perf_counter()
id_key_map: dict[int, NameOrID] = {}
if motors is None: if motors is None:
motors = self.ids id_key_map = {m.id: name for name, m in self.motors.items()}
elif isinstance(motors, (str, int)):
id_key_map = {self.get_motor_id(motors): motors}
elif isinstance(motors, list):
id_key_map = {self.get_motor_id(m): m for m in motors}
else:
raise TypeError(motors)
if isinstance(motors, (str, int)): motor_ids = list(id_key_map)
motors = [motors]
motor_ids = [self.get_motor_id(motor) for motor in motors]
if self._has_different_ctrl_tables: if self._has_different_ctrl_tables:
models = [self.id_to_model(idx) for idx in motor_ids] models = [self.id_to_model(idx) for idx in motor_ids]
assert_same_address(self.model_ctrl_table, models, data_name) assert_same_address(self.model_ctrl_table, models, data_name)
@ -506,8 +493,7 @@ class MotorsBus(abc.ABC):
if data_name in self.calibration_required and self.calibration is not None: if data_name in self.calibration_required and self.calibration is not None:
ids_values = self.calibrate_values(ids_values) ids_values = self.calibrate_values(ids_values)
# TODO(aliberts): return keys in the same format we got them? keys_values = {id_key_map[idx]: val for idx, val in ids_values.items()}
ids_values = {self.id_to_name(idx): val for idx, val in ids_values.items()}
# log the number of seconds it took to read the data from the motors # log the number of seconds it took to read the data from the motors
delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_ids) delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_ids)
@ -517,10 +503,10 @@ class MotorsBus(abc.ABC):
ts_utc_name = get_log_name("timestamp_utc", "read", data_name, motor_ids) ts_utc_name = get_log_name("timestamp_utc", "read", data_name, motor_ids)
self.logs[ts_utc_name] = capture_timestamp_utc() self.logs[ts_utc_name] = capture_timestamp_utc()
return ids_values return keys_values
def _read( def _read(
self, motor_ids: list[str], address: int, n_bytes: int, num_retry: int = 1 self, motor_ids: list[str], address: int, n_bytes: int, num_retry: int = 0
) -> tuple[int, dict[int, int]]: ) -> tuple[int, dict[int, int]]:
self.reader.clearParam() self.reader.clearParam()
self.reader.start_address = address self.reader.start_address = address
@ -534,7 +520,7 @@ class MotorsBus(abc.ABC):
for idx in motor_ids: for idx in motor_ids:
self.reader.addParam(idx) self.reader.addParam(idx)
for _ in range(num_retry): for _ in range(1 + num_retry):
comm = self.reader.txRxPacket() comm = self.reader.txRxPacket()
if self._is_comm_success(comm): if self._is_comm_success(comm):
break break
@ -551,7 +537,7 @@ class MotorsBus(abc.ABC):
# for idx in motor_ids: # for idx in motor_ids:
# value = self.reader.getData(idx, address, n_bytes) # value = self.reader.getData(idx, address, n_bytes)
def write(self, data_name: str, values: int | dict[MotorLike, int], num_retry: int = 1) -> None: def write(self, data_name: str, values: int | dict[NameOrID, int], num_retry: int = 0) -> None:
if not self.is_connected: if not self.is_connected:
raise DeviceNotConnectedError( raise DeviceNotConnectedError(
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`." f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
@ -591,7 +577,7 @@ class MotorsBus(abc.ABC):
ts_utc_name = get_log_name("timestamp_utc", "write", data_name, list(ids_values)) ts_utc_name = get_log_name("timestamp_utc", "write", data_name, list(ids_values))
self.logs[ts_utc_name] = capture_timestamp_utc() self.logs[ts_utc_name] = capture_timestamp_utc()
def _write(self, ids_values: dict[int, int], address: int, n_bytes: int, num_retry: int = 1) -> int: def _write(self, ids_values: dict[int, int], address: int, n_bytes: int, num_retry: int = 0) -> int:
self.writer.clearParam() self.writer.clearParam()
self.writer.start_address = address self.writer.start_address = address
self.writer.data_length = n_bytes self.writer.data_length = n_bytes
@ -600,7 +586,7 @@ class MotorsBus(abc.ABC):
data = self.split_int_bytes(value, n_bytes) data = self.split_int_bytes(value, n_bytes)
self.writer.addParam(idx, data) self.writer.addParam(idx, data)
for _ in range(num_retry): for _ in range(1 + num_retry):
comm = self.writer.txPacket() comm = self.writer.txPacket()
if self._is_comm_success(comm): if self._is_comm_success(comm):
break break