From 7c8ab8e2d6597a7d111974b1fe5befe0b6dbbb63 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Mon, 24 Mar 2025 20:46:36 +0100 Subject: [PATCH] Implement feetech broadcast ping --- lerobot/common/motors/feetech/feetech.py | 105 +++++++++++++++++++++-- lerobot/common/motors/feetech/tables.py | 8 +- tests/mocks/mock_feetech.py | 9 +- tests/motors/test_feetech.py | 30 +++---- 4 files changed, 120 insertions(+), 32 deletions(-) diff --git a/lerobot/common/motors/feetech/feetech.py b/lerobot/common/motors/feetech/feetech.py index 78c30045..849136a9 100644 --- a/lerobot/common/motors/feetech/feetech.py +++ b/lerobot/common/motors/feetech/feetech.py @@ -12,14 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging from copy import deepcopy from enum import Enum +from pprint import pformat from ..motors_bus import Motor, MotorsBus from .tables import ( CALIBRATION_REQUIRED, MODEL_BAUDRATE_TABLE, MODEL_CONTROL_TABLE, + MODEL_NUMBER, MODEL_RESOLUTION, ) @@ -27,7 +30,7 @@ PROTOCOL_VERSION = 0 BAUDRATE = 1_000_000 DEFAULT_TIMEOUT_MS = 1000 -MAX_ID_RANGE = 252 +logger = logging.getLogger(__name__) class OperatingMode(Enum): @@ -53,6 +56,7 @@ class FeetechMotorsBus(MotorsBus): model_ctrl_table = deepcopy(MODEL_CONTROL_TABLE) model_resolution_table = deepcopy(MODEL_RESOLUTION) model_baudrate_table = deepcopy(MODEL_BAUDRATE_TABLE) + model_number_table = deepcopy(MODEL_NUMBER) calibration_required = deepcopy(CALIBRATION_REQUIRED) default_timeout = DEFAULT_TIMEOUT_MS @@ -115,8 +119,97 @@ class FeetechMotorsBus(MotorsBus): ] return data - def broadcast_ping( - self, num_retry: int = 0, raise_on_error: bool = False - ) -> dict[int, list[int, int]] | None: - # TODO - raise NotImplementedError + def _broadcast_ping(self) -> tuple[dict[int, int], int]: + import scservo_sdk as scs + + data_list = {} + + status_length = 6 + + rx_length = 0 + wait_length = status_length * scs.MAX_ID + + txpacket = [0] * 6 + + tx_time_per_byte = (1000.0 / self.port_handler.getBaudRate()) * 10.0 + + txpacket[scs.PKT_ID] = scs.BROADCAST_ID + txpacket[scs.PKT_LENGTH] = 2 + txpacket[scs.PKT_INSTRUCTION] = scs.INST_PING + + result = self.packet_handler.txPacket(self.port_handler, txpacket) + if result != scs.COMM_SUCCESS: + self.port_handler.is_using = False + return data_list, result + + # set rx timeout + self.port_handler.setPacketTimeoutMillis((wait_length * tx_time_per_byte) + (3.0 * scs.MAX_ID) + 16.0) + + rxpacket = [] + while True: + rxpacket += self.port_handler.readPort(wait_length - rx_length) + rx_length = len(rxpacket) + + if self.port_handler.isPacketTimeout(): # or rx_length >= wait_length + break + + self.port_handler.is_using = False + + if rx_length == 0: + return data_list, scs.COMM_RX_TIMEOUT + + while True: + if rx_length < status_length: + return data_list, scs.COMM_RX_CORRUPT + + # find packet header + for idx in range(0, (rx_length - 1)): + if (rxpacket[idx] == 0xFF) and (rxpacket[idx + 1] == 0xFF): + break + + if idx == 0: # found at the beginning of the packet + # calculate checksum + checksum = 0 + for idx in range(2, status_length - 1): # except header & checksum + checksum += rxpacket[idx] + + checksum = scs.SCS_LOBYTE(~checksum) + if rxpacket[status_length - 1] == checksum: + result = scs.COMM_SUCCESS + data_list[rxpacket[scs.PKT_ID]] = rxpacket[scs.PKT_ERROR] + + del rxpacket[0:status_length] + rx_length = rx_length - status_length + + if rx_length == 0: + return data_list, result + else: + result = scs.COMM_RX_CORRUPT + # remove header (0xFF 0xFF) + del rxpacket[0:2] + rx_length = rx_length - 2 + else: + # remove unnecessary packets + del rxpacket[0:idx] + rx_length = rx_length - idx + + def broadcast_ping(self, num_retry: int = 0, raise_on_error: bool = False) -> dict[int, str] | None: + for n_try in range(1 + num_retry): + ids_status, comm = self._broadcast_ping() + if self._is_comm_success(comm): + break + logger.debug(f"Broadcast failed on port '{self.port}' ({n_try=})") + logger.debug(self.packet_handler.getRxPacketError(comm)) + + if not self._is_comm_success(comm): + if raise_on_error: + raise ConnectionError + + return ids_status if ids_status else None + + ids_errors = {id_: status for id_, status in ids_status.items() if self._is_error(status)} + if ids_errors: + display_dict = {id_: self.packet_handler.getRxPacketError(err) for id_, err in ids_errors.items()} + logger.error(f"Some motors found returned an error status:\n{pformat(display_dict, indent=4)}") + model_numbers = self.sync_read("Model_Number", list(ids_status), num_retry) + return {id_: self._model_nb_to_model(model_nb) for id_, model_nb in model_numbers.items()} diff --git a/lerobot/common/motors/feetech/tables.py b/lerobot/common/motors/feetech/tables.py index 5de95b61..7e7e7e9d 100644 --- a/lerobot/common/motors/feetech/tables.py +++ b/lerobot/common/motors/feetech/tables.py @@ -2,7 +2,8 @@ # https://docs.google.com/spreadsheets/d/1GVs7W1VS1PqdhA1nW-abeyAHhTUxKUdR/edit?usp=sharing&ouid=116566590112741600240&rtpof=true&sd=true # data_name: (address, size_byte) SCS_SERIES_CONTROL_TABLE = { - "Model": (3, 2), + "Firmware_Version": (0, 2), + "Model_Number": (3, 2), "ID": (5, 1), "Baud_Rate": (6, 1), "Return_Delay": (7, 1), @@ -72,6 +73,11 @@ MODEL_RESOLUTION = { "sts3215": 4096, } +# {model: model_number} +MODEL_NUMBER = { + "sts3215": 777, +} + MODEL_BAUDRATE_TABLE = { "scs_series": SCS_SERIES_BAUDRATE_TABLE, "sts3215": SCS_SERIES_BAUDRATE_TABLE, diff --git a/tests/mocks/mock_feetech.py b/tests/mocks/mock_feetech.py index b1478a42..f938522e 100644 --- a/tests/mocks/mock_feetech.py +++ b/tests/mocks/mock_feetech.py @@ -295,16 +295,13 @@ class MockMotors(MockSerial): return new_stub def build_broadcast_ping_stub( - self, ids_models_firmwares: dict[int, list[int]] | None = None, num_invalid_try: int = 0 + self, ids_models: dict[int, list[int]] | None = None, num_invalid_try: int = 0 ) -> str: ping_request = MockInstructionPacket.ping(scs.BROADCAST_ID) - return_packets = b"".join( - MockStatusPacket.ping(idx, model, firm_ver) - for idx, (model, firm_ver) in ids_models_firmwares.items() - ) + return_packets = b"".join(MockStatusPacket.ping(idx, model) for idx, model in ids_models.items()) ping_response = self._build_send_fn(return_packets, num_invalid_try) - stub_name = "Ping_" + "_".join([str(idx) for idx in ids_models_firmwares]) + stub_name = "Ping_" + "_".join([str(idx) for idx in ids_models]) self.stub( name=stub_name, receive_bytes=ping_request, diff --git a/tests/motors/test_feetech.py b/tests/motors/test_feetech.py index 1b50d045..994e6944 100644 --- a/tests/motors/test_feetech.py +++ b/tests/motors/test_feetech.py @@ -6,7 +6,7 @@ import pytest import scservo_sdk as scs from lerobot.common.motors import CalibrationMode, Motor -from lerobot.common.motors.feetech import FeetechMotorsBus +from lerobot.common.motors.feetech import MODEL_NUMBER, FeetechMotorsBus from tests.mocks.mock_feetech import MockMotors, MockPortHandler @@ -93,15 +93,10 @@ def test_abc_implementation(dummy_motors): @pytest.mark.skip("TODO") -@pytest.mark.parametrize( - "idx, model_nb", - [ - (1, 1190), - (2, 1200), - (3, 1120), - ], -) -def test_ping(idx, model_nb, mock_motors, dummy_motors): +@pytest.mark.parametrize("idx", [1, 2, 3]) +def test_ping(idx, mock_motors, dummy_motors): + expected_model = dummy_motors[f"dummy_{idx}"].model + model_nb = MODEL_NUMBER[expected_model] stub_name = mock_motors.build_ping_stub(idx, model_nb) motors_bus = FeetechMotorsBus( port=mock_motors.port, @@ -111,27 +106,24 @@ def test_ping(idx, model_nb, mock_motors, dummy_motors): ping_model_nb = motors_bus.ping(idx) - assert ping_model_nb == model_nb + assert ping_model_nb == expected_model assert mock_motors.stubs[stub_name].called @pytest.mark.skip("TODO") def test_broadcast_ping(mock_motors, dummy_motors): - expected_pings = { - 1: [1060, 50], - 2: [1120, 30], - 3: [1190, 10], - } - stub_name = mock_motors.build_broadcast_ping_stub(expected_pings) + expected_models = {m.id: m.model for m in dummy_motors.values()} + model_nbs = {id_: MODEL_NUMBER[model] for id_, model in expected_models.items()} + stub_name = mock_motors.build_broadcast_ping_stub(model_nbs) motors_bus = FeetechMotorsBus( port=mock_motors.port, motors=dummy_motors, ) motors_bus.connect() - ping_list = motors_bus.broadcast_ping() + ping_model_nbs = motors_bus.broadcast_ping() - assert ping_list == expected_pings + assert ping_model_nbs == expected_models assert mock_motors.stubs[stub_name].called