Improve read ergonomics & typing, rm find_motor_indices
This commit is contained in:
parent
3d119c0ccb
commit
2abfa5838d
|
@ -27,16 +27,16 @@ from enum import Enum
|
|||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
from typing import Protocol, TypeAlias, Union
|
||||
from typing import Protocol, TypeAlias, overload
|
||||
|
||||
import serial
|
||||
import tqdm
|
||||
from deepdiff import DeepDiff
|
||||
|
||||
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.common.utils.utils import capture_timestamp_utc
|
||||
|
||||
MotorLike: TypeAlias = Union[str, int, "Motor"]
|
||||
NameOrID: TypeAlias = str | int
|
||||
Value: TypeAlias = int | float
|
||||
|
||||
MAX_ID_RANGE = 252
|
||||
|
||||
|
@ -361,41 +361,20 @@ class MotorsBus(abc.ABC):
|
|||
print(e)
|
||||
return False
|
||||
|
||||
def ping(self, motor: MotorLike, num_retry: int | None = None) -> int:
|
||||
def ping(self, motor: NameOrID, num_retry: int = 0, raise_on_error: bool = False) -> int | None:
|
||||
idx = self.get_motor_id(motor)
|
||||
for _ in range(num_retry):
|
||||
model_number, comm, _ = self.packet_handler.ping(self.port, idx)
|
||||
for _ in range(1 + num_retry):
|
||||
model_number, comm, error = self.packet_handler.ping(self.port_handler, idx)
|
||||
if self._is_comm_success(comm):
|
||||
return model_number
|
||||
|
||||
# TODO(aliberts): Should we?
|
||||
return comm
|
||||
if raise_on_error:
|
||||
raise ConnectionError(f"Ping motor {motor} returned a {error} error code.")
|
||||
|
||||
@abc.abstractmethod
|
||||
def broadcast_ping(self, num_retry: int | None = None):
|
||||
...
|
||||
# TODO(aliberts): this will replace 'find_motor_indices'
|
||||
|
||||
def find_motor_indices(self, possible_ids: list[str] = None, num_retry: int = 2):
|
||||
if possible_ids is None:
|
||||
possible_ids = range(MAX_ID_RANGE)
|
||||
|
||||
indices = []
|
||||
for idx in tqdm.tqdm(possible_ids):
|
||||
try:
|
||||
present_idx = self.read("ID", idx, num_retry=num_retry)[0]
|
||||
except ConnectionError:
|
||||
continue
|
||||
|
||||
if idx != present_idx:
|
||||
# sanity check
|
||||
raise OSError(
|
||||
"Motor index used to communicate through the bus is not the same as the one present in the motor "
|
||||
"memory. The motor memory might be damaged."
|
||||
)
|
||||
indices.append(idx)
|
||||
|
||||
return indices
|
||||
def broadcast_ping(
|
||||
self, num_retry: int = 0, raise_on_error: bool = False
|
||||
) -> dict[int, list[int, int]] | None: ...
|
||||
|
||||
def set_baudrate(self, baudrate) -> None:
|
||||
present_bus_baudrate = self.port_handler.getBaudRate()
|
||||
|
@ -462,19 +441,23 @@ class MotorsBus(abc.ABC):
|
|||
"""
|
||||
pass
|
||||
|
||||
def get_motor_id(self, motor: MotorLike) -> int:
|
||||
def get_motor_id(self, motor: NameOrID) -> int:
|
||||
if isinstance(motor, str):
|
||||
return self.motors[motor].id
|
||||
elif isinstance(motor, int):
|
||||
return motor
|
||||
elif isinstance(motor, Motor):
|
||||
return motor.id
|
||||
else:
|
||||
raise ValueError(f"{motor} should be int, str or Motor.")
|
||||
raise TypeError(f"'{motor}' should be int, str.")
|
||||
|
||||
@overload
|
||||
def read(self, data_name: str, motors: None = ...) -> dict[str, Value]: ...
|
||||
@overload
|
||||
def read(self, data_name: str, motors: NameOrID) -> dict[NameOrID, Value]: ...
|
||||
@overload
|
||||
def read(self, data_name: str, motors: list[NameOrID]) -> dict[NameOrID, Value]: ...
|
||||
def read(
|
||||
self, data_name: str, motors: MotorLike | list[MotorLike] | None = None, num_retry: int = 1
|
||||
) -> dict[str, float]:
|
||||
self, data_name: str, motors: NameOrID | list[NameOrID] | None = None, num_retry: int = 0
|
||||
) -> dict[NameOrID, Value]:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
|
||||
|
@ -482,13 +465,17 @@ class MotorsBus(abc.ABC):
|
|||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
id_key_map: dict[int, NameOrID] = {}
|
||||
if motors is None:
|
||||
motors = self.ids
|
||||
id_key_map = {m.id: name for name, m in self.motors.items()}
|
||||
elif isinstance(motors, (str, int)):
|
||||
id_key_map = {self.get_motor_id(motors): motors}
|
||||
elif isinstance(motors, list):
|
||||
id_key_map = {self.get_motor_id(m): m for m in motors}
|
||||
else:
|
||||
raise TypeError(motors)
|
||||
|
||||
if isinstance(motors, (str, int)):
|
||||
motors = [motors]
|
||||
|
||||
motor_ids = [self.get_motor_id(motor) for motor in motors]
|
||||
motor_ids = list(id_key_map)
|
||||
if self._has_different_ctrl_tables:
|
||||
models = [self.id_to_model(idx) for idx in motor_ids]
|
||||
assert_same_address(self.model_ctrl_table, models, data_name)
|
||||
|
@ -506,8 +493,7 @@ class MotorsBus(abc.ABC):
|
|||
if data_name in self.calibration_required and self.calibration is not None:
|
||||
ids_values = self.calibrate_values(ids_values)
|
||||
|
||||
# TODO(aliberts): return keys in the same format we got them?
|
||||
ids_values = {self.id_to_name(idx): val for idx, val in ids_values.items()}
|
||||
keys_values = {id_key_map[idx]: val for idx, val in ids_values.items()}
|
||||
|
||||
# log the number of seconds it took to read the data from the motors
|
||||
delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_ids)
|
||||
|
@ -517,10 +503,10 @@ class MotorsBus(abc.ABC):
|
|||
ts_utc_name = get_log_name("timestamp_utc", "read", data_name, motor_ids)
|
||||
self.logs[ts_utc_name] = capture_timestamp_utc()
|
||||
|
||||
return ids_values
|
||||
return keys_values
|
||||
|
||||
def _read(
|
||||
self, motor_ids: list[str], address: int, n_bytes: int, num_retry: int = 1
|
||||
self, motor_ids: list[str], address: int, n_bytes: int, num_retry: int = 0
|
||||
) -> tuple[int, dict[int, int]]:
|
||||
self.reader.clearParam()
|
||||
self.reader.start_address = address
|
||||
|
@ -534,7 +520,7 @@ class MotorsBus(abc.ABC):
|
|||
for idx in motor_ids:
|
||||
self.reader.addParam(idx)
|
||||
|
||||
for _ in range(num_retry):
|
||||
for _ in range(1 + num_retry):
|
||||
comm = self.reader.txRxPacket()
|
||||
if self._is_comm_success(comm):
|
||||
break
|
||||
|
@ -551,7 +537,7 @@ class MotorsBus(abc.ABC):
|
|||
# for idx in motor_ids:
|
||||
# value = self.reader.getData(idx, address, n_bytes)
|
||||
|
||||
def write(self, data_name: str, values: int | dict[MotorLike, int], num_retry: int = 1) -> None:
|
||||
def write(self, data_name: str, values: int | dict[NameOrID, int], num_retry: int = 0) -> None:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
|
||||
|
@ -591,7 +577,7 @@ class MotorsBus(abc.ABC):
|
|||
ts_utc_name = get_log_name("timestamp_utc", "write", data_name, list(ids_values))
|
||||
self.logs[ts_utc_name] = capture_timestamp_utc()
|
||||
|
||||
def _write(self, ids_values: dict[int, int], address: int, n_bytes: int, num_retry: int = 1) -> int:
|
||||
def _write(self, ids_values: dict[int, int], address: int, n_bytes: int, num_retry: int = 0) -> int:
|
||||
self.writer.clearParam()
|
||||
self.writer.start_address = address
|
||||
self.writer.data_length = n_bytes
|
||||
|
@ -600,7 +586,7 @@ class MotorsBus(abc.ABC):
|
|||
data = self.split_int_bytes(value, n_bytes)
|
||||
self.writer.addParam(idx, data)
|
||||
|
||||
for _ in range(num_retry):
|
||||
for _ in range(1 + num_retry):
|
||||
comm = self.writer.txPacket()
|
||||
if self._is_comm_success(comm):
|
||||
break
|
||||
|
|
Loading…
Reference in New Issue