From 4005065223ceea9fd4bae7eb62528d68df5f8a53 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Thu, 10 Apr 2025 00:51:23 +0200 Subject: [PATCH] (nit) move write --- lerobot/common/motors/motors_bus.py | 88 ++++++++++++++--------------- tests/mocks/mock_dynamixel.py | 68 +++++++++++----------- tests/mocks/mock_feetech.py | 58 +++++++++---------- 3 files changed, 107 insertions(+), 107 deletions(-) diff --git a/lerobot/common/motors/motors_bus.py b/lerobot/common/motors/motors_bus.py index 9f3fcdb2..16aa0402 100644 --- a/lerobot/common/motors/motors_bus.py +++ b/lerobot/common/motors/motors_bus.py @@ -784,6 +784,50 @@ class MotorsBus(abc.ABC): return value, comm, error + def write( + self, data_name: str, motor: str, value: Value, *, normalize: bool = True, num_retry: int = 0 + ) -> None: + if not self.is_connected: + raise DeviceNotConnectedError( + f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`." + ) + + id_ = self.motors[motor].id + model = self.motors[motor].model + addr, n_bytes = get_address(self.model_ctrl_table, model, data_name) + + if normalize and data_name in self.normalized_data: + value = self._unnormalize(data_name, {id_: value})[id_] + + value = self._encode_sign(data_name, {id_: value})[id_] + + comm, error = self._write(addr, n_bytes, id_, value, num_retry=num_retry) + if not self._is_comm_success(comm): + raise ConnectionError( + f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries." + f"\n{self.packet_handler.getTxRxResult(comm)}" + ) + elif self._is_error(error): + raise RuntimeError( + f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries." + f"\n{self.packet_handler.getRxPacketError(error)}" + ) + + def _write( + self, addr: int, n_bytes: int, motor_id: int, value: int, num_retry: int = 0 + ) -> tuple[int, int]: + data = self._split_int_to_bytes(value, n_bytes) + for n_try in range(1 + num_retry): + comm, error = self.packet_handler.writeTxRx(self.port_handler, motor_id, addr, n_bytes, data) + if self._is_comm_success(comm): + break + logger.debug( + f"Failed to sync write @{addr=} ({n_bytes=}) on id={motor_id} with {value=} ({n_try=}): " + + self.packet_handler.getTxRxResult(comm) + ) + + return comm, error + def sync_read( self, data_name: str, @@ -914,50 +958,6 @@ class MotorsBus(abc.ABC): data = self._split_int_to_bytes(value, n_bytes) self.sync_writer.addParam(id_, data) - def write( - self, data_name: str, motor: str, value: Value, *, normalize: bool = True, num_retry: int = 0 - ) -> None: - if not self.is_connected: - raise DeviceNotConnectedError( - f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`." - ) - - id_ = self.motors[motor].id - model = self.motors[motor].model - addr, n_bytes = get_address(self.model_ctrl_table, model, data_name) - - if normalize and data_name in self.normalized_data: - value = self._unnormalize(data_name, {id_: value})[id_] - - value = self._encode_sign(data_name, {id_: value})[id_] - - comm, error = self._write(addr, n_bytes, id_, value, num_retry=num_retry) - if not self._is_comm_success(comm): - raise ConnectionError( - f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries." - f"\n{self.packet_handler.getTxRxResult(comm)}" - ) - elif self._is_error(error): - raise RuntimeError( - f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries." - f"\n{self.packet_handler.getRxPacketError(error)}" - ) - - def _write( - self, addr: int, n_bytes: int, motor_id: int, value: int, num_retry: int = 0 - ) -> tuple[int, int]: - data = self._split_int_to_bytes(value, n_bytes) - for n_try in range(1 + num_retry): - comm, error = self.packet_handler.writeTxRx(self.port_handler, motor_id, addr, n_bytes, data) - if self._is_comm_success(comm): - break - logger.debug( - f"Failed to sync write @{addr=} ({n_bytes=}) on id={motor_id} with {value=} ({n_try=}): " - + self.packet_handler.getTxRxResult(comm) - ) - - return comm, error - def disconnect(self, disable_torque: bool = True) -> None: if not self.is_connected: raise DeviceNotConnectedError( diff --git a/tests/mocks/mock_dynamixel.py b/tests/mocks/mock_dynamixel.py index 78738025..454d8da8 100644 --- a/tests/mocks/mock_dynamixel.py +++ b/tests/mocks/mock_dynamixel.py @@ -212,6 +212,40 @@ class MockInstructionPacket(MockDynamixelPacketv2): params, length = [], 3 return cls.build(dxl_id=dxl_id, params=params, length=length, instruct_type="Ping") + @classmethod + def write( + cls, + dxl_id: int, + value: int, + start_address: int, + data_length: int, + ) -> bytes: + """ + Builds a "Write" instruction. + https://emanual.robotis.com/docs/en/dxl/protocol2/#write-0x03 + + The parameters for Write (Protocol 2.0) are: + param[0] = start_address L + param[1] = start_address H + param[2] = 1st Byte + param[3] = 2nd Byte + ... + param[1+X] = X-th Byte + + And 'length' = data_length + 5, where: + +1 is for instruction byte, + +2 is for the length bytes, + +2 is for the CRC at the end. + """ + data = DynamixelMotorsBus._split_int_to_bytes(value, data_length) + params = [ + dxl.DXL_LOBYTE(start_address), + dxl.DXL_HIBYTE(start_address), + *data, + ] + length = data_length + 5 + return cls.build(dxl_id=dxl_id, params=params, length=length, instruct_type="Write") + @classmethod def sync_read( cls, @@ -293,40 +327,6 @@ class MockInstructionPacket(MockDynamixelPacketv2): length = len(ids_values) * (1 + data_length) + 7 return cls.build(dxl_id=dxl.BROADCAST_ID, params=params, length=length, instruct_type="Sync_Write") - @classmethod - def write( - cls, - dxl_id: int, - value: int, - start_address: int, - data_length: int, - ) -> bytes: - """ - Builds a "Write" instruction. - https://emanual.robotis.com/docs/en/dxl/protocol2/#write-0x03 - - The parameters for Write (Protocol 2.0) are: - param[0] = start_address L - param[1] = start_address H - param[2] = 1st Byte - param[3] = 2nd Byte - ... - param[1+X] = X-th Byte - - And 'length' = data_length + 5, where: - +1 is for instruction byte, - +2 is for the length bytes, - +2 is for the CRC at the end. - """ - data = DynamixelMotorsBus._split_int_to_bytes(value, data_length) - params = [ - dxl.DXL_LOBYTE(start_address), - dxl.DXL_HIBYTE(start_address), - *data, - ] - length = data_length + 5 - return cls.build(dxl_id=dxl_id, params=params, length=length, instruct_type="Write") - class MockStatusPacket(MockDynamixelPacketv2): """ diff --git a/tests/mocks/mock_feetech.py b/tests/mocks/mock_feetech.py index 82be9f20..dfddaa1f 100644 --- a/tests/mocks/mock_feetech.py +++ b/tests/mocks/mock_feetech.py @@ -115,6 +115,35 @@ class MockInstructionPacket(MockFeetechPacket): length = 4 return cls.build(scs_id=scs_id, params=params, length=length, instruct_type="Read") + @classmethod + def write( + cls, + scs_id: int, + value: int, + start_address: int, + data_length: int, + ) -> bytes: + """ + Builds a "Write" instruction. + + The parameters for Write are: + param[0] = start_address L + param[1] = start_address H + param[2] = 1st Byte + param[3] = 2nd Byte + ... + param[1+X] = X-th Byte + + And 'length' = data_length + 3, where: + +1 is for instruction byte, + +1 is for the length bytes, + +1 is for the checksum at the end. + """ + data = FeetechMotorsBus._split_int_to_bytes(value, data_length) + params = [start_address, *data] + length = data_length + 3 + return cls.build(scs_id=scs_id, params=params, length=length, instruct_type="Write") + @classmethod def sync_read( cls, @@ -178,35 +207,6 @@ class MockInstructionPacket(MockFeetechPacket): length = len(ids_values) * (1 + data_length) + 4 return cls.build(scs_id=scs.BROADCAST_ID, params=params, length=length, instruct_type="Sync_Write") - @classmethod - def write( - cls, - scs_id: int, - value: int, - start_address: int, - data_length: int, - ) -> bytes: - """ - Builds a "Write" instruction. - - The parameters for Write are: - param[0] = start_address L - param[1] = start_address H - param[2] = 1st Byte - param[3] = 2nd Byte - ... - param[1+X] = X-th Byte - - And 'length' = data_length + 3, where: - +1 is for instruction byte, - +1 is for the length bytes, - +1 is for the checksum at the end. - """ - data = FeetechMotorsBus._split_int_to_bytes(value, data_length) - params = [start_address, *data] - length = data_length + 3 - return cls.build(scs_id=scs_id, params=params, length=length, instruct_type="Write") - class MockStatusPacket(MockFeetechPacket): """