diff --git a/lerobot/common/motors/__init__.py b/lerobot/common/motors/__init__.py index a746ba96..3261ae7e 100644 --- a/lerobot/common/motors/__init__.py +++ b/lerobot/common/motors/__init__.py @@ -1,3 +1 @@ -from .motors_bus import CalibrationMode, DriveMode, MotorsBus, TorqueMode - -__all__ = ["CalibrationMode", "DriveMode", "MotorsBus", "TorqueMode"] +from .motors_bus import CalibrationMode, DriveMode, Motor, MotorsBus, TorqueMode diff --git a/lerobot/common/motors/motors_bus.py b/lerobot/common/motors/motors_bus.py index 24683730..0ae7d223 100644 --- a/lerobot/common/motors/motors_bus.py +++ b/lerobot/common/motors/motors_bus.py @@ -22,11 +22,12 @@ import abc import json import time +from dataclasses import dataclass from enum import Enum from functools import cached_property from pathlib import Path from pprint import pformat -from typing import Protocol +from typing import Protocol, TypeAlias, Union import serial import tqdm @@ -35,6 +36,8 @@ from deepdiff import DeepDiff from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from lerobot.common.utils.utils import capture_timestamp_utc +MotorLike: TypeAlias = Union[str, int, "Motor"] + MAX_ID_RANGE = 252 @@ -197,6 +200,12 @@ class GroupSyncWrite(Protocol): def txPacket(self): ... +@dataclass +class Motor: + id: int + model: str + + class MotorsBus(abc.ABC): """The main LeRobot class for implementing motors buses. @@ -248,7 +257,7 @@ class MotorsBus(abc.ABC): def __init__( self, port: str, - motors: dict[str, tuple[int, str]], + motors: dict[str, Motor], ): self.port = port self.motors = motors @@ -262,8 +271,8 @@ class MotorsBus(abc.ABC): self.logs = {} # TODO(aliberts): use subclass logger self.calibration = None - self._id_to_model = dict(self.motors.values()) - self._id_to_name = {idx: name for name, (idx, _) in self.motors.items()} + self._id_to_model = {m.id: m.model for m in self.motors.values()} + self._id_to_name = {m.id: name for name, m in self.motors.items()} def __len__(self): return len(self.motors) @@ -278,39 +287,39 @@ class MotorsBus(abc.ABC): @cached_property def _has_different_ctrl_tables(self) -> bool: - if len(self.motor_models) < 2: + if len(self.models) < 2: return False - first_table = self.model_ctrl_table[self.motor_models[0]] - return any(DeepDiff(first_table, self.model_ctrl_table[model]) for model in self.motor_models[1:]) + first_table = self.model_ctrl_table[self.models[0]] + return any(DeepDiff(first_table, self.model_ctrl_table[model]) for model in self.models[1:]) - def idx_to_model(self, idx: int) -> str: - return self._id_to_model[idx] + def id_to_model(self, motor_id: int) -> str: + return self._id_to_model[motor_id] - def idx_to_name(self, idx: int) -> str: - return self._id_to_name[idx] + def id_to_name(self, motor_id: int) -> str: + return self._id_to_name[motor_id] @cached_property - def motor_names(self) -> list[str]: + def names(self) -> list[str]: return list(self.motors) @cached_property - def motor_models(self) -> list[str]: - return [model for _, model in self.motors.values()] + def models(self) -> list[str]: + return [m.model for m in self.motors.values()] @cached_property - def motor_ids(self) -> list[int]: - return [idx for idx, _ in self.motors.values()] + def ids(self) -> list[int]: + return [m.id for m in self.motors.values()] def _validate_motors(self) -> None: # TODO(aliberts): improve error messages for this (display problematics values) - if len(self.motor_ids) != len(set(self.motor_ids)): + if len(self.ids) != len(set(self.ids)): raise ValueError("Some motors have the same id.") - if len(self.motor_names) != len(set(self.motor_names)): + if len(self.names) != len(set(self.names)): raise ValueError("Some motors have the same name.") - if any(model not in self.model_resolution_table for model in self.motor_models): + if any(model not in self.model_resolution_table for model in self.models): raise ValueError("Some motors models are not available.") @property @@ -347,13 +356,13 @@ class MotorsBus(abc.ABC): """ try: # TODO(aliberts): use ping instead - return (self.motor_ids == self.read("ID")).all() + return (self.ids == self.read("ID")).all() except ConnectionError as e: print(e) return False - def ping(self, motor: str | int, num_retry: int | None = None) -> int: - idx = self.get_safe_id(motor) + def ping(self, motor: MotorLike, num_retry: int | None = None) -> int: + idx = self.get_motor_id(motor) for _ in range(num_retry): model_number, comm, _ = self.packet_handler.ping(self.port, idx) if self._is_comm_success(comm): @@ -453,16 +462,18 @@ class MotorsBus(abc.ABC): """ pass - def get_safe_id(self, motor: str | int) -> int: + def get_motor_id(self, motor: MotorLike) -> int: if isinstance(motor, str): - return self.motors[motor][0] + return self.motors[motor].id elif isinstance(motor, int): return motor + elif isinstance(motor, Motor): + return motor.id else: - raise ValueError(f"{motor} should be int or str.") + raise ValueError(f"{motor} should be int, str or Motor.") def read( - self, data_name: str, motors: str | int | list[str | int] | None = None, num_retry: int = 1 + self, data_name: str, motors: MotorLike | list[MotorLike] | None = None, num_retry: int = 1 ) -> dict[str, float]: if not self.is_connected: raise DeviceNotConnectedError( @@ -472,17 +483,17 @@ class MotorsBus(abc.ABC): start_time = time.perf_counter() if motors is None: - motors = self.motor_ids + motors = self.ids if isinstance(motors, (str, int)): motors = [motors] - motor_ids = [self.get_safe_id(motor) for motor in motors] + motor_ids = [self.get_motor_id(motor) for motor in motors] if self._has_different_ctrl_tables: - models = [self.idx_to_model(idx) for idx in motor_ids] + models = [self.id_to_model(idx) for idx in motor_ids] assert_same_address(self.model_ctrl_table, models, data_name) - model = self.idx_to_model(next(iter(motor_ids))) + model = self.id_to_model(next(iter(motor_ids))) addr, n_bytes = self.model_ctrl_table[model][data_name] comm, ids_values = self._read(motor_ids, addr, n_bytes, num_retry) @@ -496,7 +507,7 @@ class MotorsBus(abc.ABC): ids_values = self.calibrate_values(ids_values) # TODO(aliberts): return keys in the same format we got them? - ids_values = {self.idx_to_name(idx): val for idx, val in ids_values.items()} + ids_values = {self.id_to_name(idx): val for idx, val in ids_values.items()} # log the number of seconds it took to read the data from the motors delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_ids) @@ -540,7 +551,7 @@ class MotorsBus(abc.ABC): # for idx in motor_ids: # value = self.reader.getData(idx, address, n_bytes) - def write(self, data_name: str, values_dict: dict[str | int, int], num_retry: int = 1) -> None: + def write(self, data_name: str, values_dict: dict[MotorLike, int], num_retry: int = 1) -> None: if not self.is_connected: raise DeviceNotConnectedError( f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`." @@ -548,16 +559,16 @@ class MotorsBus(abc.ABC): start_time = time.perf_counter() - ids_values = {self.get_safe_id(motor): val for motor, val in values_dict.items()} + ids_values = {self.get_motor_id(motor): val for motor, val in values_dict.items()} if self._has_different_ctrl_tables: - models = [self.idx_to_model(idx) for idx in ids_values] + models = [self.id_to_model(idx) for idx in ids_values] assert_same_address(self.model_ctrl_table, models, data_name) if data_name in self.calibration_required and self.calibration is not None: ids_values = self.uncalibrate_values(ids_values) - model = self.idx_to_model(next(iter(ids_values))) + model = self.id_to_model(next(iter(ids_values))) addr, n_bytes = self.model_ctrl_table[model][data_name] comm = self._write(ids_values, addr, n_bytes, num_retry) diff --git a/tests/motors/test_dynamixel.py b/tests/motors/test_dynamixel.py index ec391dc4..ce50dd91 100644 --- a/tests/motors/test_dynamixel.py +++ b/tests/motors/test_dynamixel.py @@ -4,6 +4,7 @@ from unittest.mock import patch import dynamixel_sdk as dxl import pytest +from lerobot.common.motors import Motor from lerobot.common.motors.dynamixel import DynamixelMotorsBus from tests.mocks.mock_dynamixel import MockMotors, MockPortHandler @@ -17,6 +18,15 @@ def patch_port_handler(): yield +@pytest.fixture +def dummy_motors() -> dict[str, Motor]: + return { + "dummy_1": Motor(id=1, model="xl430-w250"), + "dummy_2": Motor(id=2, model="xm540-w270"), + "dummy_3": Motor(id=3, model="xl330-m077"), + } + + @pytest.mark.skipif(sys.platform != "darwin", reason=f"No patching needed on {sys.platform=}") def test_autouse_patch(): """Ensures that the autouse fixture correctly patches dxl.PortHandler with MockPortHandler.""" @@ -68,8 +78,9 @@ def test_split_int_bytes_large_number(): DynamixelMotorsBus.split_int_bytes(2**32, 4) # 4-byte max is 0xFFFFFFFF -def test_abc_implementation(): +def test_abc_implementation(dummy_motors): """Instantiation should raise an error if the class doesn't implement abstract methods/properties.""" + DynamixelMotorsBus(port="/dev/dummy-port", motors=dummy_motors) DynamixelMotorsBus(port="/dev/dummy-port", motors={"dummy": (1, "xl330-m077")}) @@ -83,17 +94,13 @@ def test_abc_implementation(): ], ids=["None", "by ids", "by names", "mixed"], ) -def test_read_all_motors(motors): +def test_read_all_motors(motors, dummy_motors): mock_motors = MockMotors([1, 2, 3]) positions = [1337, 42, 4016] - mock_motors.build_all_motors_stub("Present_Position", return_values=positions) + mock_motors.build_sync_read_all_motors_stub("Present_Position", return_values=positions) motors_bus = DynamixelMotorsBus( port=mock_motors.port, - motors={ - "dummy_1": (1, "xl330-m077"), - "dummy_2": (2, "xl330-m077"), - "dummy_3": (3, "xl330-m077"), - }, + motors=dummy_motors, ) motors_bus.connect() @@ -113,16 +120,12 @@ def test_read_all_motors(motors): [3, 4016], ], ) -def test_read_single_motor_by_name(idx, pos): +def test_read_single_motor_by_name(idx, pos, dummy_motors): mock_motors = MockMotors([1, 2, 3]) - mock_motors.build_single_motor_stubs("Present_Position", return_value=pos) + mock_motors.build_sync_read_single_motor_stubs("Present_Position", return_value=pos) motors_bus = DynamixelMotorsBus( port=mock_motors.port, - motors={ - "dummy_1": (1, "xl330-m077"), - "dummy_2": (2, "xl330-m077"), - "dummy_3": (3, "xl330-m077"), - }, + motors=dummy_motors, ) motors_bus.connect() @@ -141,16 +144,12 @@ def test_read_single_motor_by_name(idx, pos): [3, 4016], ], ) -def test_read_single_motor_by_id(idx, pos): +def test_read_single_motor_by_id(idx, pos, dummy_motors): mock_motors = MockMotors([1, 2, 3]) - mock_motors.build_single_motor_stubs("Present_Position", return_value=pos) + mock_motors.build_sync_read_single_motor_stubs("Present_Position", return_value=pos) motors_bus = DynamixelMotorsBus( port=mock_motors.port, - motors={ - "dummy_1": (1, "xl330-m077"), - "dummy_2": (2, "xl330-m077"), - "dummy_3": (3, "xl330-m077"), - }, + motors=dummy_motors, ) motors_bus.connect() @@ -170,18 +169,14 @@ def test_read_single_motor_by_id(idx, pos): [2, 1, 999], ], ) -def test_read_num_retry(num_retry, num_invalid_try, pos): +def test_read_num_retry(num_retry, num_invalid_try, pos, dummy_motors): mock_motors = MockMotors([1, 2, 3]) - mock_motors.build_single_motor_stubs( + mock_motors.build_sync_read_single_motor_stubs( "Present_Position", return_value=pos, num_invalid_try=num_invalid_try ) motors_bus = DynamixelMotorsBus( port=mock_motors.port, - motors={ - "dummy_1": (1, "xl330-m077"), - "dummy_2": (2, "xl330-m077"), - "dummy_3": (3, "xl330-m077"), - }, + motors=dummy_motors, ) motors_bus.connect()