Improve read ergonomics & typing, rm find_motor_indices

This commit is contained in:
Simon Alibert 2025-03-22 00:34:07 +01:00
parent 3d119c0ccb
commit 2abfa5838d
1 changed files with 37 additions and 51 deletions

View File

@ -27,16 +27,16 @@ from enum import Enum
from functools import cached_property
from pathlib import Path
from pprint import pformat
from typing import Protocol, TypeAlias, Union
from typing import Protocol, TypeAlias, overload
import serial
import tqdm
from deepdiff import DeepDiff
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.common.utils.utils import capture_timestamp_utc
MotorLike: TypeAlias = Union[str, int, "Motor"]
NameOrID: TypeAlias = str | int
Value: TypeAlias = int | float
MAX_ID_RANGE = 252
@ -361,41 +361,20 @@ class MotorsBus(abc.ABC):
print(e)
return False
def ping(self, motor: MotorLike, num_retry: int | None = None) -> int:
def ping(self, motor: NameOrID, num_retry: int = 0, raise_on_error: bool = False) -> int | None:
idx = self.get_motor_id(motor)
for _ in range(num_retry):
model_number, comm, _ = self.packet_handler.ping(self.port, idx)
for _ in range(1 + num_retry):
model_number, comm, error = self.packet_handler.ping(self.port_handler, idx)
if self._is_comm_success(comm):
return model_number
# TODO(aliberts): Should we?
return comm
if raise_on_error:
raise ConnectionError(f"Ping motor {motor} returned a {error} error code.")
@abc.abstractmethod
def broadcast_ping(self, num_retry: int | None = None):
...
# TODO(aliberts): this will replace 'find_motor_indices'
def find_motor_indices(self, possible_ids: list[str] = None, num_retry: int = 2):
if possible_ids is None:
possible_ids = range(MAX_ID_RANGE)
indices = []
for idx in tqdm.tqdm(possible_ids):
try:
present_idx = self.read("ID", idx, num_retry=num_retry)[0]
except ConnectionError:
continue
if idx != present_idx:
# sanity check
raise OSError(
"Motor index used to communicate through the bus is not the same as the one present in the motor "
"memory. The motor memory might be damaged."
)
indices.append(idx)
return indices
def broadcast_ping(
self, num_retry: int = 0, raise_on_error: bool = False
) -> dict[int, list[int, int]] | None: ...
def set_baudrate(self, baudrate) -> None:
present_bus_baudrate = self.port_handler.getBaudRate()
@ -462,19 +441,23 @@ class MotorsBus(abc.ABC):
"""
pass
def get_motor_id(self, motor: MotorLike) -> int:
def get_motor_id(self, motor: NameOrID) -> int:
if isinstance(motor, str):
return self.motors[motor].id
elif isinstance(motor, int):
return motor
elif isinstance(motor, Motor):
return motor.id
else:
raise ValueError(f"{motor} should be int, str or Motor.")
raise TypeError(f"'{motor}' should be int, str.")
@overload
def read(self, data_name: str, motors: None = ...) -> dict[str, Value]: ...
@overload
def read(self, data_name: str, motors: NameOrID) -> dict[NameOrID, Value]: ...
@overload
def read(self, data_name: str, motors: list[NameOrID]) -> dict[NameOrID, Value]: ...
def read(
self, data_name: str, motors: MotorLike | list[MotorLike] | None = None, num_retry: int = 1
) -> dict[str, float]:
self, data_name: str, motors: NameOrID | list[NameOrID] | None = None, num_retry: int = 0
) -> dict[NameOrID, Value]:
if not self.is_connected:
raise DeviceNotConnectedError(
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
@ -482,13 +465,17 @@ class MotorsBus(abc.ABC):
start_time = time.perf_counter()
id_key_map: dict[int, NameOrID] = {}
if motors is None:
motors = self.ids
id_key_map = {m.id: name for name, m in self.motors.items()}
elif isinstance(motors, (str, int)):
id_key_map = {self.get_motor_id(motors): motors}
elif isinstance(motors, list):
id_key_map = {self.get_motor_id(m): m for m in motors}
else:
raise TypeError(motors)
if isinstance(motors, (str, int)):
motors = [motors]
motor_ids = [self.get_motor_id(motor) for motor in motors]
motor_ids = list(id_key_map)
if self._has_different_ctrl_tables:
models = [self.id_to_model(idx) for idx in motor_ids]
assert_same_address(self.model_ctrl_table, models, data_name)
@ -506,8 +493,7 @@ class MotorsBus(abc.ABC):
if data_name in self.calibration_required and self.calibration is not None:
ids_values = self.calibrate_values(ids_values)
# TODO(aliberts): return keys in the same format we got them?
ids_values = {self.id_to_name(idx): val for idx, val in ids_values.items()}
keys_values = {id_key_map[idx]: val for idx, val in ids_values.items()}
# log the number of seconds it took to read the data from the motors
delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_ids)
@ -517,10 +503,10 @@ class MotorsBus(abc.ABC):
ts_utc_name = get_log_name("timestamp_utc", "read", data_name, motor_ids)
self.logs[ts_utc_name] = capture_timestamp_utc()
return ids_values
return keys_values
def _read(
self, motor_ids: list[str], address: int, n_bytes: int, num_retry: int = 1
self, motor_ids: list[str], address: int, n_bytes: int, num_retry: int = 0
) -> tuple[int, dict[int, int]]:
self.reader.clearParam()
self.reader.start_address = address
@ -534,7 +520,7 @@ class MotorsBus(abc.ABC):
for idx in motor_ids:
self.reader.addParam(idx)
for _ in range(num_retry):
for _ in range(1 + num_retry):
comm = self.reader.txRxPacket()
if self._is_comm_success(comm):
break
@ -551,7 +537,7 @@ class MotorsBus(abc.ABC):
# for idx in motor_ids:
# value = self.reader.getData(idx, address, n_bytes)
def write(self, data_name: str, values: int | dict[MotorLike, int], num_retry: int = 1) -> None:
def write(self, data_name: str, values: int | dict[NameOrID, int], num_retry: int = 0) -> None:
if not self.is_connected:
raise DeviceNotConnectedError(
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
@ -591,7 +577,7 @@ class MotorsBus(abc.ABC):
ts_utc_name = get_log_name("timestamp_utc", "write", data_name, list(ids_values))
self.logs[ts_utc_name] = capture_timestamp_utc()
def _write(self, ids_values: dict[int, int], address: int, n_bytes: int, num_retry: int = 1) -> int:
def _write(self, ids_values: dict[int, int], address: int, n_bytes: int, num_retry: int = 0) -> int:
self.writer.clearParam()
self.writer.start_address = address
self.writer.data_length = n_bytes
@ -600,7 +586,7 @@ class MotorsBus(abc.ABC):
data = self.split_int_bytes(value, n_bytes)
self.writer.addParam(idx, data)
for _ in range(num_retry):
for _ in range(1 + num_retry):
comm = self.writer.txPacket()
if self._is_comm_success(comm):
break