From d32daebf75b84a69ff9b4a0a7e3b9582ceb20257 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Fri, 11 Apr 2025 11:01:12 +0200 Subject: [PATCH] Refactor & add _serialize_data --- lerobot/common/motors/dynamixel/dynamixel.py | 8 +- lerobot/common/motors/feetech/feetech.py | 8 +- lerobot/common/motors/motors_bus.py | 116 +++++++------------ tests/mocks/mock_dynamixel.py | 6 +- tests/mocks/mock_feetech.py | 8 +- tests/motors/test_dynamixel.py | 18 +-- tests/motors/test_feetech.py | 18 +-- 7 files changed, 78 insertions(+), 104 deletions(-) diff --git a/lerobot/common/motors/dynamixel/dynamixel.py b/lerobot/common/motors/dynamixel/dynamixel.py index 8f69b8b5..a710afde 100644 --- a/lerobot/common/motors/dynamixel/dynamixel.py +++ b/lerobot/common/motors/dynamixel/dynamixel.py @@ -167,14 +167,14 @@ class DynamixelMotorsBus(MotorsBus): return half_turn_homings @staticmethod - def _split_into_byte_chunks(value: int, n_bytes: int) -> list[int]: + def _split_into_byte_chunks(value: int, length: int) -> list[int]: import dynamixel_sdk as dxl - if n_bytes == 1: + if length == 1: data = [value] - elif n_bytes == 2: + elif length == 2: data = [dxl.DXL_LOBYTE(value), dxl.DXL_HIBYTE(value)] - elif n_bytes == 4: + elif length == 4: data = [ dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)), dxl.DXL_HIBYTE(dxl.DXL_LOWORD(value)), diff --git a/lerobot/common/motors/feetech/feetech.py b/lerobot/common/motors/feetech/feetech.py index f7557c97..a0796f9c 100644 --- a/lerobot/common/motors/feetech/feetech.py +++ b/lerobot/common/motors/feetech/feetech.py @@ -170,14 +170,14 @@ class FeetechMotorsBus(MotorsBus): return ids_values @staticmethod - def _split_into_byte_chunks(value: int, n_bytes: int) -> list[int]: + def _split_into_byte_chunks(value: int, length: int) -> list[int]: import scservo_sdk as scs - if n_bytes == 1: + if length == 1: data = [value] - elif n_bytes == 2: + elif length == 2: data = [scs.SCS_LOBYTE(value), scs.SCS_HIBYTE(value)] - elif n_bytes == 4: + elif length == 4: data = [ scs.SCS_LOBYTE(scs.SCS_LOWORD(value)), scs.SCS_HIBYTE(scs.SCS_LOWORD(value)), diff --git a/lerobot/common/motors/motors_bus.py b/lerobot/common/motors/motors_bus.py index 3c64be7b..7bc8a4ae 100644 --- a/lerobot/common/motors/motors_bus.py +++ b/lerobot/common/motors/motors_bus.py @@ -642,57 +642,31 @@ class MotorsBus(abc.ABC): def _decode_sign(self, data_name: str, ids_values: dict[int, int]) -> dict[int, int]: pass - def _serialize_data(self, value: int, n_bytes: int) -> list[int]: + def _serialize_data(self, value: int, length: int) -> list[int]: """ Converts an unsigned integer value into a list of byte-sized integers to be sent via a communication protocol. Depending on the protocol, split values can be in big-endian or little-endian order. - This function extracts the individual bytes of an integer based on the - specified number of bytes (`n_bytes`). The output is a list of integers, - each representing a byte (0-255). - - **Byte order:** The function returns bytes in **little-endian format**, - meaning the least significant byte (LSB) comes first. - - Args: - value (int): The unsigned integer to be converted into a byte list. Must be within - the valid range for the specified `n_bytes`. - n_bytes (int): The number of bytes to use for conversion. Supported values for both Feetech and - Dynamixel: - - 1 (for values 0 to 255) - - 2 (for values 0 to 65,535) - - 4 (for values 0 to 4,294,967,295) - - Raises: - ValueError: If `value` is negative or exceeds the maximum allowed for `n_bytes`. - NotImplementedError: If `n_bytes` is not 1, 2, or 4. - - Returns: - list[int]: A list of integers, each representing a byte in **little-endian order**. - - Examples (for a little-endian protocol): - >>> split_int_bytes(0x12, 1) - [18] - >>> split_int_bytes(0x1234, 2) - [52, 18] # 0x1234 → 0x34 0x12 (little-endian) - >>> split_int_bytes(0x12345678, 4) - [120, 86, 52, 18] # 0x12345678 → 0x78 0x56 0x34 0x12 + Supported data length for both Feetech and Dynamixel: + - 1 (for values 0 to 255) + - 2 (for values 0 to 65,535) + - 4 (for values 0 to 4,294,967,295) """ if value < 0: raise ValueError(f"Negative values are not allowed: {value}") - max_value = {1: 0xFF, 2: 0xFFFF, 4: 0xFFFFFFFF}.get(n_bytes) + max_value = {1: 0xFF, 2: 0xFFFF, 4: 0xFFFFFFFF}.get(length) if max_value is None: - raise NotImplementedError(f"Unsupported byte size: {n_bytes}. Expected [1, 2, 4].") + raise NotImplementedError(f"Unsupported byte size: {length}. Expected [1, 2, 4].") if value > max_value: - raise ValueError(f"Value {value} exceeds the maximum for {n_bytes} bytes ({max_value}).") + raise ValueError(f"Value {value} exceeds the maximum for {length} bytes ({max_value}).") - return self._split_into_byte_chunks(value, n_bytes) + return self._split_into_byte_chunks(value, length) @staticmethod @abc.abstractmethod - def _split_into_byte_chunks(value: int, n_bytes: int) -> list[int]: + def _split_into_byte_chunks(value: int, length: int) -> list[int]: """Convert an integer into a list of byte-sized integers.""" pass @@ -736,9 +710,9 @@ class MotorsBus(abc.ABC): id_ = self.motors[motor].id model = self.motors[motor].model - addr, n_bytes = get_address(self.model_ctrl_table, model, data_name) + addr, length = get_address(self.model_ctrl_table, model, data_name) - value, comm, error = self._read(addr, n_bytes, id_, num_retry=num_retry) + value, comm, error = self._read(addr, length, id_, num_retry=num_retry) if not self._is_comm_success(comm): raise ConnectionError( f"Failed to read '{data_name}' on {id_=} after {num_retry + 1} tries." @@ -757,22 +731,22 @@ class MotorsBus(abc.ABC): return id_value[id_] - def _read(self, addr: int, n_bytes: int, motor_id: int, num_retry: int = 0) -> tuple[int, int]: - if n_bytes == 1: + def _read(self, address: int, length: int, motor_id: int, num_retry: int = 0) -> tuple[int, int]: + if length == 1: read_fn = self.packet_handler.read1ByteTxRx - elif n_bytes == 2: + elif length == 2: read_fn = self.packet_handler.read2ByteTxRx - elif n_bytes == 4: + elif length == 4: read_fn = self.packet_handler.read4ByteTxRx else: - raise ValueError(n_bytes) + raise ValueError(length) for n_try in range(1 + num_retry): - value, comm, error = read_fn(self.port_handler, motor_id, addr) + value, comm, error = read_fn(self.port_handler, motor_id, address) if self._is_comm_success(comm): break logger.debug( - f"Failed to read @{addr=} ({n_bytes=}) on {motor_id=} ({n_try=}): " + f"Failed to read @{address=} ({length=}) on {motor_id=} ({n_try=}): " + self.packet_handler.getTxRxResult(comm) ) @@ -788,14 +762,14 @@ class MotorsBus(abc.ABC): id_ = self.motors[motor].id model = self.motors[motor].model - addr, n_bytes = get_address(self.model_ctrl_table, model, data_name) + addr, length = get_address(self.model_ctrl_table, model, data_name) if normalize and data_name in self.normalized_data: value = self._unnormalize(data_name, {id_: value})[id_] value = self._encode_sign(data_name, {id_: value})[id_] - comm, error = self._write(addr, n_bytes, id_, value, num_retry=num_retry) + comm, error = self._write(addr, length, id_, value, num_retry=num_retry) if not self._is_comm_success(comm): raise ConnectionError( f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries." @@ -808,15 +782,15 @@ class MotorsBus(abc.ABC): ) def _write( - self, addr: int, n_bytes: int, motor_id: int, value: int, num_retry: int = 0 + self, addr: int, length: int, motor_id: int, value: int, num_retry: int = 0 ) -> tuple[int, int]: - data = self._serialize_data(value, n_bytes) + data = self._serialize_data(value, length) for n_try in range(1 + num_retry): - comm, error = self.packet_handler.writeTxRx(self.port_handler, motor_id, addr, n_bytes, data) + comm, error = self.packet_handler.writeTxRx(self.port_handler, motor_id, addr, length, data) if self._is_comm_success(comm): break logger.debug( - f"Failed to sync write @{addr=} ({n_bytes=}) on id={motor_id} with {value=} ({n_try=}): " + f"Failed to sync write @{addr=} ({length=}) on id={motor_id} with {value=} ({n_try=}): " + self.packet_handler.getTxRxResult(comm) ) @@ -845,9 +819,9 @@ class MotorsBus(abc.ABC): assert_same_address(self.model_ctrl_table, models, data_name) model = next(iter(models)) - addr, n_bytes = get_address(self.model_ctrl_table, model, data_name) + addr, length = get_address(self.model_ctrl_table, model, data_name) - comm, ids_values = self._sync_read(addr, n_bytes, ids, num_retry=num_retry) + comm, ids_values = self._sync_read(addr, length, ids, num_retry=num_retry) if not self._is_comm_success(comm): raise ConnectionError( f"Failed to sync read '{data_name}' on {ids=} after {num_retry + 1} tries." @@ -862,25 +836,25 @@ class MotorsBus(abc.ABC): return {self._id_to_name(id_): value for id_, value in ids_values.items()} def _sync_read( - self, addr: int, n_bytes: int, motor_ids: list[int], num_retry: int = 0 + self, addr: int, length: int, motor_ids: list[int], num_retry: int = 0 ) -> tuple[int, dict[int, int]]: - self._setup_sync_reader(motor_ids, addr, n_bytes) + self._setup_sync_reader(motor_ids, addr, length) for n_try in range(1 + num_retry): comm = self.sync_reader.txRxPacket() if self._is_comm_success(comm): break logger.debug( - f"Failed to sync read @{addr=} ({n_bytes=}) on {motor_ids=} ({n_try=}): " + f"Failed to sync read @{addr=} ({length=}) on {motor_ids=} ({n_try=}): " + self.packet_handler.getTxRxResult(comm) ) - values = {id_: self.sync_reader.getData(id_, addr, n_bytes) for id_ in motor_ids} + values = {id_: self.sync_reader.getData(id_, addr, length) for id_ in motor_ids} return comm, values - def _setup_sync_reader(self, motor_ids: list[int], addr: int, n_bytes: int) -> None: + def _setup_sync_reader(self, motor_ids: list[int], addr: int, length: int) -> None: self.sync_reader.clearParam() self.sync_reader.start_address = addr - self.sync_reader.data_length = n_bytes + self.sync_reader.data_length = length for id_ in motor_ids: self.sync_reader.addParam(id_) @@ -888,15 +862,15 @@ class MotorsBus(abc.ABC): # Would have to handle the logic of checking if a packet has been sent previously though but doable. # This could be at the cost of increase latency between the moment the data is produced by the motors and # the moment it is used by a policy. - # def _async_read(self, motor_ids: list[int], address: int, n_bytes: int): - # if self.sync_reader.start_address != address or self.sync_reader.data_length != n_bytes or ...: - # self._setup_sync_reader(motor_ids, address, n_bytes) + # def _async_read(self, motor_ids: list[int], address: int, length: int): + # if self.sync_reader.start_address != address or self.sync_reader.data_length != length or ...: + # self._setup_sync_reader(motor_ids, address, length) # else: # self.sync_reader.rxPacket() # self.sync_reader.txPacket() # for id_ in motor_ids: - # value = self.sync_reader.getData(id_, address, n_bytes) + # value = self.sync_reader.getData(id_, address, length) def sync_write( self, @@ -917,39 +891,39 @@ class MotorsBus(abc.ABC): assert_same_address(self.model_ctrl_table, models, data_name) model = next(iter(models)) - addr, n_bytes = get_address(self.model_ctrl_table, model, data_name) + addr, length = get_address(self.model_ctrl_table, model, data_name) if normalize and data_name in self.normalized_data: ids_values = self._unnormalize(data_name, ids_values) ids_values = self._encode_sign(data_name, ids_values) - comm = self._sync_write(addr, n_bytes, ids_values, num_retry=num_retry) + comm = self._sync_write(addr, length, ids_values, num_retry=num_retry) if not self._is_comm_success(comm): raise ConnectionError( f"Failed to sync write '{data_name}' with {ids_values=} after {num_retry + 1} tries." f"\n{self.packet_handler.getTxRxResult(comm)}" ) - def _sync_write(self, addr: int, n_bytes: int, ids_values: dict[int, int], num_retry: int = 0) -> int: - self._setup_sync_writer(ids_values, addr, n_bytes) + def _sync_write(self, addr: int, length: int, ids_values: dict[int, int], num_retry: int = 0) -> int: + self._setup_sync_writer(ids_values, addr, length) for n_try in range(1 + num_retry): comm = self.sync_writer.txPacket() if self._is_comm_success(comm): break logger.debug( - f"Failed to sync write @{addr=} ({n_bytes=}) with {ids_values=} ({n_try=}): " + f"Failed to sync write @{addr=} ({length=}) with {ids_values=} ({n_try=}): " + self.packet_handler.getTxRxResult(comm) ) return comm - def _setup_sync_writer(self, ids_values: dict[int, int], addr: int, n_bytes: int) -> None: + def _setup_sync_writer(self, ids_values: dict[int, int], addr: int, length: int) -> None: self.sync_writer.clearParam() self.sync_writer.start_address = addr - self.sync_writer.data_length = n_bytes + self.sync_writer.data_length = length for id_, value in ids_values.items(): - data = self._serialize_data(value, n_bytes) + data = self._serialize_data(value, length) self.sync_writer.addParam(id_, data) def disconnect(self, disable_torque: bool = True) -> None: diff --git a/tests/mocks/mock_dynamixel.py b/tests/mocks/mock_dynamixel.py index 454d8da8..feae051b 100644 --- a/tests/mocks/mock_dynamixel.py +++ b/tests/mocks/mock_dynamixel.py @@ -237,7 +237,7 @@ class MockInstructionPacket(MockDynamixelPacketv2): +2 is for the length bytes, +2 is for the CRC at the end. """ - data = DynamixelMotorsBus._split_int_to_bytes(value, data_length) + data = DynamixelMotorsBus._split_into_byte_chunks(value, data_length) params = [ dxl.DXL_LOBYTE(start_address), dxl.DXL_HIBYTE(start_address), @@ -315,7 +315,7 @@ class MockInstructionPacket(MockDynamixelPacketv2): """ data = [] for id_, value in ids_values.items(): - split_value = DynamixelMotorsBus._split_int_to_bytes(value, data_length) + split_value = DynamixelMotorsBus._split_into_byte_chunks(value, data_length) data += [id_, *split_value] params = [ dxl.DXL_LOBYTE(start_address), @@ -389,7 +389,7 @@ class MockStatusPacket(MockDynamixelPacketv2): Returns: bytes: The raw 'Present_Position' status packet ready to be sent through serial. """ - params = DynamixelMotorsBus._split_int_to_bytes(value, param_length) + params = DynamixelMotorsBus._split_into_byte_chunks(value, param_length) length = param_length + 4 return cls.build(dxl_id, params=params, length=length) diff --git a/tests/mocks/mock_feetech.py b/tests/mocks/mock_feetech.py index dfddaa1f..57bd8cbc 100644 --- a/tests/mocks/mock_feetech.py +++ b/tests/mocks/mock_feetech.py @@ -49,7 +49,7 @@ class MockFeetechPacket(abc.ABC): for id_ in range(2, len(packet) - 1): # except header & checksum checksum += packet[id_] - packet[-1] = scs.SCS_LOBYTE(~checksum) + packet[-1] = ~checksum & 0xFF return packet @@ -139,7 +139,7 @@ class MockInstructionPacket(MockFeetechPacket): +1 is for the length bytes, +1 is for the checksum at the end. """ - data = FeetechMotorsBus._split_int_to_bytes(value, data_length) + data = FeetechMotorsBus._split_into_byte_chunks(value, data_length) params = [start_address, *data] length = data_length + 3 return cls.build(scs_id=scs_id, params=params, length=length, instruct_type="Write") @@ -201,7 +201,7 @@ class MockInstructionPacket(MockFeetechPacket): """ data = [] for id_, value in ids_values.items(): - split_value = FeetechMotorsBus._split_int_to_bytes(value, data_length) + split_value = FeetechMotorsBus._split_into_byte_chunks(value, data_length) data += [id_, *split_value] params = [start_address, data_length, *data] length = len(ids_values) * (1 + data_length) + 4 @@ -258,7 +258,7 @@ class MockStatusPacket(MockFeetechPacket): Returns: bytes: The raw 'Sync Read' status packet ready to be sent through serial. """ - params = FeetechMotorsBus._split_int_to_bytes(value, param_length) + params = FeetechMotorsBus._split_into_byte_chunks(value, param_length) length = param_length + 2 return cls.build(scs_id, params=params, length=length) diff --git a/tests/motors/test_dynamixel.py b/tests/motors/test_dynamixel.py index 6fd0e3a7..e047e7c1 100644 --- a/tests/motors/test_dynamixel.py +++ b/tests/motors/test_dynamixel.py @@ -62,7 +62,7 @@ def test_autouse_patch(): @pytest.mark.parametrize( - "value, n_bytes, expected", + "value, length, expected", [ (0x12, 1, [0x12]), (0x1234, 2, [0x34, 0x12]), @@ -86,24 +86,24 @@ def test_autouse_patch(): "max four bytes", ], ) # fmt: skip -def test_split_int_to_bytes(value, n_bytes, expected): - assert DynamixelMotorsBus._split_int_to_bytes(value, n_bytes) == expected +def test_serialize_data(value, length, expected): + assert DynamixelMotorsBus._serialize_data(value, length) == expected -def test_split_int_to_bytes_invalid_n_bytes(): +def test_serialize_data_invalid_length(): with pytest.raises(NotImplementedError): - DynamixelMotorsBus._split_int_to_bytes(100, 3) + DynamixelMotorsBus._serialize_data(100, 3) -def test_split_int_to_bytes_negative_numbers(): +def test_serialize_data_negative_numbers(): with pytest.raises(ValueError): - neg = DynamixelMotorsBus._split_int_to_bytes(-1, 1) + neg = DynamixelMotorsBus._serialize_data(-1, 1) print(neg) -def test_split_int_to_bytes_large_number(): +def test_serialize_data_large_number(): with pytest.raises(ValueError): - DynamixelMotorsBus._split_int_to_bytes(2**32, 4) # 4-byte max is 0xFFFFFFFF + DynamixelMotorsBus._serialize_data(2**32, 4) # 4-byte max is 0xFFFFFFFF def test_abc_implementation(dummy_motors): diff --git a/tests/motors/test_feetech.py b/tests/motors/test_feetech.py index 5372c37a..da819464 100644 --- a/tests/motors/test_feetech.py +++ b/tests/motors/test_feetech.py @@ -61,7 +61,7 @@ def test_autouse_patch(): @pytest.mark.parametrize( - "value, n_bytes, expected", + "value, length, expected", [ (0x12, 1, [0x12]), (0x1234, 2, [0x34, 0x12]), @@ -85,24 +85,24 @@ def test_autouse_patch(): "max four bytes", ], ) # fmt: skip -def test_split_int_to_bytes(value, n_bytes, expected): - assert FeetechMotorsBus._split_int_to_bytes(value, n_bytes) == expected +def test_serialize_data(value, length, expected): + assert FeetechMotorsBus._serialize_data(value, length) == expected -def test_split_int_to_bytes_invalid_n_bytes(): +def test_serialize_data_invalid_length(): with pytest.raises(NotImplementedError): - FeetechMotorsBus._split_int_to_bytes(100, 3) + FeetechMotorsBus._serialize_data(100, 3) -def test_split_int_to_bytes_negative_numbers(): +def test_serialize_data_negative_numbers(): with pytest.raises(ValueError): - neg = FeetechMotorsBus._split_int_to_bytes(-1, 1) + neg = FeetechMotorsBus._serialize_data(-1, 1) print(neg) -def test_split_int_to_bytes_large_number(): +def test_serialize_data_large_number(): with pytest.raises(ValueError): - FeetechMotorsBus._split_int_to_bytes(2**32, 4) # 4-byte max is 0xFFFFFFFF + FeetechMotorsBus._serialize_data(2**32, 4) # 4-byte max is 0xFFFFFFFF def test_abc_implementation(dummy_motors):