diff --git a/lerobot/common/motors/dynamixel/dynamixel.py b/lerobot/common/motors/dynamixel/dynamixel.py index a710afde..1ebefac0 100644 --- a/lerobot/common/motors/dynamixel/dynamixel.py +++ b/lerobot/common/motors/dynamixel/dynamixel.py @@ -84,6 +84,23 @@ class TorqueMode(Enum): DISABLED = 0 +def _split_into_byte_chunks(value: int, length: int) -> list[int]: + import dynamixel_sdk as dxl + + if length == 1: + data = [value] + elif length == 2: + data = [dxl.DXL_LOBYTE(value), dxl.DXL_HIBYTE(value)] + elif length == 4: + data = [ + dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)), + dxl.DXL_HIBYTE(dxl.DXL_LOWORD(value)), + dxl.DXL_LOBYTE(dxl.DXL_HIWORD(value)), + dxl.DXL_HIBYTE(dxl.DXL_HIWORD(value)), + ] + return data + + class DynamixelMotorsBus(MotorsBus): """ The Dynamixel implementation for a MotorsBus. It relies on the python dynamixel sdk to communicate with @@ -166,22 +183,8 @@ class DynamixelMotorsBus(MotorsBus): return half_turn_homings - @staticmethod - def _split_into_byte_chunks(value: int, length: int) -> list[int]: - import dynamixel_sdk as dxl - - if length == 1: - data = [value] - elif length == 2: - data = [dxl.DXL_LOBYTE(value), dxl.DXL_HIBYTE(value)] - elif length == 4: - data = [ - dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)), - dxl.DXL_HIBYTE(dxl.DXL_LOWORD(value)), - dxl.DXL_LOBYTE(dxl.DXL_HIWORD(value)), - dxl.DXL_HIBYTE(dxl.DXL_HIWORD(value)), - ] - return data + def _split_into_byte_chunks(self, value: int, length: int) -> list[int]: + return _split_into_byte_chunks(value, length) def broadcast_ping(self, num_retry: int = 0, raise_on_error: bool = False) -> dict[int, int] | None: for n_try in range(1 + num_retry): diff --git a/lerobot/common/motors/feetech/feetech.py b/lerobot/common/motors/feetech/feetech.py index a6b0c380..5e957f2f 100644 --- a/lerobot/common/motors/feetech/feetech.py +++ b/lerobot/common/motors/feetech/feetech.py @@ -64,6 +64,23 @@ class TorqueMode(Enum): DISABLED = 0 +def _split_into_byte_chunks(value: int, length: int) -> list[int]: + import scservo_sdk as scs + + if length == 1: + data = [value] + elif length == 2: + data = [scs.SCS_LOBYTE(value), scs.SCS_HIBYTE(value)] + elif length == 4: + data = [ + scs.SCS_LOBYTE(scs.SCS_LOWORD(value)), + scs.SCS_HIBYTE(scs.SCS_LOWORD(value)), + scs.SCS_LOBYTE(scs.SCS_HIWORD(value)), + scs.SCS_HIBYTE(scs.SCS_HIWORD(value)), + ] + return data + + def patch_setPacketTimeout(self, packet_length): # noqa: N802 """ HACK: This patches the PortHandler behavior to set the correct packet timeouts. @@ -169,22 +186,8 @@ class FeetechMotorsBus(MotorsBus): return ids_values - @staticmethod - def _split_into_byte_chunks(value: int, length: int) -> list[int]: - import scservo_sdk as scs - - if length == 1: - data = [value] - elif length == 2: - data = [scs.SCS_LOBYTE(value), scs.SCS_HIBYTE(value)] - elif length == 4: - data = [ - scs.SCS_LOBYTE(scs.SCS_LOWORD(value)), - scs.SCS_HIBYTE(scs.SCS_LOWORD(value)), - scs.SCS_LOBYTE(scs.SCS_HIWORD(value)), - scs.SCS_HIBYTE(scs.SCS_HIWORD(value)), - ] - return data + def _split_into_byte_chunks(self, value: int, length: int) -> list[int]: + return _split_into_byte_chunks(value, length) def _broadcast_ping_p1( self, known_motors_only: bool = True, n_motors: int | None = None, num_retry: int = 0 diff --git a/lerobot/common/motors/motors_bus.py b/lerobot/common/motors/motors_bus.py index 7bc8a4ae..efc81166 100644 --- a/lerobot/common/motors/motors_bus.py +++ b/lerobot/common/motors/motors_bus.py @@ -664,9 +664,8 @@ class MotorsBus(abc.ABC): return self._split_into_byte_chunks(value, length) - @staticmethod @abc.abstractmethod - def _split_into_byte_chunks(value: int, length: int) -> list[int]: + def _split_into_byte_chunks(self, value: int, length: int) -> list[int]: """Convert an integer into a list of byte-sized integers.""" pass diff --git a/tests/mocks/mock_dynamixel.py b/tests/mocks/mock_dynamixel.py index feae051b..1c1ab6fe 100644 --- a/tests/mocks/mock_dynamixel.py +++ b/tests/mocks/mock_dynamixel.py @@ -5,7 +5,8 @@ import dynamixel_sdk as dxl import serial from mock_serial.mock_serial import MockSerial -from lerobot.common.motors.dynamixel import X_SERIES_CONTROL_TABLE, DynamixelMotorsBus +from lerobot.common.motors.dynamixel import X_SERIES_CONTROL_TABLE +from lerobot.common.motors.dynamixel.dynamixel import _split_into_byte_chunks from .mock_serial_patch import WaitableStub @@ -237,7 +238,7 @@ class MockInstructionPacket(MockDynamixelPacketv2): +2 is for the length bytes, +2 is for the CRC at the end. """ - data = DynamixelMotorsBus._split_into_byte_chunks(value, data_length) + data = _split_into_byte_chunks(value, data_length) params = [ dxl.DXL_LOBYTE(start_address), dxl.DXL_HIBYTE(start_address), @@ -315,7 +316,7 @@ class MockInstructionPacket(MockDynamixelPacketv2): """ data = [] for id_, value in ids_values.items(): - split_value = DynamixelMotorsBus._split_into_byte_chunks(value, data_length) + split_value = _split_into_byte_chunks(value, data_length) data += [id_, *split_value] params = [ dxl.DXL_LOBYTE(start_address), @@ -389,7 +390,7 @@ class MockStatusPacket(MockDynamixelPacketv2): Returns: bytes: The raw 'Present_Position' status packet ready to be sent through serial. """ - params = DynamixelMotorsBus._split_into_byte_chunks(value, param_length) + params = _split_into_byte_chunks(value, param_length) length = param_length + 4 return cls.build(dxl_id, params=params, length=length) diff --git a/tests/mocks/mock_feetech.py b/tests/mocks/mock_feetech.py index 57bd8cbc..2b54ae91 100644 --- a/tests/mocks/mock_feetech.py +++ b/tests/mocks/mock_feetech.py @@ -5,8 +5,8 @@ import scservo_sdk as scs import serial from mock_serial import MockSerial -from lerobot.common.motors.feetech import STS_SMS_SERIES_CONTROL_TABLE, FeetechMotorsBus -from lerobot.common.motors.feetech.feetech import patch_setPacketTimeout +from lerobot.common.motors.feetech import STS_SMS_SERIES_CONTROL_TABLE +from lerobot.common.motors.feetech.feetech import _split_into_byte_chunks, patch_setPacketTimeout from .mock_serial_patch import WaitableStub @@ -139,7 +139,7 @@ class MockInstructionPacket(MockFeetechPacket): +1 is for the length bytes, +1 is for the checksum at the end. """ - data = FeetechMotorsBus._split_into_byte_chunks(value, data_length) + data = _split_into_byte_chunks(value, data_length) params = [start_address, *data] length = data_length + 3 return cls.build(scs_id=scs_id, params=params, length=length, instruct_type="Write") @@ -201,7 +201,7 @@ class MockInstructionPacket(MockFeetechPacket): """ data = [] for id_, value in ids_values.items(): - split_value = FeetechMotorsBus._split_into_byte_chunks(value, data_length) + split_value = _split_into_byte_chunks(value, data_length) data += [id_, *split_value] params = [start_address, data_length, *data] length = len(ids_values) * (1 + data_length) + 4 @@ -258,7 +258,7 @@ class MockStatusPacket(MockFeetechPacket): Returns: bytes: The raw 'Sync Read' status packet ready to be sent through serial. """ - params = FeetechMotorsBus._split_into_byte_chunks(value, param_length) + params = _split_into_byte_chunks(value, param_length) length = param_length + 2 return cls.build(scs_id, params=params, length=length) diff --git a/tests/motors/test_dynamixel.py b/tests/motors/test_dynamixel.py index e047e7c1..2b708836 100644 --- a/tests/motors/test_dynamixel.py +++ b/tests/motors/test_dynamixel.py @@ -67,43 +67,16 @@ def test_autouse_patch(): (0x12, 1, [0x12]), (0x1234, 2, [0x34, 0x12]), (0x12345678, 4, [0x78, 0x56, 0x34, 0x12]), - (0, 1, [0x00]), - (0, 2, [0x00, 0x00]), - (0, 4, [0x00, 0x00, 0x00, 0x00]), - (255, 1, [0xFF]), - (65535, 2, [0xFF, 0xFF]), - (4294967295, 4, [0xFF, 0xFF, 0xFF, 0xFF]), ], ids=[ "1 byte", "2 bytes", "4 bytes", - "0 with 1 byte", - "0 with 2 bytes", - "0 with 4 bytes", - "max single byte", - "max two bytes", - "max four bytes", ], ) # fmt: skip -def test_serialize_data(value, length, expected): - assert DynamixelMotorsBus._serialize_data(value, length) == expected - - -def test_serialize_data_invalid_length(): - with pytest.raises(NotImplementedError): - DynamixelMotorsBus._serialize_data(100, 3) - - -def test_serialize_data_negative_numbers(): - with pytest.raises(ValueError): - neg = DynamixelMotorsBus._serialize_data(-1, 1) - print(neg) - - -def test_serialize_data_large_number(): - with pytest.raises(ValueError): - DynamixelMotorsBus._serialize_data(2**32, 4) # 4-byte max is 0xFFFFFFFF +def test__split_into_byte_chunks(value, length, expected): + bus = DynamixelMotorsBus("", {}) + assert bus._split_into_byte_chunks(value, length) == expected def test_abc_implementation(dummy_motors): diff --git a/tests/motors/test_feetech.py b/tests/motors/test_feetech.py index da819464..2d3d4db7 100644 --- a/tests/motors/test_feetech.py +++ b/tests/motors/test_feetech.py @@ -61,48 +61,27 @@ def test_autouse_patch(): @pytest.mark.parametrize( - "value, length, expected", + "protocol, value, length, expected", [ - (0x12, 1, [0x12]), - (0x1234, 2, [0x34, 0x12]), - (0x12345678, 4, [0x78, 0x56, 0x34, 0x12]), - (0, 1, [0x00]), - (0, 2, [0x00, 0x00]), - (0, 4, [0x00, 0x00, 0x00, 0x00]), - (255, 1, [0xFF]), - (65535, 2, [0xFF, 0xFF]), - (4294967295, 4, [0xFF, 0xFF, 0xFF, 0xFF]), + (0, 0x12, 1, [0x12]), + (1, 0x12, 1, [0x12]), + (0, 0x1234, 2, [0x34, 0x12]), + (1, 0x1234, 2, [0x12, 0x34]), + (0, 0x12345678, 4, [0x78, 0x56, 0x34, 0x12]), + (1, 0x12345678, 4, [0x56, 0x78, 0x12, 0x34]), ], ids=[ - "1 byte", - "2 bytes", - "4 bytes", - "0 with 1 byte", - "0 with 2 bytes", - "0 with 4 bytes", - "max single byte", - "max two bytes", - "max four bytes", + "P0: 1 byte", + "P1: 1 byte", + "P0: 2 bytes", + "P1: 2 bytes", + "P0: 4 bytes", + "P1: 4 bytes", ], ) # fmt: skip -def test_serialize_data(value, length, expected): - assert FeetechMotorsBus._serialize_data(value, length) == expected - - -def test_serialize_data_invalid_length(): - with pytest.raises(NotImplementedError): - FeetechMotorsBus._serialize_data(100, 3) - - -def test_serialize_data_negative_numbers(): - with pytest.raises(ValueError): - neg = FeetechMotorsBus._serialize_data(-1, 1) - print(neg) - - -def test_serialize_data_large_number(): - with pytest.raises(ValueError): - FeetechMotorsBus._serialize_data(2**32, 4) # 4-byte max is 0xFFFFFFFF +def test__split_into_byte_chunks(protocol, value, length, expected): + bus = FeetechMotorsBus("", {}, protocol_version=protocol) + assert bus._split_into_byte_chunks(value, length) == expected def test_abc_implementation(dummy_motors): diff --git a/tests/motors/test_motors_bus.py b/tests/motors/test_motors_bus.py index 7463ae8c..8ceaeefa 100644 --- a/tests/motors/test_motors_bus.py +++ b/tests/motors/test_motors_bus.py @@ -2,12 +2,50 @@ import re import pytest -from lerobot.common.motors.motors_bus import assert_same_address, get_address, get_ctrl_table +from lerobot.common.motors.motors_bus import ( + Motor, + MotorsBus, + assert_same_address, + get_address, + get_ctrl_table, +) -# TODO(aliberts) -# class DummyMotorsBus(MotorsBus): -# def __init__(self, port: str, motors: dict[str, Motor]): -# super().__init__(port, motors) +DUMMY_CTRL_TABLE = {"Present_Position": (13, 4)} + +DUMMY_BAUDRATE_TABLE = { + 0: 1_000_000, + 1: 500_000, +} + +DUMMY_ENCODING_TABLE = { + "Present_Position": 8, +} + +DUMMY_MODEL_NUMBER_TABLE = {""} + + +class DummyMotorsBus(MotorsBus): + available_baudrates = [1_000_000] + default_timeout = 1000 + model_baudrate_table = {"model": DUMMY_BAUDRATE_TABLE} + model_ctrl_table = {"model": DUMMY_CTRL_TABLE} + model_encoding_table = {"model": DUMMY_ENCODING_TABLE} + model_number_table = {"model": 1234} + model_resolution_table = {"model": 4096} + normalized_data = ["Present_Position"] + + def __init__(self, port: str, motors: dict[str, Motor]): + super().__init__(port, motors) + + def _assert_protocol_is_compatible(self, instruction_name): ... + def configure_motors(self): ... + def disable_torque(self, motors): ... + def enable_torque(self, motors): ... + def _get_half_turn_homings(self, positions): ... + def _encode_sign(self, data_name, ids_values): ... + def _decode_sign(self, data_name, ids_values): ... + def _split_into_byte_chunks(self, value, length): ... + def broadcast_ping(self, num_retry, raise_on_error): ... @pytest.fixture @@ -85,3 +123,21 @@ def test_assert_same_address_different_bytes(model_ctrl_table): match=re.escape("At least two motor models use a different bytes representation"), ): assert_same_address(model_ctrl_table, models, "Goal_Position") + + +def test__serialize_data_invalid_length(): + bus = DummyMotorsBus("", {}) + with pytest.raises(NotImplementedError): + bus._serialize_data(100, 3) + + +def test__serialize_data_negative_numbers(): + bus = DummyMotorsBus("", {}) + with pytest.raises(ValueError): + bus._serialize_data(-1, 1) + + +def test__serialize_data_large_number(): + bus = DummyMotorsBus("", {}) + with pytest.raises(ValueError): + bus._serialize_data(2**32, 4) # 4-byte max is 0xFFFFFFFF