From 5a57e6f4a733cdccf64c2aa1da9effa2ddc1f7b4 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Sun, 23 Mar 2025 13:25:45 +0100 Subject: [PATCH] Rename read/write -> sync_read/write, refactor, add write --- lerobot/common/motors/dynamixel/dynamixel.py | 7 +- lerobot/common/motors/feetech/feetech.py | 7 +- lerobot/common/motors/motors_bus.py | 152 +++++++++++++------ tests/motors/test_dynamixel.py | 14 +- tests/motors/test_feetech.py | 14 +- 5 files changed, 126 insertions(+), 68 deletions(-) diff --git a/lerobot/common/motors/dynamixel/dynamixel.py b/lerobot/common/motors/dynamixel/dynamixel.py index 7e6232e8..b0c9844f 100644 --- a/lerobot/common/motors/dynamixel/dynamixel.py +++ b/lerobot/common/motors/dynamixel/dynamixel.py @@ -55,8 +55,8 @@ class DynamixelMotorsBus(MotorsBus): self.port_handler = dxl.PortHandler(self.port) self.packet_handler = dxl.PacketHandler(PROTOCOL_VERSION) - self.reader = dxl.GroupSyncRead(self.port_handler, self.packet_handler, 0, 0) - self.writer = dxl.GroupSyncWrite(self.port_handler, self.packet_handler, 0, 0) + self.sync_reader = dxl.GroupSyncRead(self.port_handler, self.packet_handler, 0, 0) + self.sync_writer = dxl.GroupSyncWrite(self.port_handler, self.packet_handler, 0, 0) def broadcast_ping( self, num_retry: int = 0, raise_on_error: bool = False @@ -82,6 +82,9 @@ class DynamixelMotorsBus(MotorsBus): return comm == dxl.COMM_SUCCESS + def _is_error(self, error: int) -> bool: + return error != 0x00 + @staticmethod def split_int_bytes(value: int, n_bytes: int) -> list[int]: # Validate input diff --git a/lerobot/common/motors/feetech/feetech.py b/lerobot/common/motors/feetech/feetech.py index 23ea9884..5d1fa69e 100644 --- a/lerobot/common/motors/feetech/feetech.py +++ b/lerobot/common/motors/feetech/feetech.py @@ -51,8 +51,8 @@ class FeetechMotorsBus(MotorsBus): self.port_handler = scs.PortHandler(self.port) self.packet_handler = scs.PacketHandler(PROTOCOL_VERSION) - self.reader = scs.GroupSyncRead(self.port_handler, self.packet_handler, 0, 0) - self.writer = scs.GroupSyncWrite(self.port_handler, self.packet_handler, 0, 0) + self.sync_reader = scs.GroupSyncRead(self.port_handler, self.packet_handler, 0, 0) + self.sync_writer = scs.GroupSyncWrite(self.port_handler, self.packet_handler, 0, 0) def broadcast_ping(self, num_retry: int | None = None): raise NotImplementedError # TODO @@ -70,6 +70,9 @@ class FeetechMotorsBus(MotorsBus): return comm == scs.COMM_SUCCESS + def _is_error(self, error: int) -> bool: + return error != 0x00 + @staticmethod def split_int_bytes(value: int, n_bytes: int) -> list[int]: # Validate input diff --git a/lerobot/common/motors/motors_bus.py b/lerobot/common/motors/motors_bus.py index 6da8b36d..4558d56c 100644 --- a/lerobot/common/motors/motors_bus.py +++ b/lerobot/common/motors/motors_bus.py @@ -256,8 +256,8 @@ class MotorsBus(abc.ABC): self.port_handler: PortHandler self.packet_handler: PacketHandler - self.reader: GroupSyncRead - self.writer: GroupSyncWrite + self.sync_reader: GroupSyncRead + self.sync_writer: GroupSyncWrite self.calibration = None @@ -347,7 +347,7 @@ class MotorsBus(abc.ABC): """ try: # TODO(aliberts): use ping instead - return (self.ids == self.read("ID")).all() + return (self.ids == self.sync_read("ID")).all() except ConnectionError as e: logger.error(e) return False @@ -395,6 +395,10 @@ class MotorsBus(abc.ABC): def _is_comm_success(self, comm: int) -> bool: pass + @abc.abstractmethod + def _is_error(self, error: int) -> bool: + pass + @staticmethod @abc.abstractmethod def split_int_bytes(value: int, n_bytes: int) -> list[int]: @@ -442,12 +446,12 @@ class MotorsBus(abc.ABC): raise TypeError(f"'{motor}' should be int, str.") @overload - def read(self, data_name: str, motors: None = ..., num_retry: int = ...) -> dict[str, Value]: ... + def sync_read(self, data_name: str, motors: None = ..., num_retry: int = ...) -> dict[str, Value]: ... @overload - def read( + def sync_read( self, data_name: str, motors: NameOrID | list[NameOrID], num_retry: int = ... ) -> dict[NameOrID, Value]: ... - def read( + def sync_read( self, data_name: str, motors: NameOrID | list[NameOrID] | None = None, num_retry: int = 0 ) -> dict[NameOrID, Value]: if not self.is_connected: @@ -466,17 +470,11 @@ class MotorsBus(abc.ABC): raise TypeError(motors) motor_ids = list(id_key_map) - if self._has_different_ctrl_tables: - models = [self.id_to_model(idx) for idx in motor_ids] - assert_same_address(self.model_ctrl_table, models, data_name) - model = self.id_to_model(next(iter(motor_ids))) - addr, n_bytes = self.model_ctrl_table[model][data_name] - - comm, ids_values = self._read(motor_ids, addr, n_bytes, num_retry) + comm, ids_values = self._sync_read(data_name, motor_ids, num_retry) if not self._is_comm_success(comm): raise ConnectionError( - f"Failed to read {data_name} on port {self.port} for ids {motor_ids}:" + f"Failed to sync read '{data_name}' on {motor_ids=} after {num_retry + 1} tries." f"{self.packet_handler.getTxRxResult(comm)}" ) @@ -485,40 +483,50 @@ class MotorsBus(abc.ABC): return {id_key_map[idx]: val for idx, val in ids_values.items()} - def _read( - self, motor_ids: list[str], address: int, n_bytes: int, num_retry: int = 0 + def _sync_read( + self, data_name: str, motor_ids: list[str], num_retry: int = 0 ) -> tuple[int, dict[int, int]]: - self.reader.clearParam() - self.reader.start_address = address - self.reader.data_length = n_bytes + if self._has_different_ctrl_tables: + models = [self.id_to_model(idx) for idx in motor_ids] + assert_same_address(self.model_ctrl_table, models, data_name) + + model = self.id_to_model(next(iter(motor_ids))) + addr, n_bytes = self.model_ctrl_table[model][data_name] + self._setup_sync_reader(motor_ids, addr, n_bytes) # FIXME(aliberts, pkooij): We should probably not have to do this. # Let's try to see if we can do with better comm status handling instead. # self.port_handler.ser.reset_output_buffer() # self.port_handler.ser.reset_input_buffer() - for idx in motor_ids: - self.reader.addParam(idx) - for n_try in range(1 + num_retry): - comm = self.reader.txRxPacket() + comm = self.sync_reader.txRxPacket() if self._is_comm_success(comm): break - logger.debug(f"ids={list(motor_ids)} @{address} ({n_bytes} bytes) {n_try=} got {comm=}") + logger.debug(f"Failed to sync read '{data_name}' ({addr=} {n_bytes=}) on {motor_ids=} ({n_try=})") + logger.debug(self.packet_handler.getRxPacketError(comm)) - values = {idx: self.reader.getData(idx, address, n_bytes) for idx in motor_ids} + values = {idx: self.sync_reader.getData(idx, addr, n_bytes) for idx in motor_ids} return comm, values - # TODO(aliberts, pkooij): Implementing something like this could get much faster read times. - # Note: this could be at the cost of increase latency between the moment the data is produced by the - # motors and the moment it is used by a policy + def _setup_sync_reader(self, motor_ids: list[str], addr: int, n_bytes: int) -> None: + self.sync_reader.clearParam() + self.sync_reader.start_address = addr + self.sync_reader.data_length = n_bytes + for idx in motor_ids: + self.sync_reader.addParam(idx) + + # TODO(aliberts, pkooij): Implementing something like this could get even much faster read times if need be. + # Would have to handle the logic of checking if a packet has been sent previously though but doable. + # This could be at the cost of increase latency between the moment the data is produced by the motors and + # the moment it is used by a policy. # def _async_read(self, motor_ids: list[str], address: int, n_bytes: int): # self.reader.rxPacket() # self.reader.txPacket() # for idx in motor_ids: # value = self.reader.getData(idx, address, n_bytes) - def write(self, data_name: str, values: Value | dict[NameOrID, Value], num_retry: int = 0) -> None: + def sync_write(self, data_name: str, values: Value | dict[NameOrID, Value], num_retry: int = 0) -> None: if not self.is_connected: raise DeviceNotConnectedError( f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`." @@ -531,40 +539,84 @@ class MotorsBus(abc.ABC): else: raise ValueError(f"'values' is expected to be a single value or a dict. Got {values}") + if data_name in self.calibration_required and self.calibration is not None: + ids_values = self.uncalibrate_values(ids_values) + + comm = self._sync_write(data_name, ids_values, num_retry) + if not self._is_comm_success(comm): + raise ConnectionError( + f"Failed to sync write '{data_name}' with {ids_values=} after {num_retry + 1} tries." + f"\n{self.packet_handler.getTxRxResult(comm)}" + ) + + def _sync_write(self, data_name: str, ids_values: dict[int, int], num_retry: int = 0) -> int: if self._has_different_ctrl_tables: models = [self.id_to_model(idx) for idx in ids_values] assert_same_address(self.model_ctrl_table, models, data_name) - if data_name in self.calibration_required and self.calibration is not None: - ids_values = self.uncalibrate_values(ids_values) - model = self.id_to_model(next(iter(ids_values))) addr, n_bytes = self.model_ctrl_table[model][data_name] - - comm = self._write(ids_values, addr, n_bytes, num_retry) - if not self._is_comm_success(comm): - raise ConnectionError( - f"Failed to write {data_name} on port {self.port} for ids {list(ids_values)}:" - f"{self.packet_handler.getTxRxResult(comm)}" - ) - - def _write(self, ids_values: dict[int, int], address: int, n_bytes: int, num_retry: int = 0) -> int: - self.writer.clearParam() - self.writer.start_address = address - self.writer.data_length = n_bytes - - for idx, value in ids_values.items(): - data = self.split_int_bytes(value, n_bytes) - self.writer.addParam(idx, data) + self._setup_sync_writer(ids_values, addr, n_bytes) for n_try in range(1 + num_retry): - comm = self.writer.txPacket() + comm = self.sync_writer.txPacket() if self._is_comm_success(comm): break - logger.debug(f"ids={list(ids_values)} @{address} ({n_bytes} bytes) {n_try=} got {comm=}") + logger.debug( + f"Failed to sync write '{data_name}' ({addr=} {n_bytes=}) with {ids_values=} ({n_try=})" + ) + logger.debug(self.packet_handler.getRxPacketError(comm)) return comm + def _setup_sync_writer(self, ids_values: dict[int, int], addr: int, n_bytes: int) -> None: + self.sync_writer.clearParam() + self.sync_writer.start_address = addr + self.sync_writer.data_length = n_bytes + for idx, value in ids_values.items(): + data = self.split_int_bytes(value, n_bytes) + self.sync_writer.addParam(idx, data) + + def write(self, data_name: str, motor: NameOrID, value: Value, num_retry: int = 0) -> None: + if not self.is_connected: + raise DeviceNotConnectedError( + f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`." + ) + + idx = self.get_motor_id(motor) + + if data_name in self.calibration_required and self.calibration is not None: + id_value = self.uncalibrate_values({idx: value}) + value = id_value[idx] + + comm, error = self._write(data_name, idx, value, num_retry) + if not self._is_comm_success(comm): + raise ConnectionError( + f"Failed to write '{data_name}' on {idx=} with '{value}' after {num_retry + 1} tries." + f"\n{self.packet_handler.getTxRxResult(comm)}" + ) + elif self._is_error(error): + raise RuntimeError( + f"Failed to write '{data_name}' on {idx=} with '{value}' after {num_retry + 1} tries." + f"\n{self.packet_handler.getRxPacketError(error)}" + ) + + def _write(self, data_name: str, motor_id: int, value: int, num_retry: int = 0) -> tuple[int, int]: + model = self.id_to_model(motor_id) + addr, n_bytes = self.model_ctrl_table[model][data_name] + data = self.split_int_bytes(value, n_bytes) + + for n_try in range(1 + num_retry): + comm, error = self.packet_handler.writeTxRx(self.port_handler, motor_id, addr, n_bytes, data) + if self._is_comm_success(comm): + break + logger.debug( + f"Failed to write '{data_name}' ({addr=} {n_bytes=}) on {motor_id=} with '{value}' ({n_try=})" + ) + logger.debug(self.packet_handler.getRxPacketError(comm)) + + return comm, error + def disconnect(self) -> None: if not self.is_connected: raise DeviceNotConnectedError( diff --git a/tests/motors/test_dynamixel.py b/tests/motors/test_dynamixel.py index 688a7367..c189e243 100644 --- a/tests/motors/test_dynamixel.py +++ b/tests/motors/test_dynamixel.py @@ -156,7 +156,7 @@ def test_read_all_motors(motors, mock_motors, dummy_motors): ) motors_bus.connect() - positions_read = motors_bus.read("Present_Position", motors=motors) + positions_read = motors_bus.sync_read("Present_Position", motors=motors) motors = ["dummy_1", "dummy_2", "dummy_3"] if motors is None else motors assert mock_motors.stubs[stub_name].called @@ -180,7 +180,7 @@ def test_read_single_motor_by_name(idx, pos, mock_motors, dummy_motors): ) motors_bus.connect() - pos_dict = motors_bus.read("Present_Position", f"dummy_{idx}") + pos_dict = motors_bus.sync_read("Present_Position", f"dummy_{idx}") assert mock_motors.stubs[stub_name].called assert pos_dict == {f"dummy_{idx}": pos} @@ -203,7 +203,7 @@ def test_read_single_motor_by_id(idx, pos, mock_motors, dummy_motors): ) motors_bus.connect() - pos_dict = motors_bus.read("Present_Position", idx) + pos_dict = motors_bus.sync_read("Present_Position", idx) assert mock_motors.stubs[stub_name].called assert pos_dict == {idx: pos} @@ -230,11 +230,11 @@ def test_read_num_retry(num_retry, num_invalid_try, pos, mock_motors, dummy_moto motors_bus.connect() if num_retry >= num_invalid_try: - pos_dict = motors_bus.read("Present_Position", 1, num_retry=num_retry) + pos_dict = motors_bus.sync_read("Present_Position", 1, num_retry=num_retry) assert pos_dict == {1: pos} else: with pytest.raises(ConnectionError): - _ = motors_bus.read("Present_Position", 1, num_retry=num_retry) + _ = motors_bus.sync_read("Present_Position", 1, num_retry=num_retry) expected_calls = min(1 + num_retry, 1 + num_invalid_try) assert mock_motors.stubs[stub_name].calls == expected_calls @@ -263,7 +263,7 @@ def test_write_all_motors(motors, mock_motors, dummy_motors): motors_bus.connect() values = dict(zip(motors, goal_positions.values(), strict=True)) - motors_bus.write("Goal_Position", values) + motors_bus.sync_write("Goal_Position", values) assert mock_motors.stubs[stub_name].wait_called() @@ -284,6 +284,6 @@ def test_write_all_motors_single_value(data_name, value, mock_motors, dummy_moto ) motors_bus.connect() - motors_bus.write(data_name, value) + motors_bus.sync_write(data_name, value) assert mock_motors.stubs[stub_name].wait_called() diff --git a/tests/motors/test_feetech.py b/tests/motors/test_feetech.py index cb9dcda8..ac80f221 100644 --- a/tests/motors/test_feetech.py +++ b/tests/motors/test_feetech.py @@ -158,7 +158,7 @@ def test_read_all_motors(motors, mock_motors, dummy_motors): ) motors_bus.connect() - positions_read = motors_bus.read("Present_Position", motors=motors) + positions_read = motors_bus.sync_read("Present_Position", motors=motors) motors = ["dummy_1", "dummy_2", "dummy_3"] if motors is None else motors assert mock_motors.stubs[stub_name].called @@ -182,7 +182,7 @@ def test_read_single_motor_by_name(idx, pos, mock_motors, dummy_motors): ) motors_bus.connect() - pos_dict = motors_bus.read("Present_Position", f"dummy_{idx}") + pos_dict = motors_bus.sync_read("Present_Position", f"dummy_{idx}") assert mock_motors.stubs[stub_name].called assert pos_dict == {f"dummy_{idx}": pos} @@ -205,7 +205,7 @@ def test_read_single_motor_by_id(idx, pos, mock_motors, dummy_motors): ) motors_bus.connect() - pos_dict = motors_bus.read("Present_Position", idx) + pos_dict = motors_bus.sync_read("Present_Position", idx) assert mock_motors.stubs[stub_name].called assert pos_dict == {idx: pos} @@ -232,11 +232,11 @@ def test_read_num_retry(num_retry, num_invalid_try, pos, mock_motors, dummy_moto motors_bus.connect() if num_retry >= num_invalid_try: - pos_dict = motors_bus.read("Present_Position", 1, num_retry=num_retry) + pos_dict = motors_bus.sync_read("Present_Position", 1, num_retry=num_retry) assert pos_dict == {1: pos} else: with pytest.raises(ConnectionError): - _ = motors_bus.read("Present_Position", 1, num_retry=num_retry) + _ = motors_bus.sync_read("Present_Position", 1, num_retry=num_retry) expected_calls = min(1 + num_retry, 1 + num_invalid_try) assert mock_motors.stubs[stub_name].calls == expected_calls @@ -265,7 +265,7 @@ def test_write_all_motors(motors, mock_motors, dummy_motors): motors_bus.connect() values = dict(zip(motors, goal_positions.values(), strict=True)) - motors_bus.write("Goal_Position", values) + motors_bus.sync_write("Goal_Position", values) assert mock_motors.stubs[stub_name].wait_called() @@ -286,6 +286,6 @@ def test_write_all_motors_single_value(data_name, value, mock_motors, dummy_moto ) motors_bus.connect() - motors_bus.write(data_name, value) + motors_bus.sync_write(data_name, value) assert mock_motors.stubs[stub_name].wait_called()