diff --git a/tests/mocks/mock_dynamixel.py b/tests/mocks/mock_dynamixel.py index 10f9b56f..2a1cf2b1 100644 --- a/tests/mocks/mock_dynamixel.py +++ b/tests/mocks/mock_dynamixel.py @@ -1,12 +1,14 @@ import abc import random +import threading +import time from typing import Callable import dynamixel_sdk as dxl import serial -from mock_serial import MockSerial +from mock_serial.mock_serial import MockSerial, Stub -from lerobot.common.motors.dynamixel import X_SERIES_CONTROL_TABLE +from lerobot.common.motors.dynamixel import X_SERIES_CONTROL_TABLE, DynamixelMotorsBus # https://emanual.robotis.com/docs/en/dxl/crc/ DXL_CRC_TABLE = [ @@ -245,6 +247,53 @@ class MockInstructionPacket(MockDynamixelPacketv2): length = len(dxl_ids) + 7 return cls.build(dxl_id=dxl.BROADCAST_ID, params=params, length=length, instruct_type="Sync_Read") + @classmethod + def sync_write( + cls, + ids_values: dict[int], + start_address: int, + data_length: int, + ) -> bytes: + """ + Builds a "Sync_Write" broadcast instruction. + (from https://emanual.robotis.com/docs/en/dxl/protocol2/#sync-write-0x83) + + The parameters for Sync_Write (Protocol 2.0) are: + param[0] = start_address L + param[1] = start_address H + param[2] = data_length L + param[3] = data_length H + param[5] = [1st motor] ID + param[5+1] = [1st motor] 1st Byte + param[5+2] = [1st motor] 2nd Byte + ... + param[5+X] = [1st motor] X-th Byte + param[6] = [2nd motor] ID + param[6+1] = [2nd motor] 1st Byte + param[6+2] = [2nd motor] 2nd Byte + ... + param[6+X] = [2nd motor] X-th Byte + + And 'length' = ((number_of_params * 1 + data_length) + 7), where: + +1 is for instruction byte, + +2 is for the address bytes, + +2 is for the length bytes, + +2 is for the CRC at the end. + """ + data = [] + for idx, value in ids_values.items(): + split_value = DynamixelMotorsBus.split_int_bytes(value, data_length) + data += [idx, *split_value] + params = [ + dxl.DXL_LOBYTE(start_address), + dxl.DXL_HIBYTE(start_address), + dxl.DXL_LOBYTE(data_length), + dxl.DXL_HIBYTE(data_length), + *data, + ] + length = len(ids_values) * (1 + data_length) + 7 + return cls.build(dxl_id=dxl.BROADCAST_ID, params=params, length=length, instruct_type="Sync_Write") + class MockStatusPacket(MockDynamixelPacketv2): """ @@ -337,13 +386,38 @@ class MockPortHandler(dxl.PortHandler): return True +class WaitableStub(Stub): + """ + In some situations, a test might be checking if a stub has been called before `MockSerial` thread had time + to read, match, and call the stub. In these situations, the test can fail randomly. + + Use `wait_called()` or `wait_calls()` to block until the stub is called, avoiding race conditions. + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._event = threading.Event() + + def call(self): + self._event.set() + return super().call() + + def wait_called(self, timeout: float = 1.0): + return self._event.wait(timeout) + + def wait_calls(self, min_calls: int = 1, timeout: float = 1.0): + start = time.perf_counter() + while time.perf_counter() - start < timeout: + if self.calls >= min_calls: + return self.calls + time.sleep(0.005) + raise TimeoutError(f"Stub not called {min_calls} times within {timeout} seconds.") + + class MockMotors(MockSerial): """ This class will simulate physical motors by responding with valid status packets upon receiving some instruction packets. It is meant to test MotorsBus classes. - - 'data_name' supported: - - Present_Position """ ctrl_table = X_SERIES_CONTROL_TABLE @@ -352,6 +426,15 @@ class MockMotors(MockSerial): super().__init__() self.open() + @property + def stubs(self) -> dict[str, WaitableStub]: + return super().stubs + + def stub(self, *, name=None, **kwargs): + new_stub = WaitableStub(**kwargs) + self._MockSerial__stubs[name or new_stub.receive_bytes] = new_stub + return new_stub + def build_broadcast_ping_stub( self, ids_models_firmwares: dict[int, list[int]] | None = None, num_invalid_try: int = 0 ) -> str: @@ -388,6 +471,10 @@ class MockMotors(MockSerial): def build_sync_read_stub( self, data_name: str, ids_values: dict[int, int] | None = None, num_invalid_try: int = 0 ) -> str: + """ + 'data_name' supported: + - Present_Position + """ address, length = self.ctrl_table[data_name] sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length) if data_name != "Present_Position": @@ -406,6 +493,22 @@ class MockMotors(MockSerial): ) return stub_name + def build_sync_write_stub( + self, data_name: str, ids_values: dict[int, int] | None = None, num_invalid_try: int = 0 + ) -> str: + address, length = self.ctrl_table[data_name] + sync_read_request = MockInstructionPacket.sync_write(ids_values, address, length) + # if data_name != "Goal_Position": + # raise NotImplementedError + + stub_name = f"Sync_Write_{data_name}_" + "_".join([str(idx) for idx in ids_values]) + self.stub( + name=stub_name, + receive_bytes=sync_read_request, + send_fn=self._build_send_fn(b"", num_invalid_try), + ) + return stub_name + @staticmethod def _build_send_fn(packet: bytes, num_invalid_try: int = 0) -> Callable[[int], bytes]: def send_fn(_call_count: int) -> bytes: diff --git a/tests/motors/test_dynamixel.py b/tests/motors/test_dynamixel.py index 412a4a36..6d08ba00 100644 --- a/tests/motors/test_dynamixel.py +++ b/tests/motors/test_dynamixel.py @@ -233,3 +233,54 @@ def test_read_num_retry(num_retry, num_invalid_try, pos, dummy_motors): expected_calls = min(1 + num_retry, 1 + num_invalid_try) assert mock_motors.stubs[stub_name].calls == expected_calls + + +@pytest.mark.parametrize( + "motors", + [ + [1, 2, 3], + ["dummy_1", "dummy_2", "dummy_3"], + [1, "dummy_2", 3], + ], + ids=["by ids", "by names", "mixed"], +) +def test_write_all_motors(motors, dummy_motors): + mock_motors = MockMotors() + goal_positions = { + 1: 1337, + 2: 42, + 3: 4016, + } + stub_name = mock_motors.build_sync_write_stub("Goal_Position", goal_positions) + motors_bus = DynamixelMotorsBus( + port=mock_motors.port, + motors=dummy_motors, + ) + motors_bus.connect() + + values = dict(zip(motors, goal_positions.values(), strict=True)) + motors_bus.write("Goal_Position", values) + + assert mock_motors.stubs[stub_name].wait_called() + + +@pytest.mark.parametrize( + "data_name, value", + [ + ["Torque_Enable", 0], + ["Torque_Enable", 1], + ], +) +def test_write_all_motors_single_value(data_name, value, dummy_motors): + mock_motors = MockMotors() + values = {m.id: value for m in dummy_motors.values()} + stub_name = mock_motors.build_sync_write_stub(data_name, values) + motors_bus = DynamixelMotorsBus( + port=mock_motors.port, + motors=dummy_motors, + ) + motors_bus.connect() + + motors_bus.write(data_name, value) + + assert mock_motors.stubs[stub_name].wait_called()