diff --git a/lerobot/common/motors/dynamixel/dynamixel.py b/lerobot/common/motors/dynamixel/dynamixel.py index a9e00247..3a8a80be 100644 --- a/lerobot/common/motors/dynamixel/dynamixel.py +++ b/lerobot/common/motors/dynamixel/dynamixel.py @@ -89,7 +89,7 @@ class DynamixelMotorsBus(MotorsBus): self.sync_reader = dxl.GroupSyncRead(self.port_handler, self.packet_handler, 0, 0) self.sync_writer = dxl.GroupSyncWrite(self.port_handler, self.packet_handler, 0, 0) self._comm_success = dxl.COMM_SUCCESS - self._error = 0x00 + self._no_error = 0x00 def broadcast_ping( self, num_retry: int = 0, raise_on_error: bool = False @@ -102,16 +102,16 @@ class DynamixelMotorsBus(MotorsBus): if raise_on_error: raise ConnectionError(f"Broadcast ping returned a {comm} comm error.") - def calibrate_values(self, ids_values: dict[int, int]) -> dict[int, float]: + def _calibrate_values(self, ids_values: dict[int, int]) -> dict[int, float]: # TODO return ids_values - def uncalibrate_values(self, ids_values: dict[int, float]) -> dict[int, int]: + def _uncalibrate_values(self, ids_values: dict[int, float]) -> dict[int, int]: # TODO return ids_values @staticmethod - def split_int_bytes(value: int, n_bytes: int) -> list[int]: + def _split_int_to_bytes(value: int, n_bytes: int) -> list[int]: # Validate input if value < 0: raise ValueError(f"Negative values are not allowed: {value}") diff --git a/lerobot/common/motors/feetech/feetech.py b/lerobot/common/motors/feetech/feetech.py index 14aeb3df..98144473 100644 --- a/lerobot/common/motors/feetech/feetech.py +++ b/lerobot/common/motors/feetech/feetech.py @@ -69,21 +69,24 @@ class FeetechMotorsBus(MotorsBus): self.sync_reader = scs.GroupSyncRead(self.port_handler, self.packet_handler, 0, 0) self.sync_writer = scs.GroupSyncWrite(self.port_handler, self.packet_handler, 0, 0) self._comm_success = scs.COMM_SUCCESS - self._error = 0x00 + self._no_error = 0x00 - def broadcast_ping(self, num_retry: int | None = None): - raise NotImplementedError # TODO + def broadcast_ping( + self, num_retry: int = 0, raise_on_error: bool = False + ) -> dict[int, list[int, int]] | None: + # TODO + raise NotImplementedError - def calibrate_values(self, ids_values: dict[int, int]) -> dict[int, float]: + def _calibrate_values(self, ids_values: dict[int, int]) -> dict[int, float]: # TODO return ids_values - def uncalibrate_values(self, ids_values: dict[int, float]) -> dict[int, int]: + def _uncalibrate_values(self, ids_values: dict[int, float]) -> dict[int, int]: # TODO return ids_values @staticmethod - def split_int_bytes(value: int, n_bytes: int) -> list[int]: + def _split_int_to_bytes(value: int, n_bytes: int) -> list[int]: # Validate input if value < 0: raise ValueError(f"Negative values are not allowed: {value}") diff --git a/lerobot/common/motors/motors_bus.py b/lerobot/common/motors/motors_bus.py index 249c66a8..0694b66a 100644 --- a/lerobot/common/motors/motors_bus.py +++ b/lerobot/common/motors/motors_bus.py @@ -259,7 +259,7 @@ class MotorsBus(abc.ABC): self.sync_reader: GroupSyncRead self.sync_writer: GroupSyncWrite self._comm_success: int - self._error: int + self._no_error: int self.calibration = None @@ -285,12 +285,6 @@ class MotorsBus(abc.ABC): first_table = self.model_ctrl_table[self.models[0]] return any(DeepDiff(first_table, self.model_ctrl_table[model]) for model in self.models[1:]) - def id_to_model(self, motor_id: int) -> str: - return self._id_to_model[motor_id] - - def id_to_name(self, motor_id: int) -> str: - return self._id_to_name[motor_id] - @cached_property def names(self) -> list[str]: return list(self.motors) @@ -303,6 +297,20 @@ class MotorsBus(abc.ABC): def ids(self) -> list[int]: return [m.id for m in self.motors.values()] + def _id_to_model(self, motor_id: int) -> str: + return self._id_to_model[motor_id] + + def _id_to_name(self, motor_id: int) -> str: + return self._id_to_name[motor_id] + + def _get_motor_id(self, motor: NameOrID) -> int: + if isinstance(motor, str): + return self.motors[motor].id + elif isinstance(motor, int): + return motor + else: + raise TypeError(f"'{motor}' should be int, str.") + def _validate_motors(self) -> None: # TODO(aliberts): improve error messages for this (display problematics values) if len(self.ids) != len(set(self.ids)): @@ -314,6 +322,12 @@ class MotorsBus(abc.ABC): if any(model not in self.model_resolution_table for model in self.models): raise ValueError("Some motors models are not available.") + def _is_comm_success(self, comm: int) -> bool: + return comm == self._comm_success + + def _is_error(self, error: int) -> bool: + return error != self._no_error + @property def is_connected(self) -> bool: return self.port_handler.is_open @@ -341,6 +355,18 @@ class MotorsBus(abc.ABC): timeout_ms = timeout_ms if timeout_ms is not None else self.default_timeout self.port_handler.setPacketTimeoutMillis(timeout_ms) + def get_baudrate(self) -> int: + return self.port_handler.getBaudRate() + + def set_baudrate(self, baudrate: int) -> None: + present_bus_baudrate = self.port_handler.getBaudRate() + if present_bus_baudrate != baudrate: + logger.info(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.") + self.port_handler.setBaudRate(baudrate) + + if self.port_handler.getBaudRate() != baudrate: + raise OSError("Failed to write bus baud rate.") + @property def are_motors_configured(self) -> bool: """ @@ -355,7 +381,7 @@ class MotorsBus(abc.ABC): return False def ping(self, motor: NameOrID, num_retry: int = 0, raise_on_error: bool = False) -> int | None: - idx = self.get_motor_id(motor) + idx = self._get_motor_id(motor) for n_try in range(1 + num_retry): model_number, comm, error = self.packet_handler.ping(self.port_handler, idx) if self._is_comm_success(comm): @@ -368,16 +394,8 @@ class MotorsBus(abc.ABC): @abc.abstractmethod def broadcast_ping( self, num_retry: int = 0, raise_on_error: bool = False - ) -> dict[int, list[int, int]] | None: ... - - def set_baudrate(self, baudrate) -> None: - present_bus_baudrate = self.port_handler.getBaudRate() - if present_bus_baudrate != baudrate: - logger.info(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.") - self.port_handler.setBaudRate(baudrate) - - if self.port_handler.getBaudRate() != baudrate: - raise OSError("Failed to write bus baud rate.") + ) -> dict[int, list[int, int]] | None: + pass def set_calibration(self, calibration_fpath: Path) -> None: with open(calibration_fpath) as f: @@ -386,22 +404,16 @@ class MotorsBus(abc.ABC): self.calibration = {int(idx): val for idx, val in calibration.items()} @abc.abstractmethod - def calibrate_values(self, ids_values: dict[int, int]) -> dict[int, float]: + def _calibrate_values(self, ids_values: dict[int, int]) -> dict[int, float]: pass @abc.abstractmethod - def uncalibrate_values(self, ids_values: dict[int, float]) -> dict[int, int]: + def _uncalibrate_values(self, ids_values: dict[int, float]) -> dict[int, int]: pass - def _is_comm_success(self, comm: int) -> bool: - return comm == self._comm_success - - def _is_error(self, error: int) -> bool: - return error != self._error - @staticmethod @abc.abstractmethod - def split_int_bytes(value: int, n_bytes: int) -> list[int]: + def _split_int_to_bytes(value: int, n_bytes: int) -> list[int]: """ Splits an unsigned integer into a list of bytes in little-endian order. @@ -437,14 +449,6 @@ class MotorsBus(abc.ABC): """ pass - def get_motor_id(self, motor: NameOrID) -> int: - if isinstance(motor, str): - return self.motors[motor].id - elif isinstance(motor, int): - return motor - else: - raise TypeError(f"'{motor}' should be int, str.") - @overload def sync_read(self, data_name: str, motors: None = ..., num_retry: int = ...) -> dict[str, Value]: ... @overload @@ -463,9 +467,9 @@ class MotorsBus(abc.ABC): if motors is None: id_key_map = {m.id: name for name, m in self.motors.items()} elif isinstance(motors, (str, int)): - id_key_map = {self.get_motor_id(motors): motors} + id_key_map = {self._get_motor_id(motors): motors} elif isinstance(motors, list): - id_key_map = {self.get_motor_id(m): m for m in motors} + id_key_map = {self._get_motor_id(m): m for m in motors} else: raise TypeError(motors) @@ -479,7 +483,7 @@ class MotorsBus(abc.ABC): ) if data_name in self.calibration_required and self.calibration is not None: - ids_values = self.calibrate_values(ids_values) + ids_values = self._calibrate_values(ids_values) return {id_key_map[idx]: val for idx, val in ids_values.items()} @@ -487,10 +491,10 @@ class MotorsBus(abc.ABC): self, data_name: str, motor_ids: list[str], num_retry: int = 0 ) -> tuple[int, dict[int, int]]: if self._has_different_ctrl_tables: - models = [self.id_to_model(idx) for idx in motor_ids] + models = [self._id_to_model(idx) for idx in motor_ids] assert_same_address(self.model_ctrl_table, models, data_name) - model = self.id_to_model(next(iter(motor_ids))) + model = self._id_to_model(next(iter(motor_ids))) addr, n_bytes = self.model_ctrl_table[model][data_name] self._setup_sync_reader(motor_ids, addr, n_bytes) @@ -535,12 +539,12 @@ class MotorsBus(abc.ABC): if isinstance(values, int): ids_values = {id_: values for id_ in self.ids} elif isinstance(values, dict): - ids_values = {self.get_motor_id(motor): val for motor, val in values.items()} + ids_values = {self._get_motor_id(motor): val for motor, val in values.items()} else: raise ValueError(f"'values' is expected to be a single value or a dict. Got {values}") if data_name in self.calibration_required and self.calibration is not None: - ids_values = self.uncalibrate_values(ids_values) + ids_values = self._uncalibrate_values(ids_values) comm = self._sync_write(data_name, ids_values, num_retry) if not self._is_comm_success(comm): @@ -551,10 +555,10 @@ class MotorsBus(abc.ABC): def _sync_write(self, data_name: str, ids_values: dict[int, int], num_retry: int = 0) -> int: if self._has_different_ctrl_tables: - models = [self.id_to_model(idx) for idx in ids_values] + models = [self._id_to_model(idx) for idx in ids_values] assert_same_address(self.model_ctrl_table, models, data_name) - model = self.id_to_model(next(iter(ids_values))) + model = self._id_to_model(next(iter(ids_values))) addr, n_bytes = self.model_ctrl_table[model][data_name] self._setup_sync_writer(ids_values, addr, n_bytes) @@ -574,7 +578,7 @@ class MotorsBus(abc.ABC): self.sync_writer.start_address = addr self.sync_writer.data_length = n_bytes for idx, value in ids_values.items(): - data = self.split_int_bytes(value, n_bytes) + data = self._split_int_to_bytes(value, n_bytes) self.sync_writer.addParam(idx, data) def write(self, data_name: str, motor: NameOrID, value: Value, num_retry: int = 0) -> None: @@ -583,10 +587,10 @@ class MotorsBus(abc.ABC): f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`." ) - idx = self.get_motor_id(motor) + idx = self._get_motor_id(motor) if data_name in self.calibration_required and self.calibration is not None: - id_value = self.uncalibrate_values({idx: value}) + id_value = self._uncalibrate_values({idx: value}) value = id_value[idx] comm, error = self._write(data_name, idx, value, num_retry) @@ -602,9 +606,9 @@ class MotorsBus(abc.ABC): ) def _write(self, data_name: str, motor_id: int, value: int, num_retry: int = 0) -> tuple[int, int]: - model = self.id_to_model(motor_id) + model = self._id_to_model(motor_id) addr, n_bytes = self.model_ctrl_table[model][data_name] - data = self.split_int_bytes(value, n_bytes) + data = self._split_int_to_bytes(value, n_bytes) for n_try in range(1 + num_retry): comm, error = self.packet_handler.writeTxRx(self.port_handler, motor_id, addr, n_bytes, data) diff --git a/tests/mocks/mock_dynamixel.py b/tests/mocks/mock_dynamixel.py index 2f12283c..a7bcf30d 100644 --- a/tests/mocks/mock_dynamixel.py +++ b/tests/mocks/mock_dynamixel.py @@ -282,7 +282,7 @@ class MockInstructionPacket(MockDynamixelPacketv2): """ data = [] for idx, value in ids_values.items(): - split_value = DynamixelMotorsBus.split_int_bytes(value, data_length) + split_value = DynamixelMotorsBus._split_int_to_bytes(value, data_length) data += [idx, *split_value] params = [ dxl.DXL_LOBYTE(start_address), @@ -319,7 +319,7 @@ class MockInstructionPacket(MockDynamixelPacketv2): +2 is for the length bytes, +2 is for the CRC at the end. """ - data = DynamixelMotorsBus.split_int_bytes(value, data_length) + data = DynamixelMotorsBus._split_int_to_bytes(value, data_length) params = [ dxl.DXL_LOBYTE(start_address), dxl.DXL_HIBYTE(start_address), diff --git a/tests/mocks/mock_feetech.py b/tests/mocks/mock_feetech.py index 8686e5af..b1478a42 100644 --- a/tests/mocks/mock_feetech.py +++ b/tests/mocks/mock_feetech.py @@ -149,7 +149,7 @@ class MockInstructionPacket(MockFeetechPacket): """ data = [] for idx, value in ids_values.items(): - split_value = FeetechMotorsBus.split_int_bytes(value, data_length) + split_value = FeetechMotorsBus._split_int_to_bytes(value, data_length) data += [idx, *split_value] params = [start_address, data_length, *data] length = len(ids_values) * (1 + data_length) + 4 @@ -179,7 +179,7 @@ class MockInstructionPacket(MockFeetechPacket): +1 is for the length bytes, +1 is for the checksum at the end. """ - data = FeetechMotorsBus.split_int_bytes(value, data_length) + data = FeetechMotorsBus._split_int_to_bytes(value, data_length) params = [start_address, *data] length = data_length + 3 return cls.build(scs_id=scs_id, params=params, length=length, instruct_type="Write") diff --git a/tests/motors/test_dynamixel.py b/tests/motors/test_dynamixel.py index 278c24d2..7fdd2618 100644 --- a/tests/motors/test_dynamixel.py +++ b/tests/motors/test_dynamixel.py @@ -67,24 +67,24 @@ def test_autouse_patch(): "max four bytes", ], ) # fmt: skip -def test_split_int_bytes(value, n_bytes, expected): - assert DynamixelMotorsBus.split_int_bytes(value, n_bytes) == expected +def test_split_int_to_bytes(value, n_bytes, expected): + assert DynamixelMotorsBus._split_int_to_bytes(value, n_bytes) == expected -def test_split_int_bytes_invalid_n_bytes(): +def test_split_int_to_bytes_invalid_n_bytes(): with pytest.raises(NotImplementedError): - DynamixelMotorsBus.split_int_bytes(100, 3) + DynamixelMotorsBus._split_int_to_bytes(100, 3) -def test_split_int_bytes_negative_numbers(): +def test_split_int_to_bytes_negative_numbers(): with pytest.raises(ValueError): - neg = DynamixelMotorsBus.split_int_bytes(-1, 1) + neg = DynamixelMotorsBus._split_int_to_bytes(-1, 1) print(neg) -def test_split_int_bytes_large_number(): +def test_split_int_to_bytes_large_number(): with pytest.raises(ValueError): - DynamixelMotorsBus.split_int_bytes(2**32, 4) # 4-byte max is 0xFFFFFFFF + DynamixelMotorsBus._split_int_to_bytes(2**32, 4) # 4-byte max is 0xFFFFFFFF def test_abc_implementation(dummy_motors): diff --git a/tests/motors/test_feetech.py b/tests/motors/test_feetech.py index 45ffd575..1b50d045 100644 --- a/tests/motors/test_feetech.py +++ b/tests/motors/test_feetech.py @@ -67,24 +67,24 @@ def test_autouse_patch(): "max four bytes", ], ) # fmt: skip -def test_split_int_bytes(value, n_bytes, expected): - assert FeetechMotorsBus.split_int_bytes(value, n_bytes) == expected +def test_split_int_to_bytes(value, n_bytes, expected): + assert FeetechMotorsBus._split_int_to_bytes(value, n_bytes) == expected -def test_split_int_bytes_invalid_n_bytes(): +def test_split_int_to_bytes_invalid_n_bytes(): with pytest.raises(NotImplementedError): - FeetechMotorsBus.split_int_bytes(100, 3) + FeetechMotorsBus._split_int_to_bytes(100, 3) -def test_split_int_bytes_negative_numbers(): +def test_split_int_to_bytes_negative_numbers(): with pytest.raises(ValueError): - neg = FeetechMotorsBus.split_int_bytes(-1, 1) + neg = FeetechMotorsBus._split_int_to_bytes(-1, 1) print(neg) -def test_split_int_bytes_large_number(): +def test_split_int_to_bytes_large_number(): with pytest.raises(ValueError): - FeetechMotorsBus.split_int_bytes(2**32, 4) # 4-byte max is 0xFFFFFFFF + FeetechMotorsBus._split_int_to_bytes(2**32, 4) # 4-byte max is 0xFFFFFFFF def test_abc_implementation(dummy_motors):