From 1d3e1cbdbdba33e370ae66cd15bb33de0f185afa Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Sat, 22 Mar 2025 17:02:01 +0100 Subject: [PATCH] Add feetech write tests --- tests/mocks/mock_feetech.py | 105 ++++++++++++++++++++++++++++++++--- tests/motors/test_feetech.py | 51 +++++++++++++++++ 2 files changed, 149 insertions(+), 7 deletions(-) diff --git a/tests/mocks/mock_feetech.py b/tests/mocks/mock_feetech.py index 7ff994ee..2a10e789 100644 --- a/tests/mocks/mock_feetech.py +++ b/tests/mocks/mock_feetech.py @@ -1,12 +1,14 @@ import abc import random +import threading +import time from typing import Callable import scservo_sdk as scs import serial -from mock_serial import MockSerial +from mock_serial.mock_serial import MockSerial, Stub -from lerobot.common.motors.feetech.tables import SCS_SERIES_CONTROL_TABLE +from lerobot.common.motors.feetech import SCS_SERIES_CONTROL_TABLE, FeetechMotorsBus # https://files.waveshare.com/upload/2/27/Communication_Protocol_User_Manual-EN%28191218-0923%29.pdf INSTRUCTION_TYPES = { @@ -100,12 +102,12 @@ class MockInstructionPacket(MockFeetechPacket): """ Builds a "Sync_Read" broadcast instruction. - The parameters for Sync Read (Protocol 2.0) are: + The parameters for Sync Read are: param[0] = start_address param[1] = data_length param[2+] = motor IDs to read from - And 'length' = (number_of_params + 7), where: + And 'length' = (number_of_params + 4), where: +1 is for instruction byte, +1 is for the address byte, +1 is for the length bytes, @@ -115,6 +117,44 @@ class MockInstructionPacket(MockFeetechPacket): length = len(scs_ids) + 4 return cls.build(scs_id=scs.BROADCAST_ID, params=params, length=length, instruct_type="Sync_Read") + @classmethod + def sync_write( + cls, + ids_values: dict[int], + start_address: int, + data_length: int, + ) -> bytes: + """ + Builds a "Sync_Write" broadcast instruction. + + The parameters for Sync_Write are: + param[0] = start_address + param[1] = data_length + param[2] = [1st motor] ID + param[2+1] = [1st motor] 1st Byte + param[2+2] = [1st motor] 2nd Byte + ... + param[5+X] = [1st motor] X-th Byte + param[6] = [2nd motor] ID + param[6+1] = [2nd motor] 1st Byte + param[6+2] = [2nd motor] 2nd Byte + ... + param[6+X] = [2nd motor] X-th Byte + + And 'length' = ((number_of_params * 1 + data_length) + 4), where: + +1 is for instruction byte, + +1 is for the address byte, + +1 is for the length bytes, + +1 is for the checksum at the end. + """ + data = [] + for idx, value in ids_values.items(): + split_value = FeetechMotorsBus.split_int_bytes(value, data_length) + data += [idx, *split_value] + params = [start_address, data_length, *data] + length = len(ids_values) * (1 + data_length) + 4 + return cls.build(scs_id=scs.BROADCAST_ID, params=params, length=length, instruct_type="Sync_Write") + class MockStatusPacket(MockFeetechPacket): """ @@ -205,13 +245,38 @@ class MockPortHandler(scs.PortHandler): return True +class WaitableStub(Stub): + """ + In some situations, a test might be checking if a stub has been called before `MockSerial` thread had time + to read, match, and call the stub. In these situations, the test can fail randomly. + + Use `wait_called()` or `wait_calls()` to block until the stub is called, avoiding race conditions. + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._event = threading.Event() + + def call(self): + self._event.set() + return super().call() + + def wait_called(self, timeout: float = 1.0): + return self._event.wait(timeout) + + def wait_calls(self, min_calls: int = 1, timeout: float = 1.0): + start = time.perf_counter() + while time.perf_counter() - start < timeout: + if self.calls >= min_calls: + return self.calls + time.sleep(0.005) + raise TimeoutError(f"Stub not called {min_calls} times within {timeout} seconds.") + + class MockMotors(MockSerial): """ This class will simulate physical motors by responding with valid status packets upon receiving some instruction packets. It is meant to test MotorsBus classes. - - 'data_name' supported: - - Present_Position """ ctrl_table = SCS_SERIES_CONTROL_TABLE @@ -220,6 +285,15 @@ class MockMotors(MockSerial): super().__init__() self.open() + @property + def stubs(self) -> dict[str, WaitableStub]: + return super().stubs + + def stub(self, *, name=None, **kwargs): + new_stub = WaitableStub(**kwargs) + self._MockSerial__stubs[name or new_stub.receive_bytes] = new_stub + return new_stub + def build_broadcast_ping_stub( self, ids_models_firmwares: dict[int, list[int]] | None = None, num_invalid_try: int = 0 ) -> str: @@ -256,6 +330,10 @@ class MockMotors(MockSerial): def build_sync_read_stub( self, data_name: str, ids_values: dict[int, int] | None = None, num_invalid_try: int = 0 ) -> str: + """ + 'data_name' supported: + - Present_Position + """ address, length = self.ctrl_table[data_name] sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length) if data_name != "Present_Position": @@ -274,6 +352,19 @@ class MockMotors(MockSerial): ) return stub_name + def build_sync_write_stub( + self, data_name: str, ids_values: dict[int, int] | None = None, num_invalid_try: int = 0 + ) -> str: + address, length = self.ctrl_table[data_name] + sync_read_request = MockInstructionPacket.sync_write(ids_values, address, length) + stub_name = f"Sync_Write_{data_name}_" + "_".join([str(idx) for idx in ids_values]) + self.stub( + name=stub_name, + receive_bytes=sync_read_request, + send_fn=self._build_send_fn(b"", num_invalid_try), + ) + return stub_name + @staticmethod def _build_send_fn(packet: bytes, num_invalid_try: int = 0) -> Callable[[int], bytes]: def send_fn(_call_count: int) -> bytes: diff --git a/tests/motors/test_feetech.py b/tests/motors/test_feetech.py index d4a4397d..698e0adb 100644 --- a/tests/motors/test_feetech.py +++ b/tests/motors/test_feetech.py @@ -235,3 +235,54 @@ def test_read_num_retry(num_retry, num_invalid_try, pos, dummy_motors): expected_calls = min(1 + num_retry, 1 + num_invalid_try) assert mock_motors.stubs[stub_name].calls == expected_calls + + +@pytest.mark.parametrize( + "motors", + [ + [1, 2, 3], + ["dummy_1", "dummy_2", "dummy_3"], + [1, "dummy_2", 3], + ], + ids=["by ids", "by names", "mixed"], +) +def test_write_all_motors(motors, dummy_motors): + mock_motors = MockMotors() + goal_positions = { + 1: 1337, + 2: 42, + 3: 4016, + } + stub_name = mock_motors.build_sync_write_stub("Goal_Position", goal_positions) + motors_bus = FeetechMotorsBus( + port=mock_motors.port, + motors=dummy_motors, + ) + motors_bus.connect() + + values = dict(zip(motors, goal_positions.values(), strict=True)) + motors_bus.write("Goal_Position", values) + + assert mock_motors.stubs[stub_name].wait_called() + + +@pytest.mark.parametrize( + "data_name, value", + [ + ["Torque_Enable", 0], + ["Torque_Enable", 1], + ], +) +def test_write_all_motors_single_value(data_name, value, dummy_motors): + mock_motors = MockMotors() + values = {m.id: value for m in dummy_motors.values()} + stub_name = mock_motors.build_sync_write_stub(data_name, values) + motors_bus = FeetechMotorsBus( + port=mock_motors.port, + motors=dummy_motors, + ) + motors_bus.connect() + + motors_bus.write(data_name, value) + + assert mock_motors.stubs[stub_name].wait_called()