diff --git a/lerobot/common/motors/feetech/feetech.py b/lerobot/common/motors/feetech/feetech.py index 5e957f2f..193f1b4a 100644 --- a/lerobot/common/motors/feetech/feetech.py +++ b/lerobot/common/motors/feetech/feetech.py @@ -21,12 +21,13 @@ from lerobot.common.utils.encoding_utils import decode_sign_magnitude, encode_si from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value from .tables import ( - FIRMWARE_VERSION, + FIRMWARE_MAJOR_VERSION, MODEL_BAUDRATE_TABLE, MODEL_CONTROL_TABLE, MODEL_ENCODING_TABLE, MODEL_NUMBER, MODEL_NUMBER_TABLE, + MODEL_PROTOCOL, MODEL_RESOLUTION, SCAN_BAUDRATES, ) @@ -117,9 +118,10 @@ class FeetechMotorsBus(MotorsBus): protocol_version: int = DEFAULT_PROTOCOL_VERSION, ): super().__init__(port, motors, calibration) + self.protocol_version = protocol_version + self._assert_same_protocol() import scservo_sdk as scs - self.protocol_version = protocol_version self.port_handler = scs.PortHandler(self.port) # HACK: monkeypatch self.port_handler.setPacketTimeout = patch_setPacketTimeout.__get__( @@ -131,10 +133,21 @@ class FeetechMotorsBus(MotorsBus): self._comm_success = scs.COMM_SUCCESS self._no_error = 0x00 + if any(MODEL_PROTOCOL[model] != self.protocol_version for model in self.models): + raise ValueError(f"Some motors are incompatible with protocol_version={self.protocol_version}") + + def _assert_same_protocol(self) -> None: + if any(MODEL_PROTOCOL[model] != self.protocol_version for model in self.models): + raise RuntimeError("Some motors use an incompatible protocol.") + def _assert_protocol_is_compatible(self, instruction_name: str) -> None: if instruction_name == "sync_read" and self.protocol_version == 1: raise NotImplementedError( - "'Sync Read' is not available with Feetech motors using Protocol 1. Use 'Read' instead." + "'Sync Read' is not available with Feetech motors using Protocol 1. Use 'Read' sequentially instead." + ) + if instruction_name == "broadcast_ping" and self.protocol_version == 1: + raise NotImplementedError( + "'Broadcast Ping' is not available with Feetech motors using Protocol 1. Use 'Ping' sequentially instead." ) def configure_motors(self) -> None: @@ -157,12 +170,12 @@ class FeetechMotorsBus(MotorsBus): return half_turn_homings def disable_torque(self, motors: str | list[str] | None = None) -> None: - for name in self._get_names_list(motors): + for name in self._get_motors_list(motors): self.write("Torque_Enable", name, TorqueMode.DISABLED.value) self.write("Lock", name, 0) def enable_torque(self, motors: str | list[str] | None = None) -> None: - for name in self._get_names_list(motors): + for name in self._get_motors_list(motors): self.write("Torque_Enable", name, TorqueMode.ENABLED.value) self.write("Lock", name, 1) @@ -286,56 +299,52 @@ class FeetechMotorsBus(MotorsBus): rx_length = rx_length - idx def broadcast_ping(self, num_retry: int = 0, raise_on_error: bool = False) -> dict[int, int] | None: - if self.protocol_version == 0: - for n_try in range(1 + num_retry): - ids_status, comm = self._broadcast_ping_p0() - if self._is_comm_success(comm): - break - logger.debug(f"Broadcast ping failed on port '{self.port}' ({n_try=})") - logger.debug(self.packet_handler.getTxRxResult(comm)) + self._assert_protocol_is_compatible("broadcast_ping") + for n_try in range(1 + num_retry): + ids_status, comm = self._broadcast_ping_p0() + if self._is_comm_success(comm): + break + logger.debug(f"Broadcast ping failed on port '{self.port}' ({n_try=})") + logger.debug(self.packet_handler.getTxRxResult(comm)) - if not self._is_comm_success(comm): - if raise_on_error: - raise ConnectionError(self.packet_handler.getTxRxResult(comm)) - return - - ids_errors = {id_: status for id_, status in ids_status.items() if self._is_error(status)} - if ids_errors: - display_dict = { - id_: self.packet_handler.getRxPacketError(err) for id_, err in ids_errors.items() - } - logger.error( - f"Some motors found returned an error status:\n{pformat(display_dict, indent=4)}" - ) - - return self._get_model_number(list(ids_status), raise_on_error) - else: - return self._broadcast_ping_p1(num_retry=num_retry) - - def _get_firmware_version(self, motor_ids: list[int], raise_on_error: bool = False) -> dict[int, int]: - comm, firmware_versions = self._sync_read(*FIRMWARE_VERSION, motor_ids) if not self._is_comm_success(comm): if raise_on_error: raise ConnectionError(self.packet_handler.getTxRxResult(comm)) return + ids_errors = {id_: status for id_, status in ids_status.items() if self._is_error(status)} + if ids_errors: + display_dict = {id_: self.packet_handler.getRxPacketError(err) for id_, err in ids_errors.items()} + logger.error(f"Some motors found returned an error status:\n{pformat(display_dict, indent=4)}") + + return self._get_model_number(list(ids_status), raise_on_error) + + def _get_firmware_version(self, motor_ids: list[int], raise_on_error: bool = False) -> dict[int, str]: + firmware_versions = {} + for id_ in motor_ids: + firm_ver_major, comm, error = self._read( + *FIRMWARE_MAJOR_VERSION, id_, raise_on_error=raise_on_error + ) + if not self._is_comm_success(comm) or self._is_error(error): + return + + firm_ver_minor, comm, error = self._read( + *FIRMWARE_MAJOR_VERSION, id_, raise_on_error=raise_on_error + ) + if not self._is_comm_success(comm) or self._is_error(error): + return + + firmware_versions[id_] = f"{firm_ver_major}.{firm_ver_minor}" + return firmware_versions def _get_model_number(self, motor_ids: list[int], raise_on_error: bool = False) -> dict[int, int]: - if self.protocol_version == 1: - model_numbers = {} - for id_ in motor_ids: - model_nb, comm, error = self._read(*MODEL_NUMBER, id_) - if self._is_comm_success(comm) and not self._is_error(error): - model_numbers[id_] = model_nb - elif raise_on_error: - raise Exception # FIX - - else: - comm, model_numbers = self._sync_read(*MODEL_NUMBER, motor_ids) - if not self._is_comm_success(comm): - if raise_on_error: - raise ConnectionError(self.packet_handler.getTxRxResult(comm)) + model_numbers = {} + for id_ in motor_ids: + model_nb, comm, error = self._read(*MODEL_NUMBER, id_, raise_on_error=raise_on_error) + if not self._is_comm_success(comm) or self._is_error(error): return + model_numbers[id_] = model_nb + return model_numbers diff --git a/lerobot/common/motors/feetech/tables.py b/lerobot/common/motors/feetech/tables.py index 17603317..ada8d08f 100644 --- a/lerobot/common/motors/feetech/tables.py +++ b/lerobot/common/motors/feetech/tables.py @@ -1,9 +1,5 @@ FIRMWARE_MAJOR_VERSION = (0, 1) FIRMWARE_MINOR_VERSION = (1, 1) -MODEL_MAJOR_VERSION = (3, 1) -MODEL_MINOR_VERSION = (4, 1) - -FIRMWARE_VERSION = (0, 2) MODEL_NUMBER = (3, 2) # See this link for STS3215 Memory Table: @@ -11,12 +7,9 @@ MODEL_NUMBER = (3, 2) # data_name: (address, size_byte) STS_SMS_SERIES_CONTROL_TABLE = { # EPROM - "Firmware_Version": FIRMWARE_VERSION, # read-only + "Firmware_Major_Version": FIRMWARE_MAJOR_VERSION, # read-only + "Firmware_Minor_Version": FIRMWARE_MINOR_VERSION, # read-only "Model_Number": MODEL_NUMBER, # read-only - # "Firmware_Major_Version": FIRMWARE_MAJOR_VERSION, # read-only - # "Firmware_Minor_Version": FIRMWARE_MINOR_VERSION, # read-only - # "Model_Major_Version": MODEL_MAJOR_VERSION, # read-only - # "Model_Minor_Version": MODEL_MINOR_VERSION, "ID": (5, 1), "Baud_Rate": (6, 1), "Return_Delay_Time": (7, 1), @@ -68,12 +61,9 @@ STS_SMS_SERIES_CONTROL_TABLE = { SCS_SERIES_CONTROL_TABLE = { # EPROM - "Firmware_Version": FIRMWARE_VERSION, # read-only + "Firmware_Major_Version": FIRMWARE_MAJOR_VERSION, # read-only + "Firmware_Minor_Version": FIRMWARE_MINOR_VERSION, # read-only "Model_Number": MODEL_NUMBER, # read-only - # "Firmware_Major_Version": FIRMWARE_MAJOR_VERSION, # read-only - # "Firmware_Minor_Version": FIRMWARE_MINOR_VERSION, # read-only - # "Model_Major_Version": MODEL_MAJOR_VERSION, # read-only - # "Model_Minor_Version": MODEL_MINOR_VERSION, "ID": (5, 1), "Baud_Rate": (6, 1), "Return_Delay": (7, 1), @@ -194,10 +184,19 @@ SCAN_BAUDRATES = [ 1_000_000, ] -# {model: model_number} TODO MODEL_NUMBER_TABLE = { "sts3215": 777, - "sts3250": None, + "sts3250": 2825, "sm8512bl": 11272, "scs0009": 1284, } + +MODEL_PROTOCOL = { + "sts_series": 0, + "sms_series": 0, + "scs_series": 1, + "sts3215": 0, + "sts3250": 0, + "sm8512bl": 0, + "scs0009": 1, +} diff --git a/lerobot/common/motors/motors_bus.py b/lerobot/common/motors/motors_bus.py index cada33a7..d0f8ff3e 100644 --- a/lerobot/common/motors/motors_bus.py +++ b/lerobot/common/motors/motors_bus.py @@ -283,6 +283,8 @@ class MotorsBus(abc.ABC): self._id_to_name_dict = {m.id: name for name, m in self.motors.items()} self._model_nb_to_model_dict = {v: k for k, v in self.model_number_table.items()} + self._validate_motors() + def __len__(self): return len(self.motors) @@ -341,7 +343,7 @@ class MotorsBus(abc.ABC): else: raise TypeError(f"'{motor}' should be int, str.") - def _get_names_list(self, motors: str | list[str] | None) -> list[str]: + def _get_motors_list(self, motors: str | list[str] | None) -> list[str]: if motors is None: return self.names elif isinstance(motors, str): @@ -422,8 +424,8 @@ class MotorsBus(abc.ABC): logger.debug(f"{self.__class__.__name__} connected.") @classmethod - def scan_port(cls, port: str) -> dict[int, list[int]]: - bus = cls(port, {}) + def scan_port(cls, port: str, *args, **kwargs) -> dict[int, list[int]]: + bus = cls(port, {}, *args, **kwargs) try: bus.port_handler.openPort() except (FileNotFoundError, OSError, serial.SerialException) as e: @@ -715,17 +717,8 @@ class MotorsBus(abc.ABC): model = self.motors[motor].model addr, length = get_address(self.model_ctrl_table, model, data_name) - value, comm, error = self._read(addr, length, id_, num_retry=num_retry) - if not self._is_comm_success(comm): - raise ConnectionError( - f"Failed to read '{data_name}' on {id_=} after {num_retry + 1} tries." - f"{self.packet_handler.getTxRxResult(comm)}" - ) - elif self._is_error(error): - raise RuntimeError( - f"Failed to read '{data_name}' on {id_=} after {num_retry + 1} tries." - f"\n{self.packet_handler.getRxPacketError(error)}" - ) + err_msg = f"Failed to read '{data_name}' on {id_=} after {num_retry + 1} tries." + value, _, _ = self._read(addr, length, id_, num_retry=num_retry, raise_on_error=True, err_msg=err_msg) id_value = self._decode_sign(data_name, {id_: value}) @@ -734,7 +727,16 @@ class MotorsBus(abc.ABC): return id_value[id_] - def _read(self, address: int, length: int, motor_id: int, num_retry: int = 0) -> tuple[int, int]: + def _read( + self, + address: int, + length: int, + motor_id: int, + *, + num_retry: int = 0, + raise_on_error: bool = True, + err_msg: str = "", + ) -> tuple[int, int]: if length == 1: read_fn = self.packet_handler.read1ByteTxRx elif length == 2: @@ -753,6 +755,11 @@ class MotorsBus(abc.ABC): + self.packet_handler.getTxRxResult(comm) ) + if not self._is_comm_success(comm) and raise_on_error: + raise ConnectionError(f"{err_msg} {self.packet_handler.getTxRxResult(comm)}") + elif self._is_error(error) and raise_on_error: + raise RuntimeError(f"{err_msg} {self.packet_handler.getRxPacketError(error)}") + return value, comm, error def write( @@ -772,20 +779,19 @@ class MotorsBus(abc.ABC): value = self._encode_sign(data_name, {id_: value})[id_] - comm, error = self._write(addr, length, id_, value, num_retry=num_retry) - if not self._is_comm_success(comm): - raise ConnectionError( - f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries." - f"\n{self.packet_handler.getTxRxResult(comm)}" - ) - elif self._is_error(error): - raise RuntimeError( - f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries." - f"\n{self.packet_handler.getRxPacketError(error)}" - ) + err_msg = f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries." + self._write(addr, length, id_, value, num_retry=num_retry, raise_on_error=True, err_msg=err_msg) def _write( - self, addr: int, length: int, motor_id: int, value: int, num_retry: int = 0 + self, + addr: int, + length: int, + motor_id: int, + value: int, + *, + num_retry: int = 0, + raise_on_error: bool = True, + err_msg: str = "", ) -> tuple[int, int]: data = self._serialize_data(value, length) for n_try in range(1 + num_retry): @@ -797,6 +803,11 @@ class MotorsBus(abc.ABC): + self.packet_handler.getTxRxResult(comm) ) + if not self._is_comm_success(comm) and raise_on_error: + raise ConnectionError(f"{err_msg} {self.packet_handler.getTxRxResult(comm)}") + elif self._is_error(error) and raise_on_error: + raise RuntimeError(f"{err_msg} {self.packet_handler.getRxPacketError(error)}") + return comm, error def sync_read( @@ -814,7 +825,7 @@ class MotorsBus(abc.ABC): self._assert_protocol_is_compatible("sync_read") - names = self._get_names_list(motors) + names = self._get_motors_list(motors) ids = [self.motors[name].id for name in names] models = [self.motors[name].model for name in names] @@ -824,12 +835,10 @@ class MotorsBus(abc.ABC): model = next(iter(models)) addr, length = get_address(self.model_ctrl_table, model, data_name) - comm, ids_values = self._sync_read(addr, length, ids, num_retry=num_retry) - if not self._is_comm_success(comm): - raise ConnectionError( - f"Failed to sync read '{data_name}' on {ids=} after {num_retry + 1} tries." - f"{self.packet_handler.getTxRxResult(comm)}" - ) + err_msg = f"Failed to sync read '{data_name}' on {ids=} after {num_retry + 1} tries." + ids_values, _ = self._sync_read( + addr, length, ids, num_retry=num_retry, raise_on_error=True, err_msg=err_msg + ) ids_values = self._decode_sign(data_name, ids_values) @@ -839,8 +848,15 @@ class MotorsBus(abc.ABC): return {self._id_to_name(id_): value for id_, value in ids_values.items()} def _sync_read( - self, addr: int, length: int, motor_ids: list[int], num_retry: int = 0 - ) -> tuple[int, dict[int, int]]: + self, + addr: int, + length: int, + motor_ids: list[int], + *, + num_retry: int = 0, + raise_on_error: bool = True, + err_msg: str = "", + ) -> tuple[dict[int, int], int]: self._setup_sync_reader(motor_ids, addr, length) for n_try in range(1 + num_retry): comm = self.sync_reader.txRxPacket() @@ -851,8 +867,11 @@ class MotorsBus(abc.ABC): + self.packet_handler.getTxRxResult(comm) ) + if not self._is_comm_success(comm) and raise_on_error: + raise ConnectionError(f"{err_msg} {self.packet_handler.getTxRxResult(comm)}") + values = {id_: self.sync_reader.getData(id_, addr, length) for id_ in motor_ids} - return comm, values + return values, comm def _setup_sync_reader(self, motor_ids: list[int], addr: int, length: int) -> None: self.sync_reader.clearParam() @@ -901,14 +920,18 @@ class MotorsBus(abc.ABC): ids_values = self._encode_sign(data_name, ids_values) - comm = self._sync_write(addr, length, ids_values, num_retry=num_retry) - if not self._is_comm_success(comm): - raise ConnectionError( - f"Failed to sync write '{data_name}' with {ids_values=} after {num_retry + 1} tries." - f"\n{self.packet_handler.getTxRxResult(comm)}" - ) + err_msg = f"Failed to sync write '{data_name}' with {ids_values=} after {num_retry + 1} tries." + self._sync_write(addr, length, ids_values, num_retry=num_retry, raise_on_error=True, err_msg=err_msg) - def _sync_write(self, addr: int, length: int, ids_values: dict[int, int], num_retry: int = 0) -> int: + def _sync_write( + self, + addr: int, + length: int, + ids_values: dict[int, int], + num_retry: int = 0, + raise_on_error: bool = True, + err_msg: str = "", + ) -> int: self._setup_sync_writer(ids_values, addr, length) for n_try in range(1 + num_retry): comm = self.sync_writer.txPacket() @@ -919,6 +942,9 @@ class MotorsBus(abc.ABC): + self.packet_handler.getTxRxResult(comm) ) + if not self._is_comm_success(comm) and raise_on_error: + raise ConnectionError(f"{err_msg} {self.packet_handler.getTxRxResult(comm)}") + return comm def _setup_sync_writer(self, ids_values: dict[int, int], addr: int, length: int) -> None: diff --git a/tests/mocks/mock_feetech.py b/tests/mocks/mock_feetech.py index 2b54ae91..f4bb1c68 100644 --- a/tests/mocks/mock_feetech.py +++ b/tests/mocks/mock_feetech.py @@ -10,27 +10,6 @@ from lerobot.common.motors.feetech.feetech import _split_into_byte_chunks, patch from .mock_serial_patch import WaitableStub -# https://files.waveshare.com/upload/2/27/Communication_Protocol_User_Manual-EN%28191218-0923%29.pdf -INSTRUCTION_TYPES = { - "Read": scs.INST_PING, # Read data from the Device - "Ping": scs.INST_READ, # Checks whether the Packet has arrived at a device with the same ID as the specified packet ID - "Write": scs.INST_WRITE, # Write data to the Device - "Reg_Write": scs.INST_REG_WRITE, # Register the Instruction Packet in standby status; Packet can later be executed using the Action command - "Action": scs.INST_ACTION, # Executes a Packet that was registered beforehand using Reg Write - "Factory_Reset": 0x06, # Resets the Control Table to its initial factory default settings - "Sync_Write": scs.INST_SYNC_WRITE, # Write data to multiple devices with the same Address with the same length at once - "Sync_Read": scs.INST_SYNC_READ, # Read data from multiple devices with the same Address with the same length at once -} # fmt: skip - -ERROR_TYPE = { - "Success": 0x00, - "Voltage": scs.ERRBIT_VOLTAGE, - "Angle": scs.ERRBIT_ANGLE, - "Overheat": scs.ERRBIT_OVERHEAT, - "Overele": scs.ERRBIT_OVERELE, - "Overload": scs.ERRBIT_OVERLOAD, -} - class MockFeetechPacket(abc.ABC): @classmethod @@ -68,15 +47,14 @@ class MockInstructionPacket(MockFeetechPacket): """ @classmethod - def _build(cls, scs_id: int, params: list[int], length: int, instruct_type: str) -> list[int]: - instruct_value = INSTRUCTION_TYPES[instruct_type] + def _build(cls, scs_id: int, params: list[int], length: int, instruction: int) -> list[int]: return [ - 0xFF, 0xFF, # header - scs_id, # servo id - length, # length - instruct_value, # instruction type - *params, # data bytes - 0x00, # placeholder for checksum + 0xFF, 0xFF, # header + scs_id, # servo id + length, # length + instruction, # instruction type + *params, # data bytes + 0x00, # placeholder for checksum ] # fmt: skip @classmethod @@ -89,7 +67,7 @@ class MockInstructionPacket(MockFeetechPacket): No parameters required. """ - return cls.build(scs_id=scs_id, params=[], length=2, instruct_type="Ping") + return cls.build(scs_id=scs_id, params=[], length=2, instruction=scs.INST_PING) @classmethod def read( @@ -113,7 +91,7 @@ class MockInstructionPacket(MockFeetechPacket): """ params = [start_address, data_length] length = 4 - return cls.build(scs_id=scs_id, params=params, length=length, instruct_type="Read") + return cls.build(scs_id=scs_id, params=params, length=length, instruction=scs.INST_READ) @classmethod def write( @@ -142,7 +120,7 @@ class MockInstructionPacket(MockFeetechPacket): data = _split_into_byte_chunks(value, data_length) params = [start_address, *data] length = data_length + 3 - return cls.build(scs_id=scs_id, params=params, length=length, instruct_type="Write") + return cls.build(scs_id=scs_id, params=params, length=length, instruction=scs.INST_WRITE) @classmethod def sync_read( @@ -167,7 +145,9 @@ class MockInstructionPacket(MockFeetechPacket): """ params = [start_address, data_length, *scs_ids] length = len(scs_ids) + 4 - return cls.build(scs_id=scs.BROADCAST_ID, params=params, length=length, instruct_type="Sync_Read") + return cls.build( + scs_id=scs.BROADCAST_ID, params=params, length=length, instruction=scs.INST_SYNC_READ + ) @classmethod def sync_write( @@ -205,7 +185,9 @@ class MockInstructionPacket(MockFeetechPacket): data += [id_, *split_value] params = [start_address, data_length, *data] length = len(ids_values) * (1 + data_length) + 4 - return cls.build(scs_id=scs.BROADCAST_ID, params=params, length=length, instruct_type="Sync_Write") + return cls.build( + scs_id=scs.BROADCAST_ID, params=params, length=length, instruction=scs.INST_SYNC_WRITE + ) class MockStatusPacket(MockFeetechPacket): @@ -222,19 +204,18 @@ class MockStatusPacket(MockFeetechPacket): """ @classmethod - def _build(cls, scs_id: int, params: list[int], length: int, error: str = "Success") -> list[int]: - err_byte = ERROR_TYPE[error] + def _build(cls, scs_id: int, params: list[int], length: int, error: int = 0) -> list[int]: return [ 0xFF, 0xFF, # header scs_id, # servo id length, # length - err_byte, # status + error, # status *params, # data bytes 0x00, # placeholder for checksum ] # fmt: skip @classmethod - def ping(cls, scs_id: int, error: str = "Success") -> bytes: + def ping(cls, scs_id: int, error: int = 0) -> bytes: """Builds a 'Ping' status packet. Args: @@ -247,7 +228,7 @@ class MockStatusPacket(MockFeetechPacket): return cls.build(scs_id, params=[], length=2, error=error) @classmethod - def read(cls, scs_id: int, value: int, param_length: int) -> bytes: + def read(cls, scs_id: int, value: int, param_length: int, error: int = 0) -> bytes: """Builds a 'Read' status packet. Args: @@ -260,7 +241,7 @@ class MockStatusPacket(MockFeetechPacket): """ params = _split_into_byte_chunks(value, param_length) length = param_length + 2 - return cls.build(scs_id, params=params, length=length) + return cls.build(scs_id, params=params, length=length, error=error) class MockPortHandler(scs.PortHandler): @@ -323,11 +304,11 @@ class MockMotors(MockSerial): ) return stub_name - def build_ping_stub(self, scs_id: int, num_invalid_try: int = 0) -> str: + def build_ping_stub(self, scs_id: int, num_invalid_try: int = 0, error: int = 0) -> str: ping_request = MockInstructionPacket.ping(scs_id) - return_packet = MockStatusPacket.ping(scs_id) + return_packet = MockStatusPacket.ping(scs_id, error) ping_response = self._build_send_fn(return_packet, num_invalid_try) - stub_name = f"Ping_{scs_id}" + stub_name = f"Ping_{scs_id}_{error}" self.stub( name=stub_name, receive_bytes=ping_request, @@ -336,13 +317,19 @@ class MockMotors(MockSerial): return stub_name def build_read_stub( - self, data_name: str, scs_id: int, value: int | None = None, num_invalid_try: int = 0 + self, + address: int, + length: int, + scs_id: int, + value: int, + reply: bool = True, + error: int = 0, + num_invalid_try: int = 0, ) -> str: - address, length = self.ctrl_table[data_name] read_request = MockInstructionPacket.read(scs_id, address, length) - return_packet = MockStatusPacket.read(scs_id, value, length) + return_packet = MockStatusPacket.read(scs_id, value, length, error) if reply else b"" read_response = self._build_send_fn(return_packet, num_invalid_try) - stub_name = f"Read_{data_name}_{scs_id}" + stub_name = f"Read_{address}_{length}_{scs_id}_{value}_{error}" self.stub( name=stub_name, receive_bytes=read_request, @@ -350,15 +337,42 @@ class MockMotors(MockSerial): ) return stub_name - def build_sync_read_stub( - self, data_name: str, ids_values: dict[int, int] | None = None, num_invalid_try: int = 0 + def build_write_stub( + self, + address: int, + length: int, + scs_id: int, + value: int, + reply: bool = True, + error: int = 0, + num_invalid_try: int = 0, ) -> str: - address, length = self.ctrl_table[data_name] - sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length) - return_packets = b"".join(MockStatusPacket.read(id_, pos, length) for id_, pos in ids_values.items()) + sync_read_request = MockInstructionPacket.write(scs_id, value, address, length) + return_packet = MockStatusPacket.build(scs_id, params=[], length=2, error=error) if reply else b"" + stub_name = f"Write_{address}_{length}_{scs_id}" + self.stub( + name=stub_name, + receive_bytes=sync_read_request, + send_fn=self._build_send_fn(return_packet, num_invalid_try), + ) + return stub_name + def build_sync_read_stub( + self, + address: int, + length: int, + ids_values: dict[int, int], + reply: bool = True, + num_invalid_try: int = 0, + ) -> str: + sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length) + return_packets = ( + b"".join(MockStatusPacket.read(id_, pos, length) for id_, pos in ids_values.items()) + if reply + else b"" + ) sync_read_response = self._build_send_fn(return_packets, num_invalid_try) - stub_name = f"Sync_Read_{data_name}_" + "_".join([str(id_) for id_ in ids_values]) + stub_name = f"Sync_Read_{address}_{length}_" + "_".join([str(id_) for id_ in ids_values]) self.stub( name=stub_name, receive_bytes=sync_read_request, @@ -367,11 +381,10 @@ class MockMotors(MockSerial): return stub_name def build_sequential_sync_read_stub( - self, data_name: str, ids_values: dict[int, list[int]] | None = None + self, address: int, length: int, ids_values: dict[int, list[int]] | None = None ) -> str: sequence_length = len(next(iter(ids_values.values()))) assert all(len(positions) == sequence_length for positions in ids_values.values()) - address, length = self.ctrl_table[data_name] sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length) sequential_packets = [] for count in range(sequence_length): @@ -381,7 +394,7 @@ class MockMotors(MockSerial): sequential_packets.append(return_packets) sync_read_response = self._build_sequential_send_fn(sequential_packets) - stub_name = f"Seq_Sync_Read_{data_name}_" + "_".join([str(id_) for id_ in ids_values]) + stub_name = f"Seq_Sync_Read_{address}_{length}_" + "_".join([str(id_) for id_ in ids_values]) self.stub( name=stub_name, receive_bytes=sync_read_request, @@ -390,11 +403,10 @@ class MockMotors(MockSerial): return stub_name def build_sync_write_stub( - self, data_name: str, ids_values: dict[int, int] | None = None, num_invalid_try: int = 0 + self, address: int, length: int, ids_values: dict[int, int], num_invalid_try: int = 0 ) -> str: - address, length = self.ctrl_table[data_name] sync_read_request = MockInstructionPacket.sync_write(ids_values, address, length) - stub_name = f"Sync_Write_{data_name}_" + "_".join([str(id_) for id_ in ids_values]) + stub_name = f"Sync_Write_{address}_{length}_" + "_".join([str(id_) for id_ in ids_values]) self.stub( name=stub_name, receive_bytes=sync_read_request, @@ -402,20 +414,6 @@ class MockMotors(MockSerial): ) return stub_name - def build_write_stub( - self, data_name: str, scs_id: int, value: int, error: str = "Success", num_invalid_try: int = 0 - ) -> str: - address, length = self.ctrl_table[data_name] - sync_read_request = MockInstructionPacket.write(scs_id, value, address, length) - return_packet = MockStatusPacket.build(scs_id, params=[], length=2, error=error) - stub_name = f"Write_{data_name}_{scs_id}" - self.stub( - name=stub_name, - receive_bytes=sync_read_request, - send_fn=self._build_send_fn(return_packet, num_invalid_try), - ) - return stub_name - @staticmethod def _build_send_fn(packet: bytes, num_invalid_try: int = 0) -> Callable[[int], bytes]: def send_fn(_call_count: int) -> bytes: diff --git a/tests/motors/test_feetech.py b/tests/motors/test_feetech.py index 2d3d4db7..d25b98bc 100644 --- a/tests/motors/test_feetech.py +++ b/tests/motors/test_feetech.py @@ -1,3 +1,4 @@ +import re import sys from typing import Generator from unittest.mock import MagicMock, patch @@ -6,7 +7,8 @@ import pytest import scservo_sdk as scs from lerobot.common.motors import Motor, MotorCalibration, MotorNormMode -from lerobot.common.motors.feetech import MODEL_NUMBER_TABLE, FeetechMotorsBus +from lerobot.common.motors.feetech import MODEL_NUMBER, MODEL_NUMBER_TABLE, FeetechMotorsBus +from lerobot.common.motors.feetech.tables import STS_SMS_SERIES_CONTROL_TABLE from lerobot.common.utils.encoding_utils import encode_sign_magnitude from tests.mocks.mock_feetech import MockMotors, MockPortHandler @@ -109,8 +111,9 @@ def test_scan_port(mock_motors): @pytest.mark.parametrize("id_", [1, 2, 3]) def test_ping(id_, mock_motors, dummy_motors): expected_model_nb = MODEL_NUMBER_TABLE[dummy_motors[f"dummy_{id_}"].model] + addr, length = MODEL_NUMBER ping_stub = mock_motors.build_ping_stub(id_) - mobel_nb_stub = mock_motors.build_read_stub("Model_Number", id_, expected_model_nb) + mobel_nb_stub = mock_motors.build_read_stub(addr, length, id_, expected_model_nb) motors_bus = FeetechMotorsBus( port=mock_motors.port, motors=dummy_motors, @@ -126,9 +129,15 @@ def test_ping(id_, mock_motors, dummy_motors): def test_broadcast_ping(mock_motors, dummy_motors): models = {m.id: m.model for m in dummy_motors.values()} - expected_model_nbs = {id_: MODEL_NUMBER_TABLE[model] for id_, model in models.items()} + addr, length = MODEL_NUMBER ping_stub = mock_motors.build_broadcast_ping_stub(list(models)) - mobel_nb_stub = mock_motors.build_sync_read_stub("Model_Number", expected_model_nbs) + mobel_nb_stubs = [] + expected_model_nbs = {} + for id_, model in models.items(): + model_nb = MODEL_NUMBER_TABLE[model] + stub = mock_motors.build_read_stub(addr, length, id_, model_nb) + expected_model_nbs[id_] = model_nb + mobel_nb_stubs.append(stub) motors_bus = FeetechMotorsBus( port=mock_motors.port, motors=dummy_motors, @@ -139,187 +148,209 @@ def test_broadcast_ping(mock_motors, dummy_motors): assert ping_model_nbs == expected_model_nbs assert mock_motors.stubs[ping_stub].called - assert mock_motors.stubs[mobel_nb_stub].called - - -def test_sync_read_none(mock_motors, dummy_motors): - expected_positions = { - "dummy_1": 1337, - "dummy_2": 42, - "dummy_3": 4016, - } - ids_values = dict(zip([1, 2, 3], expected_positions.values(), strict=True)) - stub_name = mock_motors.build_sync_read_stub("Present_Position", ids_values) - motors_bus = FeetechMotorsBus( - port=mock_motors.port, - motors=dummy_motors, - ) - motors_bus.connect(assert_motors_exist=False) - - read_positions = motors_bus.sync_read("Present_Position", normalize=False) - - assert mock_motors.stubs[stub_name].called - assert read_positions == expected_positions + assert all(mock_motors.stubs[stub].called for stub in mobel_nb_stubs) @pytest.mark.parametrize( - "id_, position", + "addr, length, id_, value", [ - (1, 1337), - (2, 42), - (3, 4016), + (0, 1, 1, 2), + (10, 2, 2, 999), + (42, 4, 3, 1337), ], ) -def test_sync_read_single_value(id_, position, mock_motors, dummy_motors): - expected_position = {f"dummy_{id_}": position} - stub_name = mock_motors.build_sync_read_stub("Present_Position", {id_: position}) +def test__read(addr, length, id_, value, mock_motors, dummy_motors): + stub_name = mock_motors.build_read_stub(addr, length, id_, value) motors_bus = FeetechMotorsBus( port=mock_motors.port, motors=dummy_motors, ) motors_bus.connect(assert_motors_exist=False) - read_position = motors_bus.sync_read("Present_Position", f"dummy_{id_}", normalize=False) + read_value, _, _ = motors_bus._read(addr, length, id_) assert mock_motors.stubs[stub_name].called - assert read_position == expected_position + assert read_value == value -@pytest.mark.parametrize( - "ids, positions", - [ - ([1], [1337]), - ([1, 2], [1337, 42]), - ([1, 2, 3], [1337, 42, 4016]), - ], - ids=["1 motor", "2 motors", "3 motors"], -) # fmt: skip -def test_sync_read(ids, positions, mock_motors, dummy_motors): - assert len(ids) == len(positions) - names = [f"dummy_{dxl_id}" for dxl_id in ids] - expected_positions = dict(zip(names, positions, strict=True)) - ids_values = dict(zip(ids, positions, strict=True)) - stub_name = mock_motors.build_sync_read_stub("Present_Position", ids_values) +@pytest.mark.parametrize("raise_on_error", (True, False)) +def test__read_error(raise_on_error, mock_motors, dummy_motors): + addr, length, id_, value, error = (10, 4, 1, 1337, scs.ERRBIT_VOLTAGE) + stub_name = mock_motors.build_read_stub(addr, length, id_, value, error=error) motors_bus = FeetechMotorsBus( port=mock_motors.port, motors=dummy_motors, ) motors_bus.connect(assert_motors_exist=False) - read_positions = motors_bus.sync_read("Present_Position", names, normalize=False) - - assert mock_motors.stubs[stub_name].called - assert read_positions == expected_positions - - -@pytest.mark.parametrize( - "num_retry, num_invalid_try, pos", - [ - (0, 2, 1337), - (2, 3, 42), - (3, 2, 4016), - (2, 1, 999), - ], -) -def test_sync_read_num_retry(num_retry, num_invalid_try, pos, mock_motors, dummy_motors): - expected_position = {"dummy_1": pos} - stub_name = mock_motors.build_sync_read_stub( - "Present_Position", {1: pos}, num_invalid_try=num_invalid_try - ) - motors_bus = FeetechMotorsBus( - port=mock_motors.port, - motors=dummy_motors, - ) - motors_bus.connect(assert_motors_exist=False) - - if num_retry >= num_invalid_try: - pos_dict = motors_bus.sync_read("Present_Position", "dummy_1", normalize=False, num_retry=num_retry) - assert pos_dict == expected_position + if raise_on_error: + with pytest.raises(RuntimeError, match=re.escape("[RxPacketError] Input voltage error!")): + motors_bus._read(addr, length, id_, raise_on_error=raise_on_error) else: - with pytest.raises(ConnectionError): - _ = motors_bus.sync_read("Present_Position", "dummy_1", normalize=False, num_retry=num_retry) - - expected_calls = min(1 + num_retry, 1 + num_invalid_try) - assert mock_motors.stubs[stub_name].calls == expected_calls - - -@pytest.mark.parametrize( - "data_name, value", - [ - ("Torque_Enable", 0), - ("Torque_Enable", 1), - ("Goal_Position", 1337), - ("Goal_Position", 42), - ], -) -def test_sync_write_single_value(data_name, value, mock_motors, dummy_motors): - ids_values = {m.id: value for m in dummy_motors.values()} - stub_name = mock_motors.build_sync_write_stub(data_name, ids_values) - motors_bus = FeetechMotorsBus( - port=mock_motors.port, - motors=dummy_motors, - ) - motors_bus.connect(assert_motors_exist=False) - - motors_bus.sync_write(data_name, value, normalize=False) - - assert mock_motors.stubs[stub_name].wait_called() - - -@pytest.mark.parametrize( - "ids, positions", - [ - ([1], [1337]), - ([1, 2], [1337, 42]), - ([1, 2, 3], [1337, 42, 4016]), - ], - ids=["1 motor", "2 motors", "3 motors"], -) # fmt: skip -def test_sync_write(ids, positions, mock_motors, dummy_motors): - assert len(ids) == len(positions) - ids_values = dict(zip(ids, positions, strict=True)) - stub_name = mock_motors.build_sync_write_stub("Goal_Position", ids_values) - motors_bus = FeetechMotorsBus( - port=mock_motors.port, - motors=dummy_motors, - ) - motors_bus.connect(assert_motors_exist=False) - - write_values = {f"dummy_{id_}": pos for id_, pos in ids_values.items()} - motors_bus.sync_write("Goal_Position", write_values, normalize=False) - - assert mock_motors.stubs[stub_name].wait_called() - - -@pytest.mark.parametrize( - "data_name, dxl_id, value", - [ - ("Torque_Enable", 1, 0), - ("Torque_Enable", 1, 1), - ("Goal_Position", 2, 1337), - ("Goal_Position", 3, 42), - ], -) -def test_write(data_name, dxl_id, value, mock_motors, dummy_motors): - stub_name = mock_motors.build_write_stub(data_name, dxl_id, value) - motors_bus = FeetechMotorsBus( - port=mock_motors.port, - motors=dummy_motors, - ) - motors_bus.connect(assert_motors_exist=False) - - motors_bus.write(data_name, f"dummy_{dxl_id}", value, normalize=False) + _, _, read_error = motors_bus._read(addr, length, id_, raise_on_error=raise_on_error) + assert read_error == error assert mock_motors.stubs[stub_name].called +@pytest.mark.parametrize("raise_on_error", (True, False)) +def test__read_comm(raise_on_error, mock_motors, dummy_motors): + addr, length, id_, value = (10, 4, 1, 1337) + stub_name = mock_motors.build_read_stub(addr, length, id_, value, reply=False) + motors_bus = FeetechMotorsBus( + port=mock_motors.port, + motors=dummy_motors, + ) + motors_bus.connect(assert_motors_exist=False) + + if raise_on_error: + with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")): + motors_bus._read(addr, length, id_, raise_on_error=raise_on_error) + else: + _, read_comm, _ = motors_bus._read(addr, length, id_, raise_on_error=raise_on_error) + assert read_comm == scs.COMM_RX_TIMEOUT + + assert mock_motors.stubs[stub_name].called + + +@pytest.mark.parametrize( + "addr, length, id_, value", + [ + (0, 1, 1, 2), + (10, 2, 2, 999), + (42, 4, 3, 1337), + ], +) +def test__write(addr, length, id_, value, mock_motors, dummy_motors): + stub_name = mock_motors.build_write_stub(addr, length, id_, value) + motors_bus = FeetechMotorsBus( + port=mock_motors.port, + motors=dummy_motors, + ) + motors_bus.connect(assert_motors_exist=False) + + comm, error = motors_bus._write(addr, length, id_, value) + + assert mock_motors.stubs[stub_name].called + assert comm == scs.COMM_SUCCESS + assert error == 0 + + +@pytest.mark.parametrize("raise_on_error", (True, False)) +def test__write_error(raise_on_error, mock_motors, dummy_motors): + addr, length, id_, value, error = (10, 4, 1, 1337, scs.ERRBIT_VOLTAGE) + stub_name = mock_motors.build_write_stub(addr, length, id_, value, error=error) + motors_bus = FeetechMotorsBus( + port=mock_motors.port, + motors=dummy_motors, + ) + motors_bus.connect(assert_motors_exist=False) + + if raise_on_error: + with pytest.raises(RuntimeError, match=re.escape("[RxPacketError] Input voltage error!")): + motors_bus._write(addr, length, id_, value, raise_on_error=raise_on_error) + else: + _, write_error = motors_bus._write(addr, length, id_, value, raise_on_error=raise_on_error) + assert write_error == error + + assert mock_motors.stubs[stub_name].called + + +@pytest.mark.parametrize("raise_on_error", (True, False)) +def test__write_comm(raise_on_error, mock_motors, dummy_motors): + addr, length, id_, value = (10, 4, 1, 1337) + stub_name = mock_motors.build_write_stub(addr, length, id_, value, reply=False) + motors_bus = FeetechMotorsBus( + port=mock_motors.port, + motors=dummy_motors, + ) + motors_bus.connect(assert_motors_exist=False) + + if raise_on_error: + with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")): + motors_bus._write(addr, length, id_, value, raise_on_error=raise_on_error) + else: + write_comm, _ = motors_bus._write(addr, length, id_, value, raise_on_error=raise_on_error) + assert write_comm == scs.COMM_RX_TIMEOUT + + assert mock_motors.stubs[stub_name].called + + +@pytest.mark.parametrize( + "addr, length, ids_values", + [ + (0, 1, {1: 4}), + (10, 2, {1: 1337, 2: 42}), + (42, 4, {1: 1337, 2: 42, 3: 4016}), + ], + ids=["1 motor", "2 motors", "3 motors"], +) +def test__sync_read(addr, length, ids_values, mock_motors, dummy_motors): + stub_name = mock_motors.build_sync_read_stub(addr, length, ids_values) + motors_bus = FeetechMotorsBus( + port=mock_motors.port, + motors=dummy_motors, + ) + motors_bus.connect(assert_motors_exist=False) + + read_values, _ = motors_bus._sync_read(addr, length, list(ids_values)) + + assert mock_motors.stubs[stub_name].called + assert read_values == ids_values + + +@pytest.mark.parametrize("raise_on_error", (True, False)) +def test__sync_read_comm(raise_on_error, mock_motors, dummy_motors): + addr, length, ids_values = (10, 4, {1: 1337}) + stub_name = mock_motors.build_sync_read_stub(addr, length, ids_values, reply=False) + motors_bus = FeetechMotorsBus( + port=mock_motors.port, + motors=dummy_motors, + ) + motors_bus.connect(assert_motors_exist=False) + + if raise_on_error: + with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")): + motors_bus._sync_read(addr, length, list(ids_values), raise_on_error=raise_on_error) + else: + _, read_comm = motors_bus._sync_read(addr, length, list(ids_values), raise_on_error=raise_on_error) + assert read_comm == scs.COMM_RX_TIMEOUT + + assert mock_motors.stubs[stub_name].called + + +@pytest.mark.parametrize( + "addr, length, ids_values", + [ + (0, 1, {1: 4}), + (10, 2, {1: 1337, 2: 42}), + (42, 4, {1: 1337, 2: 42, 3: 4016}), + ], + ids=["1 motor", "2 motors", "3 motors"], +) +def test__sync_write(addr, length, ids_values, mock_motors, dummy_motors): + stub_name = mock_motors.build_sync_write_stub(addr, length, ids_values) + motors_bus = FeetechMotorsBus( + port=mock_motors.port, + motors=dummy_motors, + ) + motors_bus.connect(assert_motors_exist=False) + + comm = motors_bus._sync_write(addr, length, ids_values) + + assert mock_motors.stubs[stub_name].wait_called() + assert comm == scs.COMM_SUCCESS + + def test_is_calibrated(mock_motors, dummy_motors, dummy_calibration): encoded_homings = {m.id: encode_sign_magnitude(m.homing_offset, 11) for m in dummy_calibration.values()} mins = {m.id: m.range_min for m in dummy_calibration.values()} maxes = {m.id: m.range_max for m in dummy_calibration.values()} - offsets_stub = mock_motors.build_sync_read_stub("Homing_Offset", encoded_homings) - mins_stub = mock_motors.build_sync_read_stub("Min_Position_Limit", mins) - maxes_stub = mock_motors.build_sync_read_stub("Max_Position_Limit", maxes) + offsets_stub = mock_motors.build_sync_read_stub( + *STS_SMS_SERIES_CONTROL_TABLE["Homing_Offset"], encoded_homings + ) + mins_stub = mock_motors.build_sync_read_stub(*STS_SMS_SERIES_CONTROL_TABLE["Min_Position_Limit"], mins) + maxes_stub = mock_motors.build_sync_read_stub(*STS_SMS_SERIES_CONTROL_TABLE["Max_Position_Limit"], maxes) motors_bus = FeetechMotorsBus( port=mock_motors.port, motors=dummy_motors, @@ -340,9 +371,15 @@ def test_reset_calibration(mock_motors, dummy_motors): write_mins_stubs = [] write_maxes_stubs = [] for motor in dummy_motors.values(): - write_homing_stubs.append(mock_motors.build_write_stub("Homing_Offset", motor.id, 0)) - write_mins_stubs.append(mock_motors.build_write_stub("Min_Position_Limit", motor.id, 0)) - write_maxes_stubs.append(mock_motors.build_write_stub("Max_Position_Limit", motor.id, 4095)) + write_homing_stubs.append( + mock_motors.build_write_stub(*STS_SMS_SERIES_CONTROL_TABLE["Homing_Offset"], motor.id, 0) + ) + write_mins_stubs.append( + mock_motors.build_write_stub(*STS_SMS_SERIES_CONTROL_TABLE["Min_Position_Limit"], motor.id, 0) + ) + write_maxes_stubs.append( + mock_motors.build_write_stub(*STS_SMS_SERIES_CONTROL_TABLE["Max_Position_Limit"], motor.id, 4095) + ) motors_bus = FeetechMotorsBus( port=mock_motors.port, @@ -372,11 +409,15 @@ def test_set_half_turn_homings(mock_motors, dummy_motors): 2: -2005, # 42 - 2047 3: 1625, # 3672 - 2047 } - read_pos_stub = mock_motors.build_sync_read_stub("Present_Position", current_positions) + read_pos_stub = mock_motors.build_sync_read_stub( + *STS_SMS_SERIES_CONTROL_TABLE["Present_Position"], current_positions + ) write_homing_stubs = [] for id_, homing in expected_homings.items(): encoded_homing = encode_sign_magnitude(homing, 11) - stub = mock_motors.build_write_stub("Homing_Offset", id_, encoded_homing) + stub = mock_motors.build_write_stub( + *STS_SMS_SERIES_CONTROL_TABLE["Homing_Offset"], id_, encoded_homing + ) write_homing_stubs.append(stub) motors_bus = FeetechMotorsBus( @@ -409,7 +450,9 @@ def test_record_ranges_of_motion(mock_motors, dummy_motors): "dummy_2": 3600, "dummy_3": 4002, } - read_pos_stub = mock_motors.build_sequential_sync_read_stub("Present_Position", positions) + read_pos_stub = mock_motors.build_sequential_sync_read_stub( + *STS_SMS_SERIES_CONTROL_TABLE["Present_Position"], positions + ) with patch("lerobot.common.motors.motors_bus.enter_pressed", side_effect=[False, True]): motors_bus = FeetechMotorsBus( port=mock_motors.port, diff --git a/tests/motors/test_motors_bus.py b/tests/motors/test_motors_bus.py index 7797622e..c98cda7d 100644 --- a/tests/motors/test_motors_bus.py +++ b/tests/motors/test_motors_bus.py @@ -1,9 +1,13 @@ +# ruff: noqa: N802 + import re +from unittest.mock import patch import pytest from lerobot.common.motors.motors_bus import ( Motor, + MotorNormMode, MotorsBus, assert_same_address, get_address, @@ -14,30 +18,35 @@ DUMMY_CTRL_TABLE_1 = { "Firmware_Version": (0, 1), "Model_Number": (1, 2), "Present_Position": (3, 4), - "Goal_Position": (7, 2), + "Goal_Position": (11, 2), } DUMMY_CTRL_TABLE_2 = { "Model_Number": (0, 2), "Firmware_Version": (2, 1), "Present_Position": (3, 4), - "Goal_Position": (7, 4), - "Lock": (7, 4), + "Present_Velocity": (7, 4), + "Goal_Position": (11, 4), + "Goal_Velocity": (15, 4), + "Lock": (19, 1), } DUMMY_MODEL_CTRL_TABLE = { "model_1": DUMMY_CTRL_TABLE_1, "model_2": DUMMY_CTRL_TABLE_2, + "model_3": DUMMY_CTRL_TABLE_2, } DUMMY_BAUDRATE_TABLE = { 0: 1_000_000, 1: 500_000, + 2: 250_000, } DUMMY_MODEL_BAUDRATE_TABLE = { "model_1": DUMMY_BAUDRATE_TABLE, "model_2": DUMMY_BAUDRATE_TABLE, + "model_3": DUMMY_BAUDRATE_TABLE, } DUMMY_ENCODING_TABLE = { @@ -48,21 +57,78 @@ DUMMY_ENCODING_TABLE = { DUMMY_MODEL_ENCODING_TABLE = { "model_1": DUMMY_ENCODING_TABLE, "model_2": DUMMY_ENCODING_TABLE, + "model_3": DUMMY_ENCODING_TABLE, +} + +DUMMY_MODEL_NUMBER_TABLE = { + "model_1": 1234, + "model_2": 5678, + "model_3": 5799, +} + +DUMMY_MODEL_RESOLUTION_TABLE = { + "model_1": 4096, + "model_2": 1024, + "model_3": 4096, } -class DummyMotorsBus(MotorsBus): +class MockPortHandler: + def __init__(self, port_name): + self.is_open: bool = False + self.baudrate: int + self.packet_start_time: float + self.packet_timeout: float + self.tx_time_per_byte: float + self.is_using: bool = False + self.port_name: str = port_name + self.ser = None + + def openPort(self): + self.is_open = True + return self.is_open + + def closePort(self): + self.is_open = False + + def clearPort(self): ... + def setPortName(self, port_name): + self.port_name = port_name + + def getPortName(self): + return self.port_name + + def setBaudRate(self, baudrate): + self.baudrate: baudrate + + def getBaudRate(self): + return self.baudrate + + def getBytesAvailable(self): ... + def readPort(self, length): ... + def writePort(self, packet): ... + def setPacketTimeout(self, packet_length): ... + def setPacketTimeoutMillis(self, msec): ... + def isPacketTimeout(self): ... + def getCurrentTime(self): ... + def getTimeSinceStart(self): ... + def setupPort(self, cflag_baud): ... + def getCFlagBaud(self, baudrate): ... + + +class MockMotorsBus(MotorsBus): available_baudrates = [500_000, 1_000_000] default_timeout = 1000 model_baudrate_table = DUMMY_MODEL_BAUDRATE_TABLE model_ctrl_table = DUMMY_MODEL_CTRL_TABLE model_encoding_table = DUMMY_MODEL_ENCODING_TABLE - model_number_table = {"model_1": 1234, "model_2": 5678} - model_resolution_table = {"model_1": 4096, "model_2": 1024} + model_number_table = DUMMY_MODEL_NUMBER_TABLE + model_resolution_table = DUMMY_MODEL_RESOLUTION_TABLE normalized_data = ["Present_Position", "Goal_Position"] def __init__(self, port: str, motors: dict[str, Motor]): super().__init__(port, motors) + self.port_handler = MockPortHandler(port) def _assert_protocol_is_compatible(self, instruction_name): ... def configure_motors(self): ... @@ -75,6 +141,15 @@ class DummyMotorsBus(MotorsBus): def broadcast_ping(self, num_retry, raise_on_error): ... +@pytest.fixture +def dummy_motors() -> dict[str, Motor]: + return { + "dummy_1": Motor(1, "model_2", MotorNormMode.RANGE_M100_100), + "dummy_2": Motor(2, "model_3", MotorNormMode.RANGE_M100_100), + "dummy_3": Motor(3, "model_2", MotorNormMode.RANGE_0_100), + } + + def test_get_ctrl_table(): model = "model_1" ctrl_table = get_ctrl_table(DUMMY_MODEL_CTRL_TABLE, model) @@ -105,7 +180,7 @@ def test_assert_same_address(): assert_same_address(DUMMY_MODEL_CTRL_TABLE, models, "Present_Position") -def test_assert_same_address_different_addresses(): +def test_assert_same_length_different_addresses(): models = ["model_1", "model_2"] with pytest.raises( NotImplementedError, @@ -114,7 +189,7 @@ def test_assert_same_address_different_addresses(): assert_same_address(DUMMY_MODEL_CTRL_TABLE, models, "Model_Number") -def test_assert_same_address_different_bytes(): +def test_assert_same_address_different_length(): models = ["model_1", "model_2"] with pytest.raises( NotImplementedError, @@ -124,18 +199,267 @@ def test_assert_same_address_different_bytes(): def test__serialize_data_invalid_length(): - bus = DummyMotorsBus("", {}) + bus = MockMotorsBus("", {}) with pytest.raises(NotImplementedError): bus._serialize_data(100, 3) def test__serialize_data_negative_numbers(): - bus = DummyMotorsBus("", {}) + bus = MockMotorsBus("", {}) with pytest.raises(ValueError): bus._serialize_data(-1, 1) def test__serialize_data_large_number(): - bus = DummyMotorsBus("", {}) + bus = MockMotorsBus("", {}) with pytest.raises(ValueError): bus._serialize_data(2**32, 4) # 4-byte max is 0xFFFFFFFF + + +@pytest.mark.parametrize( + "data_name, id_, value", + [ + ("Firmware_Version", 1, 14), + ("Model_Number", 1, 5678), + ("Present_Position", 2, 1337), + ("Present_Velocity", 3, 42), + ], +) +def test_read(data_name, id_, value, dummy_motors): + bus = MockMotorsBus("/dev/dummy-port", dummy_motors) + bus.connect(assert_motors_exist=False) + addr, length = DUMMY_CTRL_TABLE_2[data_name] + + with ( + patch.object(MockMotorsBus, "_read", return_value=(value, 0, 0)) as mock__read, + patch.object(MockMotorsBus, "_decode_sign", return_value={id_: value}) as mock__decode_sign, + patch.object(MockMotorsBus, "_normalize", return_value={id_: value}) as mock__normalize, + ): + returned_value = bus.read(data_name, f"dummy_{id_}") + + assert returned_value == value + mock__read.assert_called_once_with( + addr, + length, + id_, + num_retry=0, + raise_on_error=True, + err_msg=f"Failed to read '{data_name}' on {id_=} after 1 tries.", + ) + mock__decode_sign.assert_called_once_with(data_name, {id_: value}) + if data_name in bus.normalized_data: + mock__normalize.assert_called_once_with(data_name, {id_: value}) + + +@pytest.mark.parametrize( + "data_name, id_, value", + [ + ("Goal_Position", 1, 1337), + ("Goal_Velocity", 2, 3682), + ("Lock", 3, 1), + ], +) +def test_write(data_name, id_, value, dummy_motors): + bus = MockMotorsBus("/dev/dummy-port", dummy_motors) + bus.connect(assert_motors_exist=False) + addr, length = DUMMY_CTRL_TABLE_2[data_name] + + with ( + patch.object(MockMotorsBus, "_write", return_value=(0, 0)) as mock__write, + patch.object(MockMotorsBus, "_encode_sign", return_value={id_: value}) as mock__encode_sign, + patch.object(MockMotorsBus, "_unnormalize", return_value={id_: value}) as mock__unnormalize, + ): + bus.write(data_name, f"dummy_{id_}", value) + + mock__write.assert_called_once_with( + addr, + length, + id_, + value, + num_retry=0, + raise_on_error=True, + err_msg=f"Failed to write '{data_name}' on {id_=} with '{value}' after 1 tries.", + ) + mock__encode_sign.assert_called_once_with(data_name, {id_: value}) + if data_name in bus.normalized_data: + mock__unnormalize.assert_called_once_with(data_name, {id_: value}) + + +@pytest.mark.parametrize( + "data_name, id_, value", + [ + ("Firmware_Version", 1, 14), + ("Model_Number", 1, 5678), + ("Present_Position", 2, 1337), + ("Present_Velocity", 3, 42), + ], +) +def test_sync_read_by_str(data_name, id_, value, dummy_motors): + bus = MockMotorsBus("/dev/dummy-port", dummy_motors) + bus.connect(assert_motors_exist=False) + addr, length = DUMMY_CTRL_TABLE_2[data_name] + ids = [id_] + expected_value = {f"dummy_{id_}": value} + + with ( + patch.object(MockMotorsBus, "_sync_read", return_value=({id_: value}, 0)) as mock__sync_read, + patch.object(MockMotorsBus, "_decode_sign", return_value={id_: value}) as mock__decode_sign, + patch.object(MockMotorsBus, "_normalize", return_value={id_: value}) as mock__normalize, + ): + returned_dict = bus.sync_read(data_name, f"dummy_{id_}") + + assert returned_dict == expected_value + mock__sync_read.assert_called_once_with( + addr, + length, + ids, + num_retry=0, + raise_on_error=True, + err_msg=f"Failed to sync read '{data_name}' on {ids=} after 1 tries.", + ) + mock__decode_sign.assert_called_once_with(data_name, {id_: value}) + if data_name in bus.normalized_data: + mock__normalize.assert_called_once_with(data_name, {id_: value}) + + +@pytest.mark.parametrize( + "data_name, ids_values", + [ + ("Model_Number", {1: 5678}), + ("Present_Position", {1: 1337, 2: 42}), + ("Present_Velocity", {1: 1337, 2: 42, 3: 4016}), + ], + ids=["1 motor", "2 motors", "3 motors"], +) +def test_sync_read_by_list(data_name, ids_values, dummy_motors): + bus = MockMotorsBus("/dev/dummy-port", dummy_motors) + bus.connect(assert_motors_exist=False) + addr, length = DUMMY_CTRL_TABLE_2[data_name] + ids = list(ids_values) + expected_values = {f"dummy_{id_}": val for id_, val in ids_values.items()} + + with ( + patch.object(MockMotorsBus, "_sync_read", return_value=(ids_values, 0)) as mock__sync_read, + patch.object(MockMotorsBus, "_decode_sign", return_value=ids_values) as mock__decode_sign, + patch.object(MockMotorsBus, "_normalize", return_value=ids_values) as mock__normalize, + ): + returned_dict = bus.sync_read(data_name, [f"dummy_{id_}" for id_ in ids]) + + assert returned_dict == expected_values + mock__sync_read.assert_called_once_with( + addr, + length, + ids, + num_retry=0, + raise_on_error=True, + err_msg=f"Failed to sync read '{data_name}' on {ids=} after 1 tries.", + ) + mock__decode_sign.assert_called_once_with(data_name, ids_values) + if data_name in bus.normalized_data: + mock__normalize.assert_called_once_with(data_name, ids_values) + + +@pytest.mark.parametrize( + "data_name, ids_values", + [ + ("Model_Number", {1: 5678, 2: 5799, 3: 5678}), + ("Present_Position", {1: 1337, 2: 42, 3: 4016}), + ("Goal_Position", {1: 4008, 2: 199, 3: 3446}), + ], + ids=["Model_Number", "Present_Position", "Goal_Position"], +) +def test_sync_read_by_none(data_name, ids_values, dummy_motors): + bus = MockMotorsBus("/dev/dummy-port", dummy_motors) + bus.connect(assert_motors_exist=False) + addr, length = DUMMY_CTRL_TABLE_2[data_name] + ids = list(ids_values) + expected_values = {f"dummy_{id_}": val for id_, val in ids_values.items()} + + with ( + patch.object(MockMotorsBus, "_sync_read", return_value=(ids_values, 0)) as mock__sync_read, + patch.object(MockMotorsBus, "_decode_sign", return_value=ids_values) as mock__decode_sign, + patch.object(MockMotorsBus, "_normalize", return_value=ids_values) as mock__normalize, + ): + returned_dict = bus.sync_read(data_name) + + assert returned_dict == expected_values + mock__sync_read.assert_called_once_with( + addr, + length, + ids, + num_retry=0, + raise_on_error=True, + err_msg=f"Failed to sync read '{data_name}' on {ids=} after 1 tries.", + ) + mock__decode_sign.assert_called_once_with(data_name, ids_values) + if data_name in bus.normalized_data: + mock__normalize.assert_called_once_with(data_name, ids_values) + + +@pytest.mark.parametrize( + "data_name, value", + [ + ("Goal_Position", 500), + ("Goal_Velocity", 4010), + ("Lock", 0), + ], +) +def test_sync_write_by_single_value(data_name, value, dummy_motors): + bus = MockMotorsBus("/dev/dummy-port", dummy_motors) + bus.connect(assert_motors_exist=False) + addr, length = DUMMY_CTRL_TABLE_2[data_name] + ids_values = {m.id: value for m in dummy_motors.values()} + + with ( + patch.object(MockMotorsBus, "_sync_write", return_value=(ids_values, 0)) as mock__sync_write, + patch.object(MockMotorsBus, "_encode_sign", return_value=ids_values) as mock__encode_sign, + patch.object(MockMotorsBus, "_unnormalize", return_value=ids_values) as mock__unnormalize, + ): + bus.sync_write(data_name, value) + + mock__sync_write.assert_called_once_with( + addr, + length, + ids_values, + num_retry=0, + raise_on_error=True, + err_msg=f"Failed to sync write '{data_name}' with {ids_values=} after 1 tries.", + ) + mock__encode_sign.assert_called_once_with(data_name, ids_values) + if data_name in bus.normalized_data: + mock__unnormalize.assert_called_once_with(data_name, ids_values) + + +@pytest.mark.parametrize( + "data_name, ids_values", + [ + ("Goal_Position", {1: 1337, 2: 42, 3: 4016}), + ("Goal_Velocity", {1: 50, 2: 83, 3: 2777}), + ("Lock", {1: 0, 2: 0, 3: 1}), + ], + ids=["Goal_Position", "Goal_Velocity", "Lock"], +) +def test_sync_write_by_value_dict(data_name, ids_values, dummy_motors): + bus = MockMotorsBus("/dev/dummy-port", dummy_motors) + bus.connect(assert_motors_exist=False) + addr, length = DUMMY_CTRL_TABLE_2[data_name] + values = {f"dummy_{id_}": val for id_, val in ids_values.items()} + + with ( + patch.object(MockMotorsBus, "_sync_write", return_value=(ids_values, 0)) as mock__sync_write, + patch.object(MockMotorsBus, "_encode_sign", return_value=ids_values) as mock__encode_sign, + patch.object(MockMotorsBus, "_unnormalize", return_value=ids_values) as mock__unnormalize, + ): + bus.sync_write(data_name, values) + + mock__sync_write.assert_called_once_with( + addr, + length, + ids_values, + num_retry=0, + raise_on_error=True, + err_msg=f"Failed to sync write '{data_name}' with {ids_values=} after 1 tries.", + ) + mock__encode_sign.assert_called_once_with(data_name, ids_values) + if data_name in bus.normalized_data: + mock__unnormalize.assert_called_once_with(data_name, ids_values)