diff --git a/tests/mocks/mock_feetech.py b/tests/mocks/mock_feetech.py index 431707c5..e0a38042 100644 --- a/tests/mocks/mock_feetech.py +++ b/tests/mocks/mock_feetech.py @@ -77,6 +77,19 @@ class MockInstructionPacket(MockFeetechPacket): 0x00, # placeholder for checksum ] # fmt: skip + @classmethod + def ping( + cls, + scs_id: int, + ) -> bytes: + """ + Builds a "Ping" broadcast instruction. + + No parameters required. + """ + params, length = [], 2 + return cls.build(scs_id=scs_id, params=params, length=length, instruct_type="Ping") + @classmethod def sync_read( cls, @@ -128,6 +141,25 @@ class MockStatusPacket(MockFeetechPacket): 0x00, # placeholder for checksum ] # fmt: skip + @classmethod + def ping(cls, scs_id: int, model_nb: int = 1190, firm_ver: int = 50) -> bytes: + """Builds a 'Ping' status packet. + + Args: + scs_id (int): ID of the servo responding. + model_nb (int, optional): Desired 'model number' to be returned in the packet. Defaults to 1190 + which corresponds to a XL330-M077-T. + firm_ver (int, optional): Desired 'firmware version' to be returned in the packet. + Defaults to 50. + + Returns: + bytes: The raw 'Ping' status packet ready to be sent through serial. + """ + # raise NotImplementedError + params = [scs.SCS_LOBYTE(model_nb), scs.SCS_HIBYTE(model_nb), firm_ver] + length = 2 + return cls.build(scs_id, params=params, length=length) + @classmethod def present_position(cls, scs_id: int, pos: int | None = None, min_max_range: tuple = (0, 4095)) -> bytes: """Builds a 'Present_Position' status packet. @@ -184,62 +216,69 @@ class MockMotors(MockSerial): ctrl_table = SCS_SERIES_CONTROL_TABLE - def __init__(self, scs_ids: list[int]): + def __init__(self): super().__init__() - self._ids = scs_ids self.open() - def build_single_motor_stubs( - self, data_name: str, return_value: int | None = None, num_invalid_try: int | None = None - ) -> None: - address, length = self.ctrl_table[data_name] - for idx in self._ids: - if data_name == "Present_Position": - sync_read_request_single = MockInstructionPacket.sync_read([idx], address, length) - sync_read_response_single = self._build_present_pos_send_fn( - [idx], [return_value], num_invalid_try - ) - else: - raise NotImplementedError # TODO(aliberts): add ping? - - self.stub( - name=f"SyncRead_{data_name}_{idx}", - receive_bytes=sync_read_request_single, - send_fn=sync_read_response_single, - ) - - def build_all_motors_stub( - self, data_name: str, return_values: list[int] | None = None, num_invalid_try: int | None = None - ) -> None: - address, length = self.ctrl_table[data_name] - if data_name == "Present_Position": - sync_read_request_all = MockInstructionPacket.sync_read(self._ids, address, length) - sync_read_response_all = self._build_present_pos_send_fn( - self._ids, return_values, num_invalid_try - ) - else: - raise NotImplementedError # TODO(aliberts): add ping? - - self.stub( - name=f"SyncRead_{data_name}_all", - receive_bytes=sync_read_request_all, - send_fn=sync_read_response_all, + def build_broadcast_ping_stub( + self, ids_models_firmwares: dict[int, list[int]] | None = None, num_invalid_try: int = 0 + ) -> str: + ping_request = MockInstructionPacket.ping(scs.BROADCAST_ID) + return_packets = b"".join( + MockStatusPacket.ping(idx, model, firm_ver) + for idx, (model, firm_ver) in ids_models_firmwares.items() ) + ping_response = self._build_send_fn(return_packets, num_invalid_try) - def _build_present_pos_send_fn( - self, scs_ids: list[int], return_pos: list[int] | None = None, num_invalid_try: int | None = None - ) -> Callable[[int], bytes]: - return_pos = [None for _ in scs_ids] if return_pos is None else return_pos - assert len(return_pos) == len(scs_ids) + stub_name = "Ping_" + "_".join([str(idx) for idx in ids_models_firmwares]) + self.stub( + name=stub_name, + receive_bytes=ping_request, + send_fn=ping_response, + ) + return stub_name + def build_ping_stub( + self, scs_id: int, model_nb: int, firm_ver: int = 50, num_invalid_try: int = 0 + ) -> str: + ping_request = MockInstructionPacket.ping(scs_id) + return_packet = MockStatusPacket.ping(scs_id, model_nb, firm_ver) + ping_response = self._build_send_fn(return_packet, num_invalid_try) + + stub_name = f"Ping_{scs_id}" + self.stub( + name=stub_name, + receive_bytes=ping_request, + send_fn=ping_response, + ) + return stub_name + + def build_sync_read_stub( + self, data_name: str, ids_values: dict[int, int] | None = None, num_invalid_try: int = 0 + ) -> str: + address, length = self.ctrl_table[data_name] + sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length) + if data_name != "Present_Position": + raise NotImplementedError + + return_packets = b"".join( + MockStatusPacket.present_position(idx, pos) for idx, pos in ids_values.items() + ) + sync_read_response = self._build_send_fn(return_packets, num_invalid_try) + + stub_name = f"Sync_Read_{data_name}_" + "_".join([str(idx) for idx in ids_values]) + self.stub( + name=stub_name, + receive_bytes=sync_read_request, + send_fn=sync_read_response, + ) + return stub_name + + @staticmethod + def _build_send_fn(packet: bytes, num_invalid_try: int = 0) -> Callable[[int], bytes]: def send_fn(_call_count: int) -> bytes: - if num_invalid_try is not None and num_invalid_try >= _call_count: + if num_invalid_try >= _call_count: return b"" - - packets = b"".join( - MockStatusPacket.present_position(idx, pos) - for idx, pos in zip(scs_ids, return_pos, strict=True) - ) - return packets + return packet return send_fn diff --git a/tests/motors/test_feetech.py b/tests/motors/test_feetech.py index 42292fb0..d4a4397d 100644 --- a/tests/motors/test_feetech.py +++ b/tests/motors/test_feetech.py @@ -4,6 +4,7 @@ from unittest.mock import patch import pytest import scservo_sdk as scs +from lerobot.common.motors import Motor from lerobot.common.motors.feetech import FeetechMotorsBus from tests.mocks.mock_feetech import MockMotors, MockPortHandler @@ -17,6 +18,15 @@ def patch_port_handler(): yield +@pytest.fixture +def dummy_motors() -> dict[str, Motor]: + return { + "dummy_1": Motor(id=1, model="sts3215"), + "dummy_2": Motor(id=2, model="sts3215"), + "dummy_3": Motor(id=3, model="sts3215"), + } + + @pytest.mark.skipif(sys.platform != "darwin", reason=f"No patching needed on {sys.platform=}") def test_autouse_patch(): """Ensures that the autouse fixture correctly patches scs.PortHandler with MockPortHandler.""" @@ -68,9 +78,52 @@ def test_split_int_bytes_large_number(): FeetechMotorsBus.split_int_bytes(2**32, 4) # 4-byte max is 0xFFFFFFFF -def test_abc_implementation(): +def test_abc_implementation(dummy_motors): """Instantiation should raise an error if the class doesn't implement abstract methods/properties.""" - FeetechMotorsBus(port="/dev/dummy-port", motors={"dummy": (1, "sts3215")}) + FeetechMotorsBus(port="/dev/dummy-port", motors=dummy_motors) + + +@pytest.mark.skip("TODO") +@pytest.mark.parametrize( + "idx, model_nb", + [ + [1, 1190], + [2, 1200], + [3, 1120], + ], +) +def test_ping(idx, model_nb, dummy_motors): + mock_motors = MockMotors() + mock_motors.build_ping_stub(idx, model_nb) + motors_bus = FeetechMotorsBus( + port=mock_motors.port, + motors=dummy_motors, + ) + motors_bus.connect() + + ping_model_nb = motors_bus.ping(idx) + + assert ping_model_nb == model_nb + + +@pytest.mark.skip("TODO") +def test_broadcast_ping(dummy_motors): + expected_pings = { + 1: [1060, 50], + 2: [1120, 30], + 3: [1190, 10], + } + mock_motors = MockMotors() + mock_motors.build_broadcast_ping_stub(expected_pings) + motors_bus = FeetechMotorsBus( + port=mock_motors.port, + motors=dummy_motors, + ) + motors_bus.connect() + + ping_list = motors_bus.broadcast_ping() + + assert ping_list == expected_pings @pytest.mark.parametrize( @@ -83,26 +136,25 @@ def test_abc_implementation(): ], ids=["None", "by ids", "by names", "mixed"], ) -def test_read_all_motors(motors): - mock_motors = MockMotors([1, 2, 3]) - positions = [1337, 42, 4016] - mock_motors.build_all_motors_stub("Present_Position", return_values=positions) +def test_read_all_motors(motors, dummy_motors): + mock_motors = MockMotors() + expected_positions = { + 1: 1337, + 2: 42, + 3: 4016, + } + stub_name = mock_motors.build_sync_read_stub("Present_Position", expected_positions) motors_bus = FeetechMotorsBus( port=mock_motors.port, - motors={ - "dummy_1": (1, "sts3215"), - "dummy_2": (2, "sts3215"), - "dummy_3": (3, "sts3215"), - }, + motors=dummy_motors, ) motors_bus.connect() - pos_dict = motors_bus.read("Present_Position", motors=motors) + positions_read = motors_bus.read("Present_Position", motors=motors) - assert mock_motors.stubs["SyncRead_Present_Position_all"].called - assert all(returned_pos == pos for returned_pos, pos in zip(pos_dict.values(), positions, strict=True)) - assert set(pos_dict) == {"dummy_1", "dummy_2", "dummy_3"} - assert all(pos >= 0 and pos <= 4095 for pos in pos_dict.values()) + motors = ["dummy_1", "dummy_2", "dummy_3"] if motors is None else motors + assert mock_motors.stubs[stub_name].called + assert positions_read == dict(zip(motors, expected_positions.values(), strict=True)) @pytest.mark.parametrize( @@ -113,24 +165,20 @@ def test_read_all_motors(motors): [3, 4016], ], ) -def test_read_single_motor_by_name(idx, pos): - mock_motors = MockMotors([1, 2, 3]) - mock_motors.build_single_motor_stubs("Present_Position", return_value=pos) +def test_read_single_motor_by_name(idx, pos, dummy_motors): + mock_motors = MockMotors() + expected_position = {idx: pos} + stub_name = mock_motors.build_sync_read_stub("Present_Position", expected_position) motors_bus = FeetechMotorsBus( port=mock_motors.port, - motors={ - "dummy_1": (1, "sts3215"), - "dummy_2": (2, "sts3215"), - "dummy_3": (3, "sts3215"), - }, + motors=dummy_motors, ) motors_bus.connect() pos_dict = motors_bus.read("Present_Position", f"dummy_{idx}") - assert mock_motors.stubs[f"SyncRead_Present_Position_{idx}"].called + assert mock_motors.stubs[stub_name].called assert pos_dict == {f"dummy_{idx}": pos} - assert all(pos >= 0 and pos <= 4095 for pos in pos_dict.values()) @pytest.mark.parametrize( @@ -141,56 +189,49 @@ def test_read_single_motor_by_name(idx, pos): [3, 4016], ], ) -def test_read_single_motor_by_id(idx, pos): - mock_motors = MockMotors([1, 2, 3]) - mock_motors.build_single_motor_stubs("Present_Position", return_value=pos) +def test_read_single_motor_by_id(idx, pos, dummy_motors): + mock_motors = MockMotors() + expected_position = {idx: pos} + stub_name = mock_motors.build_sync_read_stub("Present_Position", expected_position) motors_bus = FeetechMotorsBus( port=mock_motors.port, - motors={ - "dummy_1": (1, "sts3215"), - "dummy_2": (2, "sts3215"), - "dummy_3": (3, "sts3215"), - }, + motors=dummy_motors, ) motors_bus.connect() pos_dict = motors_bus.read("Present_Position", idx) - assert mock_motors.stubs[f"SyncRead_Present_Position_{idx}"].called - assert pos_dict == {f"dummy_{idx}": pos} - assert all(pos >= 0 and pos <= 4095 for pos in pos_dict.values()) + assert mock_motors.stubs[stub_name].called + assert pos_dict == {idx: pos} @pytest.mark.parametrize( "num_retry, num_invalid_try, pos", [ - [1, 2, 1337], + [0, 2, 1337], [2, 3, 42], [3, 2, 4016], [2, 1, 999], ], ) -def test_read_num_retry(num_retry, num_invalid_try, pos): - mock_motors = MockMotors([1, 2, 3]) - mock_motors.build_single_motor_stubs( - "Present_Position", return_value=pos, num_invalid_try=num_invalid_try +def test_read_num_retry(num_retry, num_invalid_try, pos, dummy_motors): + mock_motors = MockMotors() + expected_position = {1: pos} + stub_name = mock_motors.build_sync_read_stub( + "Present_Position", expected_position, num_invalid_try=num_invalid_try ) motors_bus = FeetechMotorsBus( port=mock_motors.port, - motors={ - "dummy_1": (1, "sts3215"), - "dummy_2": (2, "sts3215"), - "dummy_3": (3, "sts3215"), - }, + motors=dummy_motors, ) motors_bus.connect() if num_retry >= num_invalid_try: pos_dict = motors_bus.read("Present_Position", 1, num_retry=num_retry) - assert pos_dict == {"dummy_1": pos} - assert all(pos >= 0 and pos <= 4095 for pos in pos_dict.values()) + assert pos_dict == {1: pos} else: with pytest.raises(ConnectionError): _ = motors_bus.read("Present_Position", 1, num_retry=num_retry) - assert mock_motors.stubs["SyncRead_Present_Position_1"].calls == num_retry + expected_calls = min(1 + num_retry, 1 + num_invalid_try) + assert mock_motors.stubs[stub_name].calls == expected_calls