From bc020ee0a42b2207f31adada6ca23bd3b7abcb2a Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Thu, 20 Mar 2025 14:00:10 +0100 Subject: [PATCH] Remove mock_feetech sdk & add feetech new tests --- tests/mocks/mock_scservo_sdk.py | 151 ------------------------ tests/motors/test_feetech.py | 196 ++++++++++++++++++++++++++++---- 2 files changed, 174 insertions(+), 173 deletions(-) delete mode 100644 tests/mocks/mock_scservo_sdk.py diff --git a/tests/mocks/mock_scservo_sdk.py b/tests/mocks/mock_scservo_sdk.py deleted file mode 100644 index 21a3652b..00000000 --- a/tests/mocks/mock_scservo_sdk.py +++ /dev/null @@ -1,151 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# ruff: noqa: N802, E741 - -""" -Mocked classes and functions from scservo_sdk to allow for testing FeetechMotorsBus code. - -Warning: These mocked versions are minimalist. They do not exactly mock every behaviors -from the original classes and functions (e.g. return types might be None instead of boolean). -""" - -DEFAULT_BAUDRATE = 1_000_000 -COMM_SUCCESS = 0 # tx or rx packet communication success - - -def convert_to_bytes(value, bytes): - # TODO(rcadene): remove need to mock `convert_to_bytes` by implemented the inverse transform - # `convert_bytes_to_value` - del bytes # unused - return value - - -def get_default_motor_values(motor_index): - return { - # Key (int) are from SCS_SERIES_CONTROL_TABLE - 5: motor_index, # ID - 6: DEFAULT_BAUDRATE, # Baud_rate - 10: 0, # Drive_Mode - 21: 32, # P_Coefficient - 22: 32, # D_Coefficient - 23: 0, # I_Coefficient - 40: 0, # Torque_Enable - 41: 254, # Acceleration - 31: -2047, # Offset - 33: 0, # Mode - 55: 1, # Lock - # Set 2560 since calibration values for Aloha gripper is between start_pos=2499 and end_pos=3144 - # For other joints, 2560 will be autocorrected to be in calibration range - 56: 2560, # Present_Position - 58: 0, # Present_Speed - 69: 0, # Present_Current - 85: 150, # Maximum_Acceleration - } - - -# Macro for Control Table Value -def SCS_MAKEWORD(a, b): - return (a & 0xFF) | ((b & 0xFF) << 8) - - -def SCS_MAKEDWORD(a, b): - return (a & 0xFFFF) | (b & 0xFFFF) << 16 - - -def SCS_LOWORD(l): - return l & 0xFFFF - - -def SCS_HIWORD(l): - return (l >> 16) & 0xFFFF - - -def SCS_LOBYTE(w): - return w & 0xFF - - -def SCS_HIBYTE(w): - return (w >> 8) & 0xFF - - -class PortHandler: - def __init__(self, port): - self.port = port - # factory default baudrate - self.baudrate = DEFAULT_BAUDRATE - self.ser = SerialMock() - - def openPort(self): # noqa: N802 - return True - - def closePort(self): # noqa: N802 - pass - - def setPacketTimeoutMillis(self, timeout_ms): # noqa: N802 - del timeout_ms # unused - - def getBaudRate(self): # noqa: N802 - return self.baudrate - - def setBaudRate(self, baudrate): # noqa: N802 - self.baudrate = baudrate - - -class PacketHandler: - def __init__(self, protocol_version): - del protocol_version # unused - # Use packet_handler.data to communicate across Read and Write - self.data = {} - - -class GroupSyncRead: - def __init__(self, port_handler, packet_handler, address, bytes): - self.packet_handler = packet_handler - - def addParam(self, motor_index): # noqa: N802 - # Initialize motor default values - if motor_index not in self.packet_handler.data: - self.packet_handler.data[motor_index] = get_default_motor_values(motor_index) - - def txRxPacket(self): # noqa: N802 - return COMM_SUCCESS - - def getData(self, index, address, bytes): # noqa: N802 - return self.packet_handler.data[index][address] - - -class GroupSyncWrite: - def __init__(self, port_handler, packet_handler, address, bytes): - self.packet_handler = packet_handler - self.address = address - - def addParam(self, index, data): # noqa: N802 - if index not in self.packet_handler.data: - self.packet_handler.data[index] = get_default_motor_values(index) - self.changeParam(index, data) - - def txPacket(self): # noqa: N802 - return COMM_SUCCESS - - def changeParam(self, index, data): # noqa: N802 - self.packet_handler.data[index][self.address] = data - - -class SerialMock: - def reset_output_buffer(self): - pass - - def reset_input_buffer(self): - pass diff --git a/tests/motors/test_feetech.py b/tests/motors/test_feetech.py index 6c580f6a..42292fb0 100644 --- a/tests/motors/test_feetech.py +++ b/tests/motors/test_feetech.py @@ -2,43 +2,195 @@ import sys from unittest.mock import patch import pytest +import scservo_sdk as scs from lerobot.common.motors.feetech import FeetechMotorsBus -from tests.mocks import mock_scservo_sdk +from tests.mocks.mock_feetech import MockMotors, MockPortHandler @pytest.fixture(autouse=True) -def patch_scservo_sdk(): - with patch.dict(sys.modules, {"scservo_sdk": mock_scservo_sdk}): +def patch_port_handler(): + if sys.platform == "darwin": + with patch.object(scs, "PortHandler", MockPortHandler): + yield + else: yield -def test_patch_sdk(): - assert "scservo_sdk" in sys.modules # Should be patched - assert sys.modules["scservo_sdk"] is mock_scservo_sdk # Should match the mock +@pytest.mark.skipif(sys.platform != "darwin", reason=f"No patching needed on {sys.platform=}") +def test_autouse_patch(): + """Ensures that the autouse fixture correctly patches scs.PortHandler with MockPortHandler.""" + assert scs.PortHandler is MockPortHandler + + +@pytest.mark.parametrize( + "value, n_bytes, expected", + [ + (0x12, 1, [0x12]), + (0x1234, 2, [0x34, 0x12]), + (0x12345678, 4, [0x78, 0x56, 0x34, 0x12]), + (0, 1, [0x00]), + (0, 2, [0x00, 0x00]), + (0, 4, [0x00, 0x00, 0x00, 0x00]), + (255, 1, [0xFF]), + (65535, 2, [0xFF, 0xFF]), + (4294967295, 4, [0xFF, 0xFF, 0xFF, 0xFF]), + ], + ids=[ + "1 byte", + "2 bytes", + "4 bytes", + "0 with 1 byte", + "0 with 2 bytes", + "0 with 4 bytes", + "max single byte", + "max two bytes", + "max four bytes", + ], +) # fmt: skip +def test_split_int_bytes(value, n_bytes, expected): + assert FeetechMotorsBus.split_int_bytes(value, n_bytes) == expected + + +def test_split_int_bytes_invalid_n_bytes(): + with pytest.raises(NotImplementedError): + FeetechMotorsBus.split_int_bytes(100, 3) + + +def test_split_int_bytes_negative_numbers(): + with pytest.raises(ValueError): + neg = FeetechMotorsBus.split_int_bytes(-1, 1) + print(neg) + + +def test_split_int_bytes_large_number(): + with pytest.raises(ValueError): + FeetechMotorsBus.split_int_bytes(2**32, 4) # 4-byte max is 0xFFFFFFFF def test_abc_implementation(): - # Instantiation should raise an error if the class doesn't implements abstract methods/properties + """Instantiation should raise an error if the class doesn't implement abstract methods/properties.""" FeetechMotorsBus(port="/dev/dummy-port", motors={"dummy": (1, "sts3215")}) -def test_configure_motors_all_ids_1(): - # see SCS_SERIES_BAUDRATE_TABLE - smaller_baudrate = 19_200 - smaller_baudrate_value = 7 - - # This test expect the configuration was already correct. - motors_bus = FeetechMotorsBus(port="/dev/dummy-port", motors={"dummy": (1, "sts3215")}) +@pytest.mark.parametrize( + "motors", + [ + None, + [1, 2, 3], + ["dummy_1", "dummy_2", "dummy_3"], + [1, "dummy_2", 3], + ], + ids=["None", "by ids", "by names", "mixed"], +) +def test_read_all_motors(motors): + mock_motors = MockMotors([1, 2, 3]) + positions = [1337, 42, 4016] + mock_motors.build_all_motors_stub("Present_Position", return_values=positions) + motors_bus = FeetechMotorsBus( + port=mock_motors.port, + motors={ + "dummy_1": (1, "sts3215"), + "dummy_2": (2, "sts3215"), + "dummy_3": (3, "sts3215"), + }, + ) motors_bus.connect() - motors_bus.write("Baud_Rate", [smaller_baudrate_value] * len(motors_bus)) - motors_bus.set_baudrate(smaller_baudrate) - motors_bus.write("ID", [1] * len(motors_bus)) - del motors_bus + pos_dict = motors_bus.read("Present_Position", motors=motors) - # Test configure - motors_bus = FeetechMotorsBus(port="/dev/dummy-port", motors={"dummy": (1, "sts3215")}) + assert mock_motors.stubs["SyncRead_Present_Position_all"].called + assert all(returned_pos == pos for returned_pos, pos in zip(pos_dict.values(), positions, strict=True)) + assert set(pos_dict) == {"dummy_1", "dummy_2", "dummy_3"} + assert all(pos >= 0 and pos <= 4095 for pos in pos_dict.values()) + + +@pytest.mark.parametrize( + "idx, pos", + [ + [1, 1337], + [2, 42], + [3, 4016], + ], +) +def test_read_single_motor_by_name(idx, pos): + mock_motors = MockMotors([1, 2, 3]) + mock_motors.build_single_motor_stubs("Present_Position", return_value=pos) + motors_bus = FeetechMotorsBus( + port=mock_motors.port, + motors={ + "dummy_1": (1, "sts3215"), + "dummy_2": (2, "sts3215"), + "dummy_3": (3, "sts3215"), + }, + ) motors_bus.connect() - assert motors_bus.are_motors_configured() - del motors_bus + + pos_dict = motors_bus.read("Present_Position", f"dummy_{idx}") + + assert mock_motors.stubs[f"SyncRead_Present_Position_{idx}"].called + assert pos_dict == {f"dummy_{idx}": pos} + assert all(pos >= 0 and pos <= 4095 for pos in pos_dict.values()) + + +@pytest.mark.parametrize( + "idx, pos", + [ + [1, 1337], + [2, 42], + [3, 4016], + ], +) +def test_read_single_motor_by_id(idx, pos): + mock_motors = MockMotors([1, 2, 3]) + mock_motors.build_single_motor_stubs("Present_Position", return_value=pos) + motors_bus = FeetechMotorsBus( + port=mock_motors.port, + motors={ + "dummy_1": (1, "sts3215"), + "dummy_2": (2, "sts3215"), + "dummy_3": (3, "sts3215"), + }, + ) + motors_bus.connect() + + pos_dict = motors_bus.read("Present_Position", idx) + + assert mock_motors.stubs[f"SyncRead_Present_Position_{idx}"].called + assert pos_dict == {f"dummy_{idx}": pos} + assert all(pos >= 0 and pos <= 4095 for pos in pos_dict.values()) + + +@pytest.mark.parametrize( + "num_retry, num_invalid_try, pos", + [ + [1, 2, 1337], + [2, 3, 42], + [3, 2, 4016], + [2, 1, 999], + ], +) +def test_read_num_retry(num_retry, num_invalid_try, pos): + mock_motors = MockMotors([1, 2, 3]) + mock_motors.build_single_motor_stubs( + "Present_Position", return_value=pos, num_invalid_try=num_invalid_try + ) + motors_bus = FeetechMotorsBus( + port=mock_motors.port, + motors={ + "dummy_1": (1, "sts3215"), + "dummy_2": (2, "sts3215"), + "dummy_3": (3, "sts3215"), + }, + ) + motors_bus.connect() + + if num_retry >= num_invalid_try: + pos_dict = motors_bus.read("Present_Position", 1, num_retry=num_retry) + assert pos_dict == {"dummy_1": pos} + assert all(pos >= 0 and pos <= 4095 for pos in pos_dict.values()) + else: + with pytest.raises(ConnectionError): + _ = motors_bus.read("Present_Position", 1, num_retry=num_retry) + + assert mock_motors.stubs["SyncRead_Present_Position_1"].calls == num_retry