lerobot/lerobot/common/motors/motors_bus.py

693 lines
26 KiB
Python
Raw Normal View History

#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa: N802
2025-03-22 18:11:39 +08:00
# This noqa is for the Protocols classes: PortHandler, PacketHandler GroupSyncRead/Write
# TODO(aliberts): Add block noqa when feature below is available
# https://github.com/astral-sh/ruff/issues/3711
2025-03-04 01:18:24 +08:00
import abc
2025-03-22 17:33:42 +08:00
import logging
2025-03-21 19:13:44 +08:00
from dataclasses import dataclass
2025-03-16 04:33:45 +08:00
from enum import Enum
2025-03-20 01:44:05 +08:00
from functools import cached_property
from pprint import pformat
from typing import Protocol, TypeAlias, overload
2025-03-20 01:44:05 +08:00
import serial
from deepdiff import DeepDiff
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
NameOrID: TypeAlias = str | int
Value: TypeAlias = int | float
2025-03-21 19:13:44 +08:00
MAX_ID_RANGE = 252
2025-03-22 17:33:42 +08:00
logger = logging.getLogger(__name__)
2025-03-25 18:12:26 +08:00
def get_ctrl_table(model_ctrl_table: dict[str, dict], model: str) -> dict[str, tuple[int, int]]:
2025-03-25 19:11:56 +08:00
ctrl_table = model_ctrl_table.get(model)
if ctrl_table is None:
raise KeyError(f"Control table for {model=} not found.")
return ctrl_table
2025-03-25 18:12:26 +08:00
def get_address(model_ctrl_table: dict[str, dict], model: str, data_name: str) -> tuple[int, int]:
ctrl_table = get_ctrl_table(model_ctrl_table, model)
2025-03-25 19:11:56 +08:00
addr_bytes = ctrl_table.get(data_name)
if addr_bytes is None:
raise KeyError(f"Address for '{data_name}' not found in {model} control table.")
return addr_bytes
2025-03-25 18:12:26 +08:00
2025-03-16 04:33:45 +08:00
def assert_same_address(model_ctrl_table: dict[str, dict], motor_models: list[str], data_name: str) -> None:
all_addr = []
all_bytes = []
for model in motor_models:
2025-03-25 18:12:26 +08:00
addr, bytes = get_address(model_ctrl_table, model, data_name)
all_addr.append(addr)
all_bytes.append(bytes)
if len(set(all_addr)) != 1:
raise NotImplementedError(
2025-03-16 04:33:45 +08:00
f"At least two motor models use a different address for `data_name`='{data_name}'"
2025-03-25 19:11:56 +08:00
f"({list(zip(motor_models, all_addr, strict=False))})."
)
if len(set(all_bytes)) != 1:
raise NotImplementedError(
2025-03-16 04:33:45 +08:00
f"At least two motor models use a different bytes representation for `data_name`='{data_name}'"
2025-03-25 19:11:56 +08:00
f"({list(zip(motor_models, all_bytes, strict=False))})."
)
class MotorNormMode(Enum):
DEGREE = 0
2025-03-23 17:20:08 +08:00
RANGE_0_100 = 1
RANGE_M100_100 = 2
VELOCITY = 3
@dataclass
class Motor:
id: int
model: str
norm_mode: MotorNormMode
class JointOutOfRangeError(Exception):
def __init__(self, message="Joint is out of range"):
self.message = message
super().__init__(self.message)
class PortHandler(Protocol):
def __init__(self, port_name):
self.is_open: bool
self.baudrate: int
self.packet_start_time: float
self.packet_timeout: float
self.tx_time_per_byte: float
self.is_using: bool
self.port_name: str
2025-03-20 01:44:05 +08:00
self.ser: serial.Serial
def openPort(self): ...
def closePort(self): ...
def clearPort(self): ...
def setPortName(self, port_name): ...
def getPortName(self): ...
def setBaudRate(self, baudrate): ...
def getBaudRate(self): ...
def getBytesAvailable(self): ...
def readPort(self, length): ...
def writePort(self, packet): ...
def setPacketTimeout(self, packet_length): ...
def setPacketTimeoutMillis(self, msec): ...
def isPacketTimeout(self): ...
def getCurrentTime(self): ...
def getTimeSinceStart(self): ...
def setupPort(self, cflag_baud): ...
def getCFlagBaud(self, baudrate): ...
class PacketHandler(Protocol):
def getTxRxResult(self, result): ...
def getRxPacketError(self, error): ...
def txPacket(self, port, txpacket): ...
def rxPacket(self, port): ...
def txRxPacket(self, port, txpacket): ...
def ping(self, port, id): ...
def action(self, port, id): ...
def readTx(self, port, id, address, length): ...
def readRx(self, port, id, length): ...
def readTxRx(self, port, id, address, length): ...
def read1ByteTx(self, port, id, address): ...
def read1ByteRx(self, port, id): ...
def read1ByteTxRx(self, port, id, address): ...
def read2ByteTx(self, port, id, address): ...
def read2ByteRx(self, port, id): ...
def read2ByteTxRx(self, port, id, address): ...
def read4ByteTx(self, port, id, address): ...
def read4ByteRx(self, port, id): ...
def read4ByteTxRx(self, port, id, address): ...
def writeTxOnly(self, port, id, address, length, data): ...
def writeTxRx(self, port, id, address, length, data): ...
def write1ByteTxOnly(self, port, id, address, data): ...
def write1ByteTxRx(self, port, id, address, data): ...
def write2ByteTxOnly(self, port, id, address, data): ...
def write2ByteTxRx(self, port, id, address, data): ...
def write4ByteTxOnly(self, port, id, address, data): ...
def write4ByteTxRx(self, port, id, address, data): ...
def regWriteTxOnly(self, port, id, address, length, data): ...
def regWriteTxRx(self, port, id, address, length, data): ...
def syncReadTx(self, port, start_address, data_length, param, param_length): ...
def syncWriteTxOnly(self, port, start_address, data_length, param, param_length): ...
2025-03-04 01:18:24 +08:00
2025-03-20 01:44:05 +08:00
class GroupSyncRead(Protocol):
def __init__(self, port, ph, start_address, data_length):
self.port: str
self.ph: PortHandler
self.start_address: int
self.data_length: int
self.last_result: bool
self.is_param_changed: bool
self.param: list
self.data_dict: dict
def makeParam(self): ...
def addParam(self, id): ...
def removeParam(self, id): ...
def clearParam(self): ...
def txPacket(self): ...
def rxPacket(self): ...
def txRxPacket(self): ...
def isAvailable(self, id, address, data_length): ...
def getData(self, id, address, data_length): ...
class GroupSyncWrite(Protocol):
def __init__(self, port, ph, start_address, data_length):
self.port: str
self.ph: PortHandler
self.start_address: int
self.data_length: int
self.is_param_changed: bool
self.param: list
self.data_dict: dict
def makeParam(self): ...
def addParam(self, id, data): ...
def removeParam(self, id): ...
def changeParam(self, id, data): ...
def clearParam(self): ...
def txPacket(self): ...
2025-03-04 01:18:24 +08:00
class MotorsBus(abc.ABC):
2025-03-16 04:42:54 +08:00
"""The main LeRobot class for implementing motors buses.
There are currently two implementations of this abstract class:
- DynamixelMotorsBus
- FeetechMotorsBus
Note: This class may evolve in the future should we add support for other manufacturers SDKs.
A MotorsBus allows to efficiently read and write to the attached motors.
2025-03-20 01:44:05 +08:00
It represents several motors daisy-chained together and connected through a serial port.
2025-03-16 04:42:54 +08:00
A MotorsBus subclass instance requires a port (e.g. `FeetechMotorsBus(port="/dev/tty.usbmodem575E0031751"`)).
To find the port, you can run our utility script:
```bash
python lerobot/scripts/find_motors_bus_port.py
>>> Finding all available ports for the MotorsBus.
>>> ['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
>>> Remove the usb cable from your MotorsBus and press Enter when done.
>>> The port of this MotorsBus is /dev/tty.usbmodem575E0031751.
>>> Reconnect the usb cable.
```
Example of usage for 1 Feetech sts3215 motor connected to the bus:
```python
motors_bus = FeetechMotorsBus(
port="/dev/tty.usbmodem575E0031751",
motors={"gripper": (6, "sts3215")},
)
motors_bus.connect()
position = motors_bus.read("Present_Position")
# Move from a few motor steps as an example
few_steps = 30
motors_bus.write("Goal_Position", position + few_steps)
# When done, properly disconnect the port using
motors_bus.disconnect()
```
"""
2025-03-04 01:18:24 +08:00
model_ctrl_table: dict[str, dict]
model_resolution_table: dict[str, int]
2025-03-15 20:11:56 +08:00
model_baudrate_table: dict[str, dict]
2025-03-25 03:42:43 +08:00
model_number_table: dict[str, int]
2025-03-20 01:44:05 +08:00
calibration_required: list[str]
default_timeout: int
2025-03-04 01:18:24 +08:00
def __init__(
self,
port: str,
2025-03-21 19:13:44 +08:00
motors: dict[str, Motor],
2025-03-04 01:18:24 +08:00
):
self.port = port
2025-03-04 01:18:24 +08:00
self.motors = motors
2025-03-20 01:44:05 +08:00
self._validate_motors()
2025-03-20 01:44:05 +08:00
self.port_handler: PortHandler
self.packet_handler: PacketHandler
self.sync_reader: GroupSyncRead
self.sync_writer: GroupSyncWrite
2025-03-23 23:52:29 +08:00
self._comm_success: int
2025-03-24 18:57:12 +08:00
self._no_error: int
self.calibration = None
2025-03-20 01:44:05 +08:00
2025-03-24 19:16:54 +08:00
self._id_to_model_dict = {m.id: m.model for m in self.motors.values()}
self._id_to_name_dict = {m.id: name for name, m in self.motors.items()}
2025-03-25 03:42:43 +08:00
self._model_nb_to_model_dict = {v: k for k, v in self.model_number_table.items()}
2025-03-04 01:18:24 +08:00
def __len__(self):
return len(self.motors)
2025-03-20 01:44:05 +08:00
def __repr__(self):
return (
f"{self.__class__.__name__}(\n"
f" Port: '{self.port}',\n"
2025-03-25 18:12:26 +08:00
f" Motors: \n{pformat(self.motors, indent=8, sort_dicts=False)},\n"
2025-03-20 01:44:05 +08:00
")',\n"
)
@cached_property
def _has_different_ctrl_tables(self) -> bool:
2025-03-21 19:13:44 +08:00
if len(self.models) < 2:
2025-03-20 01:44:05 +08:00
return False
2025-03-21 19:13:44 +08:00
first_table = self.model_ctrl_table[self.models[0]]
2025-03-25 18:12:26 +08:00
return any(
DeepDiff(first_table, get_ctrl_table(self.model_ctrl_table, model)) for model in self.models[1:]
)
2025-03-20 01:44:05 +08:00
@cached_property
2025-03-21 19:13:44 +08:00
def names(self) -> list[str]:
return list(self.motors)
2025-03-20 01:44:05 +08:00
@cached_property
2025-03-21 19:13:44 +08:00
def models(self) -> list[str]:
return [m.model for m in self.motors.values()]
2025-03-20 01:44:05 +08:00
@cached_property
2025-03-21 19:13:44 +08:00
def ids(self) -> list[int]:
return [m.id for m in self.motors.values()]
2025-03-25 03:42:43 +08:00
def _model_nb_to_model(self, motor_nb: int) -> str:
return self._model_nb_to_model_dict[motor_nb]
2025-03-24 18:57:12 +08:00
def _id_to_model(self, motor_id: int) -> str:
2025-03-24 19:16:54 +08:00
return self._id_to_model_dict[motor_id]
2025-03-24 18:57:12 +08:00
def _id_to_name(self, motor_id: int) -> str:
2025-03-24 19:16:54 +08:00
return self._id_to_name_dict[motor_id]
2025-03-24 18:57:12 +08:00
def _get_motor_id(self, motor: NameOrID) -> int:
if isinstance(motor, str):
return self.motors[motor].id
elif isinstance(motor, int):
return motor
else:
raise TypeError(f"'{motor}' should be int, str.")
2025-03-20 01:44:05 +08:00
def _validate_motors(self) -> None:
2025-03-21 19:13:44 +08:00
if len(self.ids) != len(set(self.ids)):
2025-03-25 18:12:26 +08:00
raise ValueError(f"Some motors have the same id!\n{self}")
2025-03-20 01:44:05 +08:00
2025-03-25 18:12:26 +08:00
# Ensure ctrl table available for all models
for model in self.models:
get_ctrl_table(self.model_ctrl_table, model)
2025-03-20 01:44:05 +08:00
2025-03-24 18:57:12 +08:00
def _is_comm_success(self, comm: int) -> bool:
return comm == self._comm_success
def _is_error(self, error: int) -> bool:
return error != self._no_error
2025-03-25 18:12:26 +08:00
def _assert_motors_exist(self) -> None:
found_models = self.broadcast_ping()
expected_models = {m.id: self.model_number_table[m.model] for m in self.motors.values()}
if not set(found_models) == set(self.ids):
raise RuntimeError(
f"{self.__class__.__name__} is supposed to have these motors: ({{id: model_nb}})"
f"\n{pformat(expected_models, indent=4, sort_dicts=False)}\n"
f"But it found these motors on port '{self.port}':"
f"\n{pformat(found_models, indent=4, sort_dicts=False)}\n"
)
for id_, model in expected_models.items():
if found_models[id_] != model:
raise RuntimeError(
f"Motor '{self._id_to_name(id_)}' (id={id_}) is supposed to be of model_number={model} "
f"('{self._id_to_model(id_)}') but a model_number={found_models[id_]} "
"was found instead for that id."
)
2025-03-20 01:44:05 +08:00
@property
def is_connected(self) -> bool:
return self.port_handler.is_open
2025-03-25 18:12:26 +08:00
def connect(self, assert_motors_exist: bool = True) -> None:
if self.is_connected:
raise DeviceAlreadyConnectedError(
2025-03-20 01:44:05 +08:00
f"{self.__class__.__name__}('{self.port}') is already connected. Do not call `{self.__class__.__name__}.connect()` twice."
)
try:
if not self.port_handler.openPort():
raise OSError(f"Failed to open port '{self.port}'.")
2025-03-25 18:12:26 +08:00
elif assert_motors_exist:
self._assert_motors_exist()
2025-03-20 01:44:05 +08:00
except (FileNotFoundError, OSError, serial.SerialException) as e:
2025-03-22 18:11:39 +08:00
logger.error(
2025-03-20 01:44:05 +08:00
f"\nCould not connect on port '{self.port}'. Make sure you are using the correct port."
"\nTry running `python lerobot/scripts/find_motors_bus_port.py`\n"
)
2025-03-20 01:44:05 +08:00
raise e
2025-03-20 01:44:05 +08:00
self.set_timeout()
2025-03-22 18:11:39 +08:00
logger.debug(f"{self.__class__.__name__} connected.")
2025-03-04 01:18:24 +08:00
@abc.abstractmethod
def _configure_motors(self) -> None:
pass
2025-03-20 01:44:05 +08:00
def set_timeout(self, timeout_ms: int | None = None):
timeout_ms = timeout_ms if timeout_ms is not None else self.default_timeout
self.port_handler.setPacketTimeoutMillis(timeout_ms)
2025-03-04 01:18:24 +08:00
2025-03-24 18:57:12 +08:00
def get_baudrate(self) -> int:
return self.port_handler.getBaudRate()
def set_baudrate(self, baudrate: int) -> None:
present_bus_baudrate = self.port_handler.getBaudRate()
if present_bus_baudrate != baudrate:
logger.info(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.")
self.port_handler.setBaudRate(baudrate)
if self.port_handler.getBaudRate() != baudrate:
raise OSError("Failed to write bus baud rate.")
2025-03-20 01:44:05 +08:00
@property
def are_motors_configured(self) -> bool:
"""
Only check the motor indices and not baudrate, since if the motor baudrates are incorrect, a
ConnectionError will be raised anyway.
"""
try:
2025-03-20 01:44:05 +08:00
# TODO(aliberts): use ping instead
return (self.ids == self.sync_read("ID")).all()
except ConnectionError as e:
2025-03-22 18:11:39 +08:00
logger.error(e)
return False
2025-03-20 01:44:05 +08:00
@abc.abstractmethod
2025-03-24 18:57:12 +08:00
def _calibrate_values(self, ids_values: dict[int, int]) -> dict[int, float]:
2025-03-20 01:44:05 +08:00
pass
@abc.abstractmethod
2025-03-24 18:57:12 +08:00
def _uncalibrate_values(self, ids_values: dict[int, float]) -> dict[int, int]:
2025-03-20 01:44:05 +08:00
pass
2025-03-20 01:44:05 +08:00
@staticmethod
2025-03-04 01:18:24 +08:00
@abc.abstractmethod
2025-03-24 18:57:12 +08:00
def _split_int_to_bytes(value: int, n_bytes: int) -> list[int]:
2025-03-20 01:44:05 +08:00
"""
Splits an unsigned integer into a list of bytes in little-endian order.
This function extracts the individual bytes of an integer based on the
specified number of bytes (`n_bytes`). The output is a list of integers,
each representing a byte (0-255).
**Byte order:** The function returns bytes in **little-endian format**,
meaning the least significant byte (LSB) comes first.
Args:
value (int): The unsigned integer to be converted into a byte list. Must be within
the valid range for the specified `n_bytes`.
n_bytes (int): The number of bytes to use for conversion. Supported values:
- 1 (for values 0 to 255)
- 2 (for values 0 to 65,535)
- 4 (for values 0 to 4,294,967,295)
Raises:
ValueError: If `value` is negative or exceeds the maximum allowed for `n_bytes`.
NotImplementedError: If `n_bytes` is not 1, 2, or 4.
Returns:
list[int]: A list of integers, each representing a byte in **little-endian order**.
Examples:
>>> split_int_bytes(0x12, 1)
[18]
>>> split_int_bytes(0x1234, 2)
[52, 18] # 0x1234 → 0x34 0x12 (little-endian)
>>> split_int_bytes(0x12345678, 4)
[120, 86, 52, 18] # 0x12345678 → 0x78 0x56 0x34 0x12
"""
2025-03-04 01:18:24 +08:00
pass
2025-03-25 18:12:26 +08:00
def ping(self, motor: NameOrID, num_retry: int = 0, raise_on_error: bool = False) -> int | None:
2025-03-25 03:58:56 +08:00
id_ = self._get_motor_id(motor)
for n_try in range(1 + num_retry):
2025-03-25 03:58:56 +08:00
model_number, comm, error = self.packet_handler.ping(self.port_handler, id_)
if self._is_comm_success(comm):
2025-03-25 03:42:43 +08:00
break
2025-03-25 03:58:56 +08:00
logger.debug(f"ping failed for {id_=}: {n_try=} got {comm=} {error=}")
2025-03-25 03:42:43 +08:00
if not self._is_comm_success(comm):
if raise_on_error:
2025-03-25 04:13:26 +08:00
raise ConnectionError(self.packet_handler.getRxPacketError(comm))
2025-03-25 03:42:43 +08:00
else:
return
if self._is_error(error):
if raise_on_error:
2025-03-25 04:13:26 +08:00
raise RuntimeError(self.packet_handler.getTxRxResult(comm))
2025-03-25 03:42:43 +08:00
else:
return
2025-03-25 18:12:26 +08:00
return model_number
@abc.abstractmethod
def broadcast_ping(
self, num_retry: int = 0, raise_on_error: bool = False
2025-03-25 03:42:43 +08:00
) -> dict[int, list[int, str]] | None:
pass
@overload
2025-03-25 03:56:58 +08:00
def sync_read(
2025-03-25 18:12:26 +08:00
self, data_name: str, motors: None = ..., *, raw_values: bool = ..., num_retry: int = ...
2025-03-25 03:56:58 +08:00
) -> dict[str, Value]: ...
@overload
def sync_read(
2025-03-25 18:12:26 +08:00
self,
data_name: str,
motors: NameOrID | list[NameOrID],
*,
raw_values: bool = ...,
num_retry: int = ...,
2025-03-22 18:11:39 +08:00
) -> dict[NameOrID, Value]: ...
def sync_read(
2025-03-25 03:56:58 +08:00
self,
data_name: str,
motors: NameOrID | list[NameOrID] | None = None,
2025-03-25 18:12:26 +08:00
*,
2025-03-25 03:56:58 +08:00
raw_values: bool = False,
num_retry: int = 0,
) -> dict[NameOrID, Value]:
if not self.is_connected:
raise DeviceNotConnectedError(
2025-03-20 01:44:05 +08:00
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
)
id_key_map: dict[int, NameOrID] = {}
2025-03-20 01:44:05 +08:00
if motors is None:
id_key_map = {m.id: name for name, m in self.motors.items()}
elif isinstance(motors, (str, int)):
2025-03-24 18:57:12 +08:00
id_key_map = {self._get_motor_id(motors): motors}
elif isinstance(motors, list):
2025-03-24 18:57:12 +08:00
id_key_map = {self._get_motor_id(m): m for m in motors}
else:
raise TypeError(motors)
2025-03-20 01:44:05 +08:00
motor_ids = list(id_key_map)
2025-03-20 01:44:05 +08:00
2025-03-25 18:12:26 +08:00
comm, ids_values = self._sync_read(data_name, motor_ids, num_retry=num_retry)
2025-03-20 01:44:05 +08:00
if not self._is_comm_success(comm):
raise ConnectionError(
f"Failed to sync read '{data_name}' on {motor_ids=} after {num_retry + 1} tries."
2025-03-20 01:44:05 +08:00
f"{self.packet_handler.getTxRxResult(comm)}"
)
2025-03-25 03:56:58 +08:00
if not raw_values and data_name in self.calibration_required and self.calibration is not None:
2025-03-24 18:57:12 +08:00
ids_values = self._calibrate_values(ids_values)
2025-03-25 03:58:56 +08:00
return {id_key_map[id_]: val for id_, val in ids_values.items()}
2025-03-20 01:44:05 +08:00
def _sync_read(
2025-03-25 18:12:26 +08:00
self, data_name: str, motor_ids: list[str], model: str | None = None, num_retry: int = 0
2025-03-20 01:44:05 +08:00
) -> tuple[int, dict[int, int]]:
if self._has_different_ctrl_tables:
2025-03-25 03:58:56 +08:00
models = [self._id_to_model(id_) for id_ in motor_ids]
assert_same_address(self.model_ctrl_table, models, data_name)
2025-03-25 18:12:26 +08:00
model = self._id_to_model(next(iter(motor_ids))) if model is None else model
addr, n_bytes = get_address(self.model_ctrl_table, model, data_name)
self._setup_sync_reader(motor_ids, addr, n_bytes)
2025-03-20 01:44:05 +08:00
# FIXME(aliberts, pkooij): We should probably not have to do this.
# Let's try to see if we can do with better comm status handling instead.
# self.port_handler.ser.reset_output_buffer()
# self.port_handler.ser.reset_input_buffer()
2025-03-22 17:33:42 +08:00
for n_try in range(1 + num_retry):
comm = self.sync_reader.txRxPacket()
2025-03-20 01:44:05 +08:00
if self._is_comm_success(comm):
break
logger.debug(f"Failed to sync read '{data_name}' ({addr=} {n_bytes=}) on {motor_ids=} ({n_try=})")
logger.debug(self.packet_handler.getRxPacketError(comm))
2025-03-20 01:44:05 +08:00
2025-03-25 03:58:56 +08:00
values = {id_: self.sync_reader.getData(id_, addr, n_bytes) for id_ in motor_ids}
2025-03-20 01:44:05 +08:00
return comm, values
def _setup_sync_reader(self, motor_ids: list[str], addr: int, n_bytes: int) -> None:
self.sync_reader.clearParam()
self.sync_reader.start_address = addr
self.sync_reader.data_length = n_bytes
2025-03-25 03:58:56 +08:00
for id_ in motor_ids:
self.sync_reader.addParam(id_)
# TODO(aliberts, pkooij): Implementing something like this could get even much faster read times if need be.
# Would have to handle the logic of checking if a packet has been sent previously though but doable.
# This could be at the cost of increase latency between the moment the data is produced by the motors and
# the moment it is used by a policy.
2025-03-20 01:44:05 +08:00
# def _async_read(self, motor_ids: list[str], address: int, n_bytes: int):
2025-03-25 22:37:18 +08:00
# if self.sync_reader.start_address != address or self.sync_reader.data_length != n_bytes or ...:
# self._setup_sync_reader(motor_ids, address, n_bytes)
# else:
# self.sync_reader.rxPacket()
# self.sync_reader.txPacket()
2025-03-25 03:58:56 +08:00
# for id_ in motor_ids:
2025-03-25 22:37:18 +08:00
# value = self.sync_reader.getData(id_, address, n_bytes)
2025-03-20 01:44:05 +08:00
2025-03-25 03:56:58 +08:00
def sync_write(
self,
data_name: str,
values: Value | dict[NameOrID, Value],
2025-03-25 18:12:26 +08:00
*,
2025-03-25 03:56:58 +08:00
raw_values: bool = False,
num_retry: int = 0,
) -> None:
if not self.is_connected:
raise DeviceNotConnectedError(
2025-03-20 01:44:05 +08:00
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
)
2025-03-21 19:31:41 +08:00
if isinstance(values, int):
ids_values = {id_: values for id_ in self.ids}
elif isinstance(values, dict):
2025-03-24 18:57:12 +08:00
ids_values = {self._get_motor_id(motor): val for motor, val in values.items()}
2025-03-21 19:31:41 +08:00
else:
raise ValueError(f"'values' is expected to be a single value or a dict. Got {values}")
2025-03-20 01:44:05 +08:00
2025-03-25 03:56:58 +08:00
if not raw_values and data_name in self.calibration_required and self.calibration is not None:
2025-03-24 18:57:12 +08:00
ids_values = self._uncalibrate_values(ids_values)
2025-03-25 18:12:26 +08:00
comm = self._sync_write(data_name, ids_values, num_retry=num_retry)
if not self._is_comm_success(comm):
raise ConnectionError(
f"Failed to sync write '{data_name}' with {ids_values=} after {num_retry + 1} tries."
f"\n{self.packet_handler.getTxRxResult(comm)}"
)
def _sync_write(self, data_name: str, ids_values: dict[int, int], num_retry: int = 0) -> int:
2025-03-20 01:44:05 +08:00
if self._has_different_ctrl_tables:
2025-03-25 03:58:56 +08:00
models = [self._id_to_model(id_) for id_ in ids_values]
2025-03-20 01:44:05 +08:00
assert_same_address(self.model_ctrl_table, models, data_name)
2025-03-24 18:57:12 +08:00
model = self._id_to_model(next(iter(ids_values)))
2025-03-25 18:12:26 +08:00
addr, n_bytes = get_address(self.model_ctrl_table, model, data_name)
self._setup_sync_writer(ids_values, addr, n_bytes)
for n_try in range(1 + num_retry):
comm = self.sync_writer.txPacket()
if self._is_comm_success(comm):
break
logger.debug(
f"Failed to sync write '{data_name}' ({addr=} {n_bytes=}) with {ids_values=} ({n_try=})"
2025-03-20 01:44:05 +08:00
)
logger.debug(self.packet_handler.getRxPacketError(comm))
return comm
2025-03-20 01:44:05 +08:00
def _setup_sync_writer(self, ids_values: dict[int, int], addr: int, n_bytes: int) -> None:
self.sync_writer.clearParam()
self.sync_writer.start_address = addr
self.sync_writer.data_length = n_bytes
2025-03-25 03:58:56 +08:00
for id_, value in ids_values.items():
2025-03-24 18:57:12 +08:00
data = self._split_int_to_bytes(value, n_bytes)
2025-03-25 03:58:56 +08:00
self.sync_writer.addParam(id_, data)
2025-03-25 03:56:58 +08:00
def write(
2025-03-25 18:12:26 +08:00
self, data_name: str, motor: NameOrID, value: Value, *, raw_value: bool = False, num_retry: int = 0
2025-03-25 03:56:58 +08:00
) -> None:
if not self.is_connected:
raise DeviceNotConnectedError(
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
)
2025-03-25 03:58:56 +08:00
id_ = self._get_motor_id(motor)
2025-03-25 03:56:58 +08:00
if not raw_value and data_name in self.calibration_required and self.calibration is not None:
2025-03-25 03:58:56 +08:00
id_value = self._uncalibrate_values({id_: value})
value = id_value[id_]
2025-03-25 18:12:26 +08:00
comm, error = self._write(data_name, id_, value, num_retry=num_retry)
if not self._is_comm_success(comm):
raise ConnectionError(
2025-03-25 03:58:56 +08:00
f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries."
f"\n{self.packet_handler.getTxRxResult(comm)}"
)
elif self._is_error(error):
raise RuntimeError(
2025-03-25 03:58:56 +08:00
f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries."
f"\n{self.packet_handler.getRxPacketError(error)}"
)
def _write(self, data_name: str, motor_id: int, value: int, num_retry: int = 0) -> tuple[int, int]:
2025-03-24 18:57:12 +08:00
model = self._id_to_model(motor_id)
2025-03-25 18:12:26 +08:00
addr, n_bytes = get_address(self.model_ctrl_table, model, data_name)
2025-03-24 18:57:12 +08:00
data = self._split_int_to_bytes(value, n_bytes)
2025-03-20 01:44:05 +08:00
2025-03-22 17:33:42 +08:00
for n_try in range(1 + num_retry):
comm, error = self.packet_handler.writeTxRx(self.port_handler, motor_id, addr, n_bytes, data)
2025-03-20 01:44:05 +08:00
if self._is_comm_success(comm):
break
logger.debug(
f"Failed to write '{data_name}' ({addr=} {n_bytes=}) on {motor_id=} with '{value}' ({n_try=})"
)
logger.debug(self.packet_handler.getRxPacketError(comm))
2025-03-20 01:44:05 +08:00
return comm, error
2025-03-04 01:18:24 +08:00
2025-03-16 04:33:45 +08:00
def disconnect(self) -> None:
if not self.is_connected:
raise DeviceNotConnectedError(
2025-03-20 01:44:05 +08:00
f"{self.__class__.__name__}('{self.port}') is not connected. Try running `{self.__class__.__name__}.connect()` first."
)
2025-03-20 01:44:05 +08:00
self.port_handler.closePort()
2025-03-22 18:11:39 +08:00
logger.debug(f"{self.__class__.__name__} disconnected.")