From 2c1bb766ffd3836e7ec5d87075952db023d68ff4 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Thu, 20 Mar 2025 09:40:58 +0100 Subject: [PATCH] Refactor MockMotors, add return values --- tests/mocks/mock_dynamixel.py | 81 ++++++++++++++++++++-------------- tests/motors/test_dynamixel.py | 63 ++++++++++++++++---------- 2 files changed, 86 insertions(+), 58 deletions(-) diff --git a/tests/mocks/mock_dynamixel.py b/tests/mocks/mock_dynamixel.py index 5549cf18..cccc3c20 100644 --- a/tests/mocks/mock_dynamixel.py +++ b/tests/mocks/mock_dynamixel.py @@ -288,59 +288,72 @@ class MockPortHandler(dxl.PortHandler): class MockMotors(MockSerial): + """ + This class will simulate physical motors by responding with valid status packets upon receiving some + instruction packets. It is meant to test MotorsBus classes. + + 'data_name' supported: + - Present_Position + """ + ctrl_table = X_SERIES_CONTROL_TABLE - def __init__(self, dlx_ids: list[int], default_stubs: bool = True): + def __init__(self, dlx_ids: list[int]): super().__init__() self._ids = dlx_ids self.open() - if default_stubs: - self._create_stubs("Present_Position") - - def _create_stubs(self, data_name: str): + def build_single_motor_stubs( + self, data_name: str, return_value: int | None = None, num_invalid_try: int | None = None + ) -> None: address, length = self.ctrl_table[data_name] - - # sync read all motors - sync_read_request_all = MockInstructionPacket.sync_read(self._ids, address, length) - sync_read_response_all = self._create_present_pos_send_fn(self._ids, data_name) - self.stub( - name=f"SyncRead_{data_name}_all", - receive_bytes=sync_read_request_all, - send_fn=sync_read_response_all, - ) - - # sync read single motors for idx in self._ids: - sync_read_request_single = MockInstructionPacket.sync_read([idx], address, length) - sync_read_response_single = self._create_present_pos_send_fn([idx], data_name) + if data_name == "Present_Position": + sync_read_request_single = MockInstructionPacket.sync_read([idx], address, length) + sync_read_response_single = self._build_present_pos_send_fn( + [idx], [return_value], num_invalid_try + ) + else: + raise NotImplementedError # TODO(aliberts): add ping? + self.stub( name=f"SyncRead_{data_name}_{idx}", receive_bytes=sync_read_request_single, send_fn=sync_read_response_single, ) - def _create_present_pos_send_fn( - self, dxl_ids: list[int], data_name: str, num_invalid_try: int | None = None + def build_all_motors_stub( + self, data_name: str, return_values: list[int] | None = None, num_invalid_try: int | None = None + ) -> None: + address, length = self.ctrl_table[data_name] + if data_name == "Present_Position": + sync_read_request_all = MockInstructionPacket.sync_read(self._ids, address, length) + sync_read_response_all = self._build_present_pos_send_fn( + self._ids, return_values, num_invalid_try + ) + else: + raise NotImplementedError # TODO(aliberts): add ping? + + self.stub( + name=f"SyncRead_{data_name}_all", + receive_bytes=sync_read_request_all, + send_fn=sync_read_response_all, + ) + + def _build_present_pos_send_fn( + self, dxl_ids: list[int], return_pos: list[int] | None = None, num_invalid_try: int | None = None ) -> Callable[[int], bytes]: - # if data_name == "Present_Position": - # packet_generator = MockStatusPacket.present_position - # else: - # # TODO(aliberts): add "Goal_Position" - # raise NotImplementedError + return_pos = [None for _ in dxl_ids] if return_pos is None else return_pos + assert len(return_pos) == len(dxl_ids) def send_fn(_call_count: int) -> bytes: if num_invalid_try is not None and num_invalid_try >= _call_count: - return bytes(0) - - first_packet = MockStatusPacket.present_position(next(iter(dxl_ids))) - if len(dxl_ids) == 1: - return first_packet - - packets = first_packet - for idx in dxl_ids: - packets += MockStatusPacket.present_position(dxl_id=idx) + return b"" + packets = b"".join( + MockStatusPacket.present_position(idx, pos) + for idx, pos in zip(dxl_ids, return_pos, strict=True) + ) return packets return send_fn diff --git a/tests/motors/test_dynamixel.py b/tests/motors/test_dynamixel.py index 77f6bd73..afdeb7f9 100644 --- a/tests/motors/test_dynamixel.py +++ b/tests/motors/test_dynamixel.py @@ -5,7 +5,7 @@ import dynamixel_sdk as dxl import pytest from lerobot.common.motors.dynamixel.dynamixel import DynamixelMotorsBus -from tests.mocks.mock_dynamixel import MockInstructionPacket, MockMotors, MockPortHandler +from tests.mocks.mock_dynamixel import MockMotors, MockPortHandler @pytest.fixture(autouse=True) @@ -68,11 +68,13 @@ def test_abc_implementation(): None, [1, 2, 3], ["dummy_1", "dummy_2", "dummy_3"], - [1, "dummy_2", 3], # Mixed + [1, "dummy_2", 3], ], ) def test_read_all_motors(motors): mock_motors = MockMotors([1, 2, 3]) + positions = [1337, 42, 4016] + mock_motors.build_all_motors_stub("Present_Position", return_values=positions) motors_bus = DynamixelMotorsBus( port=mock_motors.port, motors={ @@ -86,13 +88,22 @@ def test_read_all_motors(motors): pos_dict = motors_bus.read("Present_Position", motors=motors) assert mock_motors.stubs["SyncRead_Present_Position_all"].called - assert len(pos_dict) == 3 + assert all(returned_pos == pos for returned_pos, pos in zip(pos_dict.values(), positions, strict=True)) + assert set(pos_dict) == {"dummy_1", "dummy_2", "dummy_3"} assert all(pos >= 0 and pos <= 4095 for pos in pos_dict.values()) -@pytest.mark.parametrize("idx", [1, 2, 3]) -def test_read_single_motor_name(idx): +@pytest.mark.parametrize( + "idx, pos", + [ + [1, 1337], + [2, 42], + [3, 4016], + ], +) +def test_read_single_motor_by_name(idx, pos): mock_motors = MockMotors([1, 2, 3]) + mock_motors.build_single_motor_stubs("Present_Position", return_value=pos) motors_bus = DynamixelMotorsBus( port=mock_motors.port, motors={ @@ -106,13 +117,21 @@ def test_read_single_motor_name(idx): pos_dict = motors_bus.read("Present_Position", f"dummy_{idx}") assert mock_motors.stubs[f"SyncRead_Present_Position_{idx}"].called - assert len(pos_dict) == 1 + assert pos_dict == {f"dummy_{idx}": pos} assert all(pos >= 0 and pos <= 4095 for pos in pos_dict.values()) -@pytest.mark.parametrize("idx", [1, 2, 3]) -def test_read_single_motor_id(idx): +@pytest.mark.parametrize( + "idx, pos", + [ + [1, 1337], + [2, 42], + [3, 4016], + ], +) +def test_read_single_motor_by_id(idx, pos): mock_motors = MockMotors([1, 2, 3]) + mock_motors.build_single_motor_stubs("Present_Position", return_value=pos) motors_bus = DynamixelMotorsBus( port=mock_motors.port, motors={ @@ -126,28 +145,24 @@ def test_read_single_motor_id(idx): pos_dict = motors_bus.read("Present_Position", idx) assert mock_motors.stubs[f"SyncRead_Present_Position_{idx}"].called - assert len(pos_dict) == 1 + assert pos_dict == {f"dummy_{idx}": pos} assert all(pos >= 0 and pos <= 4095 for pos in pos_dict.values()) @pytest.mark.parametrize( - "num_retry, num_invalid_try", + "num_retry, num_invalid_try, pos", [ - [1, 2], - [2, 3], - [3, 2], - [2, 1], + [1, 2, 1337], + [2, 3, 42], + [3, 2, 4016], + [2, 1, 999], ], ) -def test_read_num_retry(num_retry, num_invalid_try): - mock_motors = MockMotors([1, 2, 3], default_stubs=None) - address, length = mock_motors.ctrl_table["Present_Position"] - receive_bytes = MockInstructionPacket.sync_read([1], address, length) - send_fn = mock_motors._create_present_pos_send_fn( - [1], "Present_Position", num_invalid_try=num_invalid_try +def test_read_num_retry(num_retry, num_invalid_try, pos): + mock_motors = MockMotors([1, 2, 3]) + mock_motors.build_single_motor_stubs( + "Present_Position", return_value=pos, num_invalid_try=num_invalid_try ) - mock_motors.stub(name="num_retry", receive_bytes=receive_bytes, send_fn=send_fn) - motors_bus = DynamixelMotorsBus( port=mock_motors.port, motors={ @@ -160,10 +175,10 @@ def test_read_num_retry(num_retry, num_invalid_try): if num_retry >= num_invalid_try: pos_dict = motors_bus.read("Present_Position", 1, num_retry=num_retry) - assert len(pos_dict) == 1 + assert pos_dict == {"dummy_1": pos} assert all(pos >= 0 and pos <= 4095 for pos in pos_dict.values()) else: with pytest.raises(ConnectionError): _ = motors_bus.read("Present_Position", 1, num_retry=num_retry) - assert mock_motors.stubs["num_retry"].calls == num_retry + assert mock_motors.stubs["SyncRead_Present_Position_1"].calls == num_retry