Rewrite MotorsBus
This commit is contained in:
parent
c85a9253e7
commit
9358d334c7
|
@ -20,13 +20,17 @@
|
|||
# ruff: noqa: N802
|
||||
|
||||
import abc
|
||||
import json
|
||||
import time
|
||||
import traceback
|
||||
from enum import Enum
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
from typing import Protocol
|
||||
|
||||
import numpy as np
|
||||
import serial
|
||||
import tqdm
|
||||
from deepdiff import DeepDiff
|
||||
|
||||
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.common.utils.utils import capture_timestamp_utc
|
||||
|
@ -34,8 +38,8 @@ from lerobot.common.utils.utils import capture_timestamp_utc
|
|||
MAX_ID_RANGE = 252
|
||||
|
||||
|
||||
def get_group_sync_key(data_name: str, motor_names: list[str]) -> str:
|
||||
group_key = f"{data_name}_" + "_".join(motor_names)
|
||||
def get_group_sync_key(data_name: str, motor_ids: list[int]) -> str:
|
||||
group_key = f"{data_name}_" + "_".join([str(idx) for idx in motor_ids])
|
||||
return group_key
|
||||
|
||||
|
||||
|
@ -98,7 +102,7 @@ class PortHandler(Protocol):
|
|||
self.tx_time_per_byte: float
|
||||
self.is_using: bool
|
||||
self.port_name: str
|
||||
self.ser: object
|
||||
self.ser: serial.Serial
|
||||
|
||||
def openPort(self): ...
|
||||
def closePort(self): ...
|
||||
|
@ -153,6 +157,46 @@ class PacketHandler(Protocol):
|
|||
def syncWriteTxOnly(self, port, start_address, data_length, param, param_length): ...
|
||||
|
||||
|
||||
class GroupSyncRead(Protocol):
|
||||
def __init__(self, port, ph, start_address, data_length):
|
||||
self.port: str
|
||||
self.ph: PortHandler
|
||||
self.start_address: int
|
||||
self.data_length: int
|
||||
self.last_result: bool
|
||||
self.is_param_changed: bool
|
||||
self.param: list
|
||||
self.data_dict: dict
|
||||
|
||||
def makeParam(self): ...
|
||||
def addParam(self, id): ...
|
||||
def removeParam(self, id): ...
|
||||
def clearParam(self): ...
|
||||
def txPacket(self): ...
|
||||
def rxPacket(self): ...
|
||||
def txRxPacket(self): ...
|
||||
def isAvailable(self, id, address, data_length): ...
|
||||
def getData(self, id, address, data_length): ...
|
||||
|
||||
|
||||
class GroupSyncWrite(Protocol):
|
||||
def __init__(self, port, ph, start_address, data_length):
|
||||
self.port: str
|
||||
self.ph: PortHandler
|
||||
self.start_address: int
|
||||
self.data_length: int
|
||||
self.is_param_changed: bool
|
||||
self.param: list
|
||||
self.data_dict: dict
|
||||
|
||||
def makeParam(self): ...
|
||||
def addParam(self, id, data): ...
|
||||
def removeParam(self, id): ...
|
||||
def changeParam(self, id, data): ...
|
||||
def clearParam(self): ...
|
||||
def txPacket(self): ...
|
||||
|
||||
|
||||
class MotorsBus(abc.ABC):
|
||||
"""The main LeRobot class for implementing motors buses.
|
||||
|
||||
|
@ -163,7 +207,7 @@ class MotorsBus(abc.ABC):
|
|||
Note: This class may evolve in the future should we add support for other manufacturers SDKs.
|
||||
|
||||
A MotorsBus allows to efficiently read and write to the attached motors.
|
||||
It represents a several motors daisy-chained together and connected through a serial port.
|
||||
It represents several motors daisy-chained together and connected through a serial port.
|
||||
|
||||
A MotorsBus subclass instance requires a port (e.g. `FeetechMotorsBus(port="/dev/tty.usbmodem575E0031751"`)).
|
||||
To find the port, you can run our utility script:
|
||||
|
@ -198,6 +242,8 @@ class MotorsBus(abc.ABC):
|
|||
model_ctrl_table: dict[str, dict]
|
||||
model_resolution_table: dict[str, int]
|
||||
model_baudrate_table: dict[str, dict]
|
||||
calibration_required: list[str]
|
||||
default_timeout: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -206,73 +252,121 @@ class MotorsBus(abc.ABC):
|
|||
):
|
||||
self.port = port
|
||||
self.motors = motors
|
||||
self.port_handler: PortHandler | None = None
|
||||
self.packet_handler: PacketHandler | None = None
|
||||
self._validate_motors()
|
||||
|
||||
self.group_readers = {}
|
||||
self.group_writers = {}
|
||||
self.logs = {}
|
||||
self.port_handler: PortHandler
|
||||
self.packet_handler: PacketHandler
|
||||
self.reader: GroupSyncRead
|
||||
self.writer: GroupSyncWrite
|
||||
|
||||
self.logs = {} # TODO(aliberts): use subclass logger
|
||||
self.calibration = None
|
||||
self.is_connected: bool = False
|
||||
|
||||
self._id_to_model = dict(self.motors.values())
|
||||
self._id_to_name = {idx: name for name, (idx, _) in self.motors.items()}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.motors)
|
||||
|
||||
@property
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"{self.__class__.__name__}(\n"
|
||||
f" Port: '{self.port}',\n"
|
||||
f" Motors: \n{pformat(self.motors, indent=8)},\n"
|
||||
")',\n"
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def _has_different_ctrl_tables(self) -> bool:
|
||||
if len(self.motor_models) < 2:
|
||||
return False
|
||||
|
||||
first_table = self.model_ctrl_table[self.motor_models[0]]
|
||||
return any(DeepDiff(first_table, self.model_ctrl_table[model]) for model in self.motor_models[1:])
|
||||
|
||||
def idx_to_model(self, idx: int) -> str:
|
||||
return self._id_to_model[idx]
|
||||
|
||||
def idx_to_name(self, idx: int) -> str:
|
||||
return self._id_to_name[idx]
|
||||
|
||||
@cached_property
|
||||
def motor_names(self) -> list[str]:
|
||||
return list(self.motors)
|
||||
|
||||
@property
|
||||
@cached_property
|
||||
def motor_models(self) -> list[str]:
|
||||
return [model for _, model in self.motors.values()]
|
||||
|
||||
@property
|
||||
def motor_indices(self) -> list[int]:
|
||||
@cached_property
|
||||
def motor_ids(self) -> list[int]:
|
||||
return [idx for idx, _ in self.motors.values()]
|
||||
|
||||
def connect(self):
|
||||
def _validate_motors(self) -> None:
|
||||
# TODO(aliberts): improve error messages for this (display problematics values)
|
||||
if len(self.motor_ids) != len(set(self.motor_ids)):
|
||||
raise ValueError("Some motors have the same id.")
|
||||
|
||||
if len(self.motor_names) != len(set(self.motor_names)):
|
||||
raise ValueError("Some motors have the same name.")
|
||||
|
||||
if any(model not in self.model_resolution_table for model in self.motor_models):
|
||||
raise ValueError("Some motors models are not available.")
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.port_handler.is_open
|
||||
|
||||
def connect(self) -> None:
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(
|
||||
f"{self.__name__}({self.port}) is already connected. Do not call `{self.__name__}.connect()` twice."
|
||||
f"{self.__class__.__name__}('{self.port}') is already connected. Do not call `{self.__class__.__name__}.connect()` twice."
|
||||
)
|
||||
|
||||
self._set_handlers()
|
||||
|
||||
try:
|
||||
if not self.port_handler.openPort():
|
||||
raise OSError(f"Failed to open port '{self.port}'.")
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
except (FileNotFoundError, OSError, serial.SerialException) as e:
|
||||
print(
|
||||
"\nTry running `python lerobot/scripts/find_motors_bus_port.py` to make sure you are using the correct port.\n"
|
||||
f"\nCould not connect on port '{self.port}'. Make sure you are using the correct port."
|
||||
"\nTry running `python lerobot/scripts/find_motors_bus_port.py`\n"
|
||||
)
|
||||
raise
|
||||
raise e
|
||||
|
||||
self._set_timeout()
|
||||
self.set_timeout()
|
||||
|
||||
# Allow to read and write
|
||||
self.is_connected = True
|
||||
def set_timeout(self, timeout_ms: int | None = None):
|
||||
timeout_ms = timeout_ms if timeout_ms is not None else self.default_timeout
|
||||
self.port_handler.setPacketTimeoutMillis(timeout_ms)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _set_handlers(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _set_timeout(self, timeout: int):
|
||||
pass
|
||||
|
||||
def are_motors_configured(self):
|
||||
@property
|
||||
def are_motors_configured(self) -> bool:
|
||||
"""
|
||||
Only check the motor indices and not baudrate, since if the motor baudrates are incorrect, a
|
||||
ConnectionError will be raised anyway.
|
||||
"""
|
||||
try:
|
||||
return (self.motor_indices == self.read("ID")).all()
|
||||
# TODO(aliberts): use ping instead
|
||||
return (self.motor_ids == self.read("ID")).all()
|
||||
except ConnectionError as e:
|
||||
print(e)
|
||||
return False
|
||||
|
||||
def ping(self, motor: str | int, num_retry: int | None = None) -> int:
|
||||
idx = self.get_safe_id(motor)
|
||||
for _ in range(num_retry):
|
||||
model_number, comm, _ = self.packet_handler.ping(self.port, idx)
|
||||
if self._is_comm_success(comm):
|
||||
return model_number
|
||||
|
||||
# TODO(aliberts): Should we?
|
||||
return comm
|
||||
|
||||
@abc.abstractmethod
|
||||
def broadcast_ping(self, num_retry: int | None = None):
|
||||
...
|
||||
# TODO(aliberts): this will replace 'find_motor_indices'
|
||||
|
||||
def find_motor_indices(self, possible_ids: list[str] = None, num_retry: int = 2):
|
||||
if possible_ids is None:
|
||||
possible_ids = range(MAX_ID_RANGE)
|
||||
|
@ -280,7 +374,7 @@ class MotorsBus(abc.ABC):
|
|||
indices = []
|
||||
for idx in tqdm.tqdm(possible_ids):
|
||||
try:
|
||||
present_idx = self.read_with_motor_ids(self.motor_models, [idx], "ID", num_retry=num_retry)[0]
|
||||
present_idx = self.read("ID", idx, num_retry=num_retry)[0]
|
||||
except ConnectionError:
|
||||
continue
|
||||
|
||||
|
@ -294,7 +388,7 @@ class MotorsBus(abc.ABC):
|
|||
|
||||
return indices
|
||||
|
||||
def set_baudrate(self, baudrate):
|
||||
def set_baudrate(self, baudrate) -> None:
|
||||
present_bus_baudrate = self.port_handler.getBaudRate()
|
||||
if present_bus_baudrate != baudrate:
|
||||
print(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.")
|
||||
|
@ -303,94 +397,207 @@ class MotorsBus(abc.ABC):
|
|||
if self.port_handler.getBaudRate() != baudrate:
|
||||
raise OSError("Failed to write bus baud rate.")
|
||||
|
||||
def set_calibration(self, calibration_dict: dict[str, list]):
|
||||
self.calibration = calibration_dict
|
||||
def set_calibration(self, calibration_fpath: Path) -> None:
|
||||
with open(calibration_fpath) as f:
|
||||
calibration = json.load(f)
|
||||
|
||||
self.calibration = {int(idx): val for idx, val in calibration.items()}
|
||||
|
||||
@abc.abstractmethod
|
||||
def apply_calibration(self):
|
||||
def calibrate_values(self, ids_values: dict[int, int]) -> dict[int, float]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def revert_calibration(self):
|
||||
def uncalibrate_values(self, ids_values: dict[int, float]) -> dict[int, int]:
|
||||
pass
|
||||
|
||||
def read(self, data_name, motor_names: str | list[str] | None = None):
|
||||
@abc.abstractmethod
|
||||
def _is_comm_success(self, comm: int) -> bool:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abc.abstractmethod
|
||||
def split_int_bytes(value: int, n_bytes: int) -> list[int]:
|
||||
"""
|
||||
Splits an unsigned integer into a list of bytes in little-endian order.
|
||||
|
||||
This function extracts the individual bytes of an integer based on the
|
||||
specified number of bytes (`n_bytes`). The output is a list of integers,
|
||||
each representing a byte (0-255).
|
||||
|
||||
**Byte order:** The function returns bytes in **little-endian format**,
|
||||
meaning the least significant byte (LSB) comes first.
|
||||
|
||||
Args:
|
||||
value (int): The unsigned integer to be converted into a byte list. Must be within
|
||||
the valid range for the specified `n_bytes`.
|
||||
n_bytes (int): The number of bytes to use for conversion. Supported values:
|
||||
- 1 (for values 0 to 255)
|
||||
- 2 (for values 0 to 65,535)
|
||||
- 4 (for values 0 to 4,294,967,295)
|
||||
|
||||
Raises:
|
||||
ValueError: If `value` is negative or exceeds the maximum allowed for `n_bytes`.
|
||||
NotImplementedError: If `n_bytes` is not 1, 2, or 4.
|
||||
|
||||
Returns:
|
||||
list[int]: A list of integers, each representing a byte in **little-endian order**.
|
||||
|
||||
Examples:
|
||||
>>> split_int_bytes(0x12, 1)
|
||||
[18]
|
||||
>>> split_int_bytes(0x1234, 2)
|
||||
[52, 18] # 0x1234 → 0x34 0x12 (little-endian)
|
||||
>>> split_int_bytes(0x12345678, 4)
|
||||
[120, 86, 52, 18] # 0x12345678 → 0x78 0x56 0x34 0x12
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_safe_id(self, motor: str | int) -> int:
|
||||
if isinstance(motor, str):
|
||||
return self.motors[motor][0]
|
||||
elif isinstance(motor, int):
|
||||
return motor
|
||||
else:
|
||||
raise ValueError(f"{motor} should be int or str.")
|
||||
|
||||
def read(
|
||||
self, data_name: str, motors: str | int | list[str | int] | None = None, num_retry: int = 1
|
||||
) -> dict[str, float]:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
f"{self.__name__}({self.port}) is not connected. You need to run `{self.__name__}.connect()`."
|
||||
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
|
||||
)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
if motor_names is None:
|
||||
motor_names = self.motor_names
|
||||
|
||||
if isinstance(motor_names, str):
|
||||
motor_names = [motor_names]
|
||||
if motors is None:
|
||||
motors = self.motor_ids
|
||||
|
||||
values = self._read(data_name, motor_names)
|
||||
if isinstance(motors, (str, int)):
|
||||
motors = [motors]
|
||||
|
||||
motor_ids = [self.get_safe_id(motor) for motor in motors]
|
||||
if self._has_different_ctrl_tables:
|
||||
models = [self.idx_to_model(idx) for idx in motor_ids]
|
||||
assert_same_address(self.model_ctrl_table, models, data_name)
|
||||
|
||||
model = self.idx_to_model(next(iter(motor_ids)))
|
||||
addr, n_bytes = self.model_ctrl_table[model][data_name]
|
||||
|
||||
comm, ids_values = self._read(motor_ids, addr, n_bytes, num_retry)
|
||||
if not self._is_comm_success(comm):
|
||||
raise ConnectionError(
|
||||
f"Failed to read {data_name} on port {self.port} for ids {motor_ids}:"
|
||||
f"{self.packet_handler.getTxRxResult(comm)}"
|
||||
)
|
||||
|
||||
if data_name in self.calibration_required and self.calibration is not None:
|
||||
ids_values = self.calibrate_values(ids_values)
|
||||
|
||||
# TODO(aliberts): return keys in the same format we got them?
|
||||
ids_values = {self.idx_to_name(idx): val for idx, val in ids_values.items()}
|
||||
|
||||
# log the number of seconds it took to read the data from the motors
|
||||
delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_names)
|
||||
delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_ids)
|
||||
self.logs[delta_ts_name] = time.perf_counter() - start_time
|
||||
|
||||
# log the utc time at which the data was received
|
||||
ts_utc_name = get_log_name("timestamp_utc", "read", data_name, motor_names)
|
||||
ts_utc_name = get_log_name("timestamp_utc", "read", data_name, motor_ids)
|
||||
self.logs[ts_utc_name] = capture_timestamp_utc()
|
||||
|
||||
return values
|
||||
return ids_values
|
||||
|
||||
@abc.abstractmethod
|
||||
def _read(self, data_name: str, motor_names: list[str]):
|
||||
pass
|
||||
def _read(
|
||||
self, motor_ids: list[str], address: int, n_bytes: int, num_retry: int = 1
|
||||
) -> tuple[int, dict[int, int]]:
|
||||
self.reader.clearParam()
|
||||
self.reader.start_address = address
|
||||
self.reader.data_length = n_bytes
|
||||
|
||||
def write(
|
||||
self, data_name: str, values: int | float | np.ndarray, motor_names: str | list[str] | None = None
|
||||
) -> None:
|
||||
# FIXME(aliberts, pkooij): We should probably not have to do this.
|
||||
# Let's try to see if we can do with better comm status handling instead.
|
||||
# self.port_handler.ser.reset_output_buffer()
|
||||
# self.port_handler.ser.reset_input_buffer()
|
||||
|
||||
for idx in motor_ids:
|
||||
self.reader.addParam(idx)
|
||||
|
||||
for _ in range(num_retry):
|
||||
comm = self.reader.txRxPacket()
|
||||
if self._is_comm_success(comm):
|
||||
break
|
||||
|
||||
values = {idx: self.reader.getData(idx, address, n_bytes) for idx in motor_ids}
|
||||
return comm, values
|
||||
|
||||
# TODO(aliberts, pkooij): Implementing something like this could get much faster read times.
|
||||
# Note: this could be at the cost of increase latency between the moment the data is produced by the
|
||||
# motors and the moment it is used by a policy
|
||||
# def _async_read(self, motor_ids: list[str], address: int, n_bytes: int):
|
||||
# self.reader.rxPacket()
|
||||
# self.reader.txPacket()
|
||||
# for idx in motor_ids:
|
||||
# value = self.reader.getData(idx, address, n_bytes)
|
||||
|
||||
def write(self, data_name: str, values_dict: dict[str | int, int], num_retry: int = 1) -> None:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
f"{self.__name__}({self.port}) is not connected. You need to run `{self.__name__}.connect()`."
|
||||
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
|
||||
)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if motor_names is None:
|
||||
motor_names = self.motor_names
|
||||
ids_values = {self.get_safe_id(motor): val for motor, val in values_dict.items()}
|
||||
|
||||
if isinstance(motor_names, str):
|
||||
motor_names = [motor_names]
|
||||
if self._has_different_ctrl_tables:
|
||||
models = [self.idx_to_model(idx) for idx in ids_values]
|
||||
assert_same_address(self.model_ctrl_table, models, data_name)
|
||||
|
||||
if isinstance(values, (int, float, np.integer)):
|
||||
values = [int(values)] * len(motor_names)
|
||||
if data_name in self.calibration_required and self.calibration is not None:
|
||||
ids_values = self.uncalibrate_values(ids_values)
|
||||
|
||||
self._write(data_name, values, motor_names)
|
||||
model = self.idx_to_model(next(iter(ids_values)))
|
||||
addr, n_bytes = self.model_ctrl_table[model][data_name]
|
||||
|
||||
comm = self._write(ids_values, addr, n_bytes, num_retry)
|
||||
if not self._is_comm_success(comm):
|
||||
raise ConnectionError(
|
||||
f"Failed to write {data_name} on port {self.port} for ids {list(ids_values)}:"
|
||||
f"{self.packet_handler.getTxRxResult(comm)}"
|
||||
)
|
||||
|
||||
# log the number of seconds it took to write the data to the motors
|
||||
delta_ts_name = get_log_name("delta_timestamp_s", "write", data_name, motor_names)
|
||||
delta_ts_name = get_log_name("delta_timestamp_s", "write", data_name, list(ids_values))
|
||||
self.logs[delta_ts_name] = time.perf_counter() - start_time
|
||||
|
||||
# TODO(rcadene): should we log the time before sending the write command?
|
||||
# log the utc time when the write has been completed
|
||||
ts_utc_name = get_log_name("timestamp_utc", "write", data_name, motor_names)
|
||||
ts_utc_name = get_log_name("timestamp_utc", "write", data_name, list(ids_values))
|
||||
self.logs[ts_utc_name] = capture_timestamp_utc()
|
||||
|
||||
@abc.abstractmethod
|
||||
def _write(self, data_name: str, values: list[int], motor_names: list[str]) -> None:
|
||||
pass
|
||||
def _write(self, ids_values: dict[int, int], address: int, n_bytes: int, num_retry: int = 1) -> int:
|
||||
self.writer.clearParam()
|
||||
self.writer.start_address = address
|
||||
self.writer.data_length = n_bytes
|
||||
|
||||
for idx, value in ids_values.items():
|
||||
data = self.split_int_bytes(value, n_bytes)
|
||||
self.writer.addParam(idx, data)
|
||||
|
||||
for _ in range(num_retry):
|
||||
comm = self.writer.txPacket()
|
||||
if self._is_comm_success(comm):
|
||||
break
|
||||
|
||||
return comm
|
||||
|
||||
def disconnect(self) -> None:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
f"{self.__name__}({self.port}) is not connected. Try running `{self.__name__}.connect()` first."
|
||||
f"{self.__class__.__name__}('{self.port}') is not connected. Try running `{self.__class__.__name__}.connect()` first."
|
||||
)
|
||||
|
||||
if self.port_handler is not None:
|
||||
self.port_handler.closePort()
|
||||
self.port_handler = None
|
||||
|
||||
self.packet_handler = None
|
||||
self.group_readers = {}
|
||||
self.group_writers = {}
|
||||
self.is_connected = False
|
||||
self.port_handler.closePort()
|
||||
|
||||
def __del__(self):
|
||||
if self.is_connected:
|
||||
|
|
Loading…
Reference in New Issue