Add Motor class

This commit is contained in:
Simon Alibert 2025-03-21 12:13:44 +01:00
parent 56c04ffc53
commit a32081757d
3 changed files with 71 additions and 67 deletions

View File

@ -1,3 +1 @@
from .motors_bus import CalibrationMode, DriveMode, MotorsBus, TorqueMode from .motors_bus import CalibrationMode, DriveMode, Motor, MotorsBus, TorqueMode
__all__ = ["CalibrationMode", "DriveMode", "MotorsBus", "TorqueMode"]

View File

@ -22,11 +22,12 @@
import abc import abc
import json import json
import time import time
from dataclasses import dataclass
from enum import Enum from enum import Enum
from functools import cached_property from functools import cached_property
from pathlib import Path from pathlib import Path
from pprint import pformat from pprint import pformat
from typing import Protocol from typing import Protocol, TypeAlias, Union
import serial import serial
import tqdm import tqdm
@ -35,6 +36,8 @@ from deepdiff import DeepDiff
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.common.utils.utils import capture_timestamp_utc from lerobot.common.utils.utils import capture_timestamp_utc
MotorLike: TypeAlias = Union[str, int, "Motor"]
MAX_ID_RANGE = 252 MAX_ID_RANGE = 252
@ -197,6 +200,12 @@ class GroupSyncWrite(Protocol):
def txPacket(self): ... def txPacket(self): ...
@dataclass
class Motor:
id: int
model: str
class MotorsBus(abc.ABC): class MotorsBus(abc.ABC):
"""The main LeRobot class for implementing motors buses. """The main LeRobot class for implementing motors buses.
@ -248,7 +257,7 @@ class MotorsBus(abc.ABC):
def __init__( def __init__(
self, self,
port: str, port: str,
motors: dict[str, tuple[int, str]], motors: dict[str, Motor],
): ):
self.port = port self.port = port
self.motors = motors self.motors = motors
@ -262,8 +271,8 @@ class MotorsBus(abc.ABC):
self.logs = {} # TODO(aliberts): use subclass logger self.logs = {} # TODO(aliberts): use subclass logger
self.calibration = None self.calibration = None
self._id_to_model = dict(self.motors.values()) self._id_to_model = {m.id: m.model for m in self.motors.values()}
self._id_to_name = {idx: name for name, (idx, _) in self.motors.items()} self._id_to_name = {m.id: name for name, m in self.motors.items()}
def __len__(self): def __len__(self):
return len(self.motors) return len(self.motors)
@ -278,39 +287,39 @@ class MotorsBus(abc.ABC):
@cached_property @cached_property
def _has_different_ctrl_tables(self) -> bool: def _has_different_ctrl_tables(self) -> bool:
if len(self.motor_models) < 2: if len(self.models) < 2:
return False return False
first_table = self.model_ctrl_table[self.motor_models[0]] first_table = self.model_ctrl_table[self.models[0]]
return any(DeepDiff(first_table, self.model_ctrl_table[model]) for model in self.motor_models[1:]) return any(DeepDiff(first_table, self.model_ctrl_table[model]) for model in self.models[1:])
def idx_to_model(self, idx: int) -> str: def id_to_model(self, motor_id: int) -> str:
return self._id_to_model[idx] return self._id_to_model[motor_id]
def idx_to_name(self, idx: int) -> str: def id_to_name(self, motor_id: int) -> str:
return self._id_to_name[idx] return self._id_to_name[motor_id]
@cached_property @cached_property
def motor_names(self) -> list[str]: def names(self) -> list[str]:
return list(self.motors) return list(self.motors)
@cached_property @cached_property
def motor_models(self) -> list[str]: def models(self) -> list[str]:
return [model for _, model in self.motors.values()] return [m.model for m in self.motors.values()]
@cached_property @cached_property
def motor_ids(self) -> list[int]: def ids(self) -> list[int]:
return [idx for idx, _ in self.motors.values()] return [m.id for m in self.motors.values()]
def _validate_motors(self) -> None: def _validate_motors(self) -> None:
# TODO(aliberts): improve error messages for this (display problematics values) # TODO(aliberts): improve error messages for this (display problematics values)
if len(self.motor_ids) != len(set(self.motor_ids)): if len(self.ids) != len(set(self.ids)):
raise ValueError("Some motors have the same id.") raise ValueError("Some motors have the same id.")
if len(self.motor_names) != len(set(self.motor_names)): if len(self.names) != len(set(self.names)):
raise ValueError("Some motors have the same name.") raise ValueError("Some motors have the same name.")
if any(model not in self.model_resolution_table for model in self.motor_models): if any(model not in self.model_resolution_table for model in self.models):
raise ValueError("Some motors models are not available.") raise ValueError("Some motors models are not available.")
@property @property
@ -347,13 +356,13 @@ class MotorsBus(abc.ABC):
""" """
try: try:
# TODO(aliberts): use ping instead # TODO(aliberts): use ping instead
return (self.motor_ids == self.read("ID")).all() return (self.ids == self.read("ID")).all()
except ConnectionError as e: except ConnectionError as e:
print(e) print(e)
return False return False
def ping(self, motor: str | int, num_retry: int | None = None) -> int: def ping(self, motor: MotorLike, num_retry: int | None = None) -> int:
idx = self.get_safe_id(motor) idx = self.get_motor_id(motor)
for _ in range(num_retry): for _ in range(num_retry):
model_number, comm, _ = self.packet_handler.ping(self.port, idx) model_number, comm, _ = self.packet_handler.ping(self.port, idx)
if self._is_comm_success(comm): if self._is_comm_success(comm):
@ -453,16 +462,18 @@ class MotorsBus(abc.ABC):
""" """
pass pass
def get_safe_id(self, motor: str | int) -> int: def get_motor_id(self, motor: MotorLike) -> int:
if isinstance(motor, str): if isinstance(motor, str):
return self.motors[motor][0] return self.motors[motor].id
elif isinstance(motor, int): elif isinstance(motor, int):
return motor return motor
elif isinstance(motor, Motor):
return motor.id
else: else:
raise ValueError(f"{motor} should be int or str.") raise ValueError(f"{motor} should be int, str or Motor.")
def read( def read(
self, data_name: str, motors: str | int | list[str | int] | None = None, num_retry: int = 1 self, data_name: str, motors: MotorLike | list[MotorLike] | None = None, num_retry: int = 1
) -> dict[str, float]: ) -> dict[str, float]:
if not self.is_connected: if not self.is_connected:
raise DeviceNotConnectedError( raise DeviceNotConnectedError(
@ -472,17 +483,17 @@ class MotorsBus(abc.ABC):
start_time = time.perf_counter() start_time = time.perf_counter()
if motors is None: if motors is None:
motors = self.motor_ids motors = self.ids
if isinstance(motors, (str, int)): if isinstance(motors, (str, int)):
motors = [motors] motors = [motors]
motor_ids = [self.get_safe_id(motor) for motor in motors] motor_ids = [self.get_motor_id(motor) for motor in motors]
if self._has_different_ctrl_tables: if self._has_different_ctrl_tables:
models = [self.idx_to_model(idx) for idx in motor_ids] models = [self.id_to_model(idx) for idx in motor_ids]
assert_same_address(self.model_ctrl_table, models, data_name) assert_same_address(self.model_ctrl_table, models, data_name)
model = self.idx_to_model(next(iter(motor_ids))) model = self.id_to_model(next(iter(motor_ids)))
addr, n_bytes = self.model_ctrl_table[model][data_name] addr, n_bytes = self.model_ctrl_table[model][data_name]
comm, ids_values = self._read(motor_ids, addr, n_bytes, num_retry) comm, ids_values = self._read(motor_ids, addr, n_bytes, num_retry)
@ -496,7 +507,7 @@ class MotorsBus(abc.ABC):
ids_values = self.calibrate_values(ids_values) ids_values = self.calibrate_values(ids_values)
# TODO(aliberts): return keys in the same format we got them? # TODO(aliberts): return keys in the same format we got them?
ids_values = {self.idx_to_name(idx): val for idx, val in ids_values.items()} ids_values = {self.id_to_name(idx): val for idx, val in ids_values.items()}
# log the number of seconds it took to read the data from the motors # log the number of seconds it took to read the data from the motors
delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_ids) delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_ids)
@ -540,7 +551,7 @@ class MotorsBus(abc.ABC):
# for idx in motor_ids: # for idx in motor_ids:
# value = self.reader.getData(idx, address, n_bytes) # value = self.reader.getData(idx, address, n_bytes)
def write(self, data_name: str, values_dict: dict[str | int, int], num_retry: int = 1) -> None: def write(self, data_name: str, values_dict: dict[MotorLike, int], num_retry: int = 1) -> None:
if not self.is_connected: if not self.is_connected:
raise DeviceNotConnectedError( raise DeviceNotConnectedError(
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`." f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
@ -548,16 +559,16 @@ class MotorsBus(abc.ABC):
start_time = time.perf_counter() start_time = time.perf_counter()
ids_values = {self.get_safe_id(motor): val for motor, val in values_dict.items()} ids_values = {self.get_motor_id(motor): val for motor, val in values_dict.items()}
if self._has_different_ctrl_tables: if self._has_different_ctrl_tables:
models = [self.idx_to_model(idx) for idx in ids_values] models = [self.id_to_model(idx) for idx in ids_values]
assert_same_address(self.model_ctrl_table, models, data_name) assert_same_address(self.model_ctrl_table, models, data_name)
if data_name in self.calibration_required and self.calibration is not None: if data_name in self.calibration_required and self.calibration is not None:
ids_values = self.uncalibrate_values(ids_values) ids_values = self.uncalibrate_values(ids_values)
model = self.idx_to_model(next(iter(ids_values))) model = self.id_to_model(next(iter(ids_values)))
addr, n_bytes = self.model_ctrl_table[model][data_name] addr, n_bytes = self.model_ctrl_table[model][data_name]
comm = self._write(ids_values, addr, n_bytes, num_retry) comm = self._write(ids_values, addr, n_bytes, num_retry)

View File

@ -4,6 +4,7 @@ from unittest.mock import patch
import dynamixel_sdk as dxl import dynamixel_sdk as dxl
import pytest import pytest
from lerobot.common.motors import Motor
from lerobot.common.motors.dynamixel import DynamixelMotorsBus from lerobot.common.motors.dynamixel import DynamixelMotorsBus
from tests.mocks.mock_dynamixel import MockMotors, MockPortHandler from tests.mocks.mock_dynamixel import MockMotors, MockPortHandler
@ -17,6 +18,15 @@ def patch_port_handler():
yield yield
@pytest.fixture
def dummy_motors() -> dict[str, Motor]:
return {
"dummy_1": Motor(id=1, model="xl430-w250"),
"dummy_2": Motor(id=2, model="xm540-w270"),
"dummy_3": Motor(id=3, model="xl330-m077"),
}
@pytest.mark.skipif(sys.platform != "darwin", reason=f"No patching needed on {sys.platform=}") @pytest.mark.skipif(sys.platform != "darwin", reason=f"No patching needed on {sys.platform=}")
def test_autouse_patch(): def test_autouse_patch():
"""Ensures that the autouse fixture correctly patches dxl.PortHandler with MockPortHandler.""" """Ensures that the autouse fixture correctly patches dxl.PortHandler with MockPortHandler."""
@ -68,8 +78,9 @@ def test_split_int_bytes_large_number():
DynamixelMotorsBus.split_int_bytes(2**32, 4) # 4-byte max is 0xFFFFFFFF DynamixelMotorsBus.split_int_bytes(2**32, 4) # 4-byte max is 0xFFFFFFFF
def test_abc_implementation(): def test_abc_implementation(dummy_motors):
"""Instantiation should raise an error if the class doesn't implement abstract methods/properties.""" """Instantiation should raise an error if the class doesn't implement abstract methods/properties."""
DynamixelMotorsBus(port="/dev/dummy-port", motors=dummy_motors)
DynamixelMotorsBus(port="/dev/dummy-port", motors={"dummy": (1, "xl330-m077")}) DynamixelMotorsBus(port="/dev/dummy-port", motors={"dummy": (1, "xl330-m077")})
@ -83,17 +94,13 @@ def test_abc_implementation():
], ],
ids=["None", "by ids", "by names", "mixed"], ids=["None", "by ids", "by names", "mixed"],
) )
def test_read_all_motors(motors): def test_read_all_motors(motors, dummy_motors):
mock_motors = MockMotors([1, 2, 3]) mock_motors = MockMotors([1, 2, 3])
positions = [1337, 42, 4016] positions = [1337, 42, 4016]
mock_motors.build_all_motors_stub("Present_Position", return_values=positions) mock_motors.build_sync_read_all_motors_stub("Present_Position", return_values=positions)
motors_bus = DynamixelMotorsBus( motors_bus = DynamixelMotorsBus(
port=mock_motors.port, port=mock_motors.port,
motors={ motors=dummy_motors,
"dummy_1": (1, "xl330-m077"),
"dummy_2": (2, "xl330-m077"),
"dummy_3": (3, "xl330-m077"),
},
) )
motors_bus.connect() motors_bus.connect()
@ -113,16 +120,12 @@ def test_read_all_motors(motors):
[3, 4016], [3, 4016],
], ],
) )
def test_read_single_motor_by_name(idx, pos): def test_read_single_motor_by_name(idx, pos, dummy_motors):
mock_motors = MockMotors([1, 2, 3]) mock_motors = MockMotors([1, 2, 3])
mock_motors.build_single_motor_stubs("Present_Position", return_value=pos) mock_motors.build_sync_read_single_motor_stubs("Present_Position", return_value=pos)
motors_bus = DynamixelMotorsBus( motors_bus = DynamixelMotorsBus(
port=mock_motors.port, port=mock_motors.port,
motors={ motors=dummy_motors,
"dummy_1": (1, "xl330-m077"),
"dummy_2": (2, "xl330-m077"),
"dummy_3": (3, "xl330-m077"),
},
) )
motors_bus.connect() motors_bus.connect()
@ -141,16 +144,12 @@ def test_read_single_motor_by_name(idx, pos):
[3, 4016], [3, 4016],
], ],
) )
def test_read_single_motor_by_id(idx, pos): def test_read_single_motor_by_id(idx, pos, dummy_motors):
mock_motors = MockMotors([1, 2, 3]) mock_motors = MockMotors([1, 2, 3])
mock_motors.build_single_motor_stubs("Present_Position", return_value=pos) mock_motors.build_sync_read_single_motor_stubs("Present_Position", return_value=pos)
motors_bus = DynamixelMotorsBus( motors_bus = DynamixelMotorsBus(
port=mock_motors.port, port=mock_motors.port,
motors={ motors=dummy_motors,
"dummy_1": (1, "xl330-m077"),
"dummy_2": (2, "xl330-m077"),
"dummy_3": (3, "xl330-m077"),
},
) )
motors_bus.connect() motors_bus.connect()
@ -170,18 +169,14 @@ def test_read_single_motor_by_id(idx, pos):
[2, 1, 999], [2, 1, 999],
], ],
) )
def test_read_num_retry(num_retry, num_invalid_try, pos): def test_read_num_retry(num_retry, num_invalid_try, pos, dummy_motors):
mock_motors = MockMotors([1, 2, 3]) mock_motors = MockMotors([1, 2, 3])
mock_motors.build_single_motor_stubs( mock_motors.build_sync_read_single_motor_stubs(
"Present_Position", return_value=pos, num_invalid_try=num_invalid_try "Present_Position", return_value=pos, num_invalid_try=num_invalid_try
) )
motors_bus = DynamixelMotorsBus( motors_bus = DynamixelMotorsBus(
port=mock_motors.port, port=mock_motors.port,
motors={ motors=dummy_motors,
"dummy_1": (1, "xl330-m077"),
"dummy_2": (2, "xl330-m077"),
"dummy_3": (3, "xl330-m077"),
},
) )
motors_bus.connect() motors_bus.connect()