Improve read ergonomics & typing, rm find_motor_indices
This commit is contained in:
parent
3d119c0ccb
commit
2abfa5838d
|
@ -27,16 +27,16 @@ from enum import Enum
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
from typing import Protocol, TypeAlias, Union
|
from typing import Protocol, TypeAlias, overload
|
||||||
|
|
||||||
import serial
|
import serial
|
||||||
import tqdm
|
|
||||||
from deepdiff import DeepDiff
|
from deepdiff import DeepDiff
|
||||||
|
|
||||||
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||||
from lerobot.common.utils.utils import capture_timestamp_utc
|
from lerobot.common.utils.utils import capture_timestamp_utc
|
||||||
|
|
||||||
MotorLike: TypeAlias = Union[str, int, "Motor"]
|
NameOrID: TypeAlias = str | int
|
||||||
|
Value: TypeAlias = int | float
|
||||||
|
|
||||||
MAX_ID_RANGE = 252
|
MAX_ID_RANGE = 252
|
||||||
|
|
||||||
|
@ -361,41 +361,20 @@ class MotorsBus(abc.ABC):
|
||||||
print(e)
|
print(e)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def ping(self, motor: MotorLike, num_retry: int | None = None) -> int:
|
def ping(self, motor: NameOrID, num_retry: int = 0, raise_on_error: bool = False) -> int | None:
|
||||||
idx = self.get_motor_id(motor)
|
idx = self.get_motor_id(motor)
|
||||||
for _ in range(num_retry):
|
for _ in range(1 + num_retry):
|
||||||
model_number, comm, _ = self.packet_handler.ping(self.port, idx)
|
model_number, comm, error = self.packet_handler.ping(self.port_handler, idx)
|
||||||
if self._is_comm_success(comm):
|
if self._is_comm_success(comm):
|
||||||
return model_number
|
return model_number
|
||||||
|
|
||||||
# TODO(aliberts): Should we?
|
if raise_on_error:
|
||||||
return comm
|
raise ConnectionError(f"Ping motor {motor} returned a {error} error code.")
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def broadcast_ping(self, num_retry: int | None = None):
|
def broadcast_ping(
|
||||||
...
|
self, num_retry: int = 0, raise_on_error: bool = False
|
||||||
# TODO(aliberts): this will replace 'find_motor_indices'
|
) -> dict[int, list[int, int]] | None: ...
|
||||||
|
|
||||||
def find_motor_indices(self, possible_ids: list[str] = None, num_retry: int = 2):
|
|
||||||
if possible_ids is None:
|
|
||||||
possible_ids = range(MAX_ID_RANGE)
|
|
||||||
|
|
||||||
indices = []
|
|
||||||
for idx in tqdm.tqdm(possible_ids):
|
|
||||||
try:
|
|
||||||
present_idx = self.read("ID", idx, num_retry=num_retry)[0]
|
|
||||||
except ConnectionError:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if idx != present_idx:
|
|
||||||
# sanity check
|
|
||||||
raise OSError(
|
|
||||||
"Motor index used to communicate through the bus is not the same as the one present in the motor "
|
|
||||||
"memory. The motor memory might be damaged."
|
|
||||||
)
|
|
||||||
indices.append(idx)
|
|
||||||
|
|
||||||
return indices
|
|
||||||
|
|
||||||
def set_baudrate(self, baudrate) -> None:
|
def set_baudrate(self, baudrate) -> None:
|
||||||
present_bus_baudrate = self.port_handler.getBaudRate()
|
present_bus_baudrate = self.port_handler.getBaudRate()
|
||||||
|
@ -462,19 +441,23 @@ class MotorsBus(abc.ABC):
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_motor_id(self, motor: MotorLike) -> int:
|
def get_motor_id(self, motor: NameOrID) -> int:
|
||||||
if isinstance(motor, str):
|
if isinstance(motor, str):
|
||||||
return self.motors[motor].id
|
return self.motors[motor].id
|
||||||
elif isinstance(motor, int):
|
elif isinstance(motor, int):
|
||||||
return motor
|
return motor
|
||||||
elif isinstance(motor, Motor):
|
|
||||||
return motor.id
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"{motor} should be int, str or Motor.")
|
raise TypeError(f"'{motor}' should be int, str.")
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def read(self, data_name: str, motors: None = ...) -> dict[str, Value]: ...
|
||||||
|
@overload
|
||||||
|
def read(self, data_name: str, motors: NameOrID) -> dict[NameOrID, Value]: ...
|
||||||
|
@overload
|
||||||
|
def read(self, data_name: str, motors: list[NameOrID]) -> dict[NameOrID, Value]: ...
|
||||||
def read(
|
def read(
|
||||||
self, data_name: str, motors: MotorLike | list[MotorLike] | None = None, num_retry: int = 1
|
self, data_name: str, motors: NameOrID | list[NameOrID] | None = None, num_retry: int = 0
|
||||||
) -> dict[str, float]:
|
) -> dict[NameOrID, Value]:
|
||||||
if not self.is_connected:
|
if not self.is_connected:
|
||||||
raise DeviceNotConnectedError(
|
raise DeviceNotConnectedError(
|
||||||
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
|
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
|
||||||
|
@ -482,13 +465,17 @@ class MotorsBus(abc.ABC):
|
||||||
|
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
id_key_map: dict[int, NameOrID] = {}
|
||||||
if motors is None:
|
if motors is None:
|
||||||
motors = self.ids
|
id_key_map = {m.id: name for name, m in self.motors.items()}
|
||||||
|
elif isinstance(motors, (str, int)):
|
||||||
|
id_key_map = {self.get_motor_id(motors): motors}
|
||||||
|
elif isinstance(motors, list):
|
||||||
|
id_key_map = {self.get_motor_id(m): m for m in motors}
|
||||||
|
else:
|
||||||
|
raise TypeError(motors)
|
||||||
|
|
||||||
if isinstance(motors, (str, int)):
|
motor_ids = list(id_key_map)
|
||||||
motors = [motors]
|
|
||||||
|
|
||||||
motor_ids = [self.get_motor_id(motor) for motor in motors]
|
|
||||||
if self._has_different_ctrl_tables:
|
if self._has_different_ctrl_tables:
|
||||||
models = [self.id_to_model(idx) for idx in motor_ids]
|
models = [self.id_to_model(idx) for idx in motor_ids]
|
||||||
assert_same_address(self.model_ctrl_table, models, data_name)
|
assert_same_address(self.model_ctrl_table, models, data_name)
|
||||||
|
@ -506,8 +493,7 @@ class MotorsBus(abc.ABC):
|
||||||
if data_name in self.calibration_required and self.calibration is not None:
|
if data_name in self.calibration_required and self.calibration is not None:
|
||||||
ids_values = self.calibrate_values(ids_values)
|
ids_values = self.calibrate_values(ids_values)
|
||||||
|
|
||||||
# TODO(aliberts): return keys in the same format we got them?
|
keys_values = {id_key_map[idx]: val for idx, val in ids_values.items()}
|
||||||
ids_values = {self.id_to_name(idx): val for idx, val in ids_values.items()}
|
|
||||||
|
|
||||||
# log the number of seconds it took to read the data from the motors
|
# log the number of seconds it took to read the data from the motors
|
||||||
delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_ids)
|
delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_ids)
|
||||||
|
@ -517,10 +503,10 @@ class MotorsBus(abc.ABC):
|
||||||
ts_utc_name = get_log_name("timestamp_utc", "read", data_name, motor_ids)
|
ts_utc_name = get_log_name("timestamp_utc", "read", data_name, motor_ids)
|
||||||
self.logs[ts_utc_name] = capture_timestamp_utc()
|
self.logs[ts_utc_name] = capture_timestamp_utc()
|
||||||
|
|
||||||
return ids_values
|
return keys_values
|
||||||
|
|
||||||
def _read(
|
def _read(
|
||||||
self, motor_ids: list[str], address: int, n_bytes: int, num_retry: int = 1
|
self, motor_ids: list[str], address: int, n_bytes: int, num_retry: int = 0
|
||||||
) -> tuple[int, dict[int, int]]:
|
) -> tuple[int, dict[int, int]]:
|
||||||
self.reader.clearParam()
|
self.reader.clearParam()
|
||||||
self.reader.start_address = address
|
self.reader.start_address = address
|
||||||
|
@ -534,7 +520,7 @@ class MotorsBus(abc.ABC):
|
||||||
for idx in motor_ids:
|
for idx in motor_ids:
|
||||||
self.reader.addParam(idx)
|
self.reader.addParam(idx)
|
||||||
|
|
||||||
for _ in range(num_retry):
|
for _ in range(1 + num_retry):
|
||||||
comm = self.reader.txRxPacket()
|
comm = self.reader.txRxPacket()
|
||||||
if self._is_comm_success(comm):
|
if self._is_comm_success(comm):
|
||||||
break
|
break
|
||||||
|
@ -551,7 +537,7 @@ class MotorsBus(abc.ABC):
|
||||||
# for idx in motor_ids:
|
# for idx in motor_ids:
|
||||||
# value = self.reader.getData(idx, address, n_bytes)
|
# value = self.reader.getData(idx, address, n_bytes)
|
||||||
|
|
||||||
def write(self, data_name: str, values: int | dict[MotorLike, int], num_retry: int = 1) -> None:
|
def write(self, data_name: str, values: int | dict[NameOrID, int], num_retry: int = 0) -> None:
|
||||||
if not self.is_connected:
|
if not self.is_connected:
|
||||||
raise DeviceNotConnectedError(
|
raise DeviceNotConnectedError(
|
||||||
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
|
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
|
||||||
|
@ -591,7 +577,7 @@ class MotorsBus(abc.ABC):
|
||||||
ts_utc_name = get_log_name("timestamp_utc", "write", data_name, list(ids_values))
|
ts_utc_name = get_log_name("timestamp_utc", "write", data_name, list(ids_values))
|
||||||
self.logs[ts_utc_name] = capture_timestamp_utc()
|
self.logs[ts_utc_name] = capture_timestamp_utc()
|
||||||
|
|
||||||
def _write(self, ids_values: dict[int, int], address: int, n_bytes: int, num_retry: int = 1) -> int:
|
def _write(self, ids_values: dict[int, int], address: int, n_bytes: int, num_retry: int = 0) -> int:
|
||||||
self.writer.clearParam()
|
self.writer.clearParam()
|
||||||
self.writer.start_address = address
|
self.writer.start_address = address
|
||||||
self.writer.data_length = n_bytes
|
self.writer.data_length = n_bytes
|
||||||
|
@ -600,7 +586,7 @@ class MotorsBus(abc.ABC):
|
||||||
data = self.split_int_bytes(value, n_bytes)
|
data = self.split_int_bytes(value, n_bytes)
|
||||||
self.writer.addParam(idx, data)
|
self.writer.addParam(idx, data)
|
||||||
|
|
||||||
for _ in range(num_retry):
|
for _ in range(1 + num_retry):
|
||||||
comm = self.writer.txPacket()
|
comm = self.writer.txPacket()
|
||||||
if self._is_comm_success(comm):
|
if self._is_comm_success(comm):
|
||||||
break
|
break
|
||||||
|
|
Loading…
Reference in New Issue