Add Motor class

This commit is contained in:
Simon Alibert 2025-03-21 12:13:44 +01:00
parent 56c04ffc53
commit a32081757d
3 changed files with 71 additions and 67 deletions

View File

@ -1,3 +1 @@
from .motors_bus import CalibrationMode, DriveMode, MotorsBus, TorqueMode
__all__ = ["CalibrationMode", "DriveMode", "MotorsBus", "TorqueMode"]
from .motors_bus import CalibrationMode, DriveMode, Motor, MotorsBus, TorqueMode

View File

@ -22,11 +22,12 @@
import abc
import json
import time
from dataclasses import dataclass
from enum import Enum
from functools import cached_property
from pathlib import Path
from pprint import pformat
from typing import Protocol
from typing import Protocol, TypeAlias, Union
import serial
import tqdm
@ -35,6 +36,8 @@ from deepdiff import DeepDiff
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.common.utils.utils import capture_timestamp_utc
MotorLike: TypeAlias = Union[str, int, "Motor"]
MAX_ID_RANGE = 252
@ -197,6 +200,12 @@ class GroupSyncWrite(Protocol):
def txPacket(self): ...
@dataclass
class Motor:
id: int
model: str
class MotorsBus(abc.ABC):
"""The main LeRobot class for implementing motors buses.
@ -248,7 +257,7 @@ class MotorsBus(abc.ABC):
def __init__(
self,
port: str,
motors: dict[str, tuple[int, str]],
motors: dict[str, Motor],
):
self.port = port
self.motors = motors
@ -262,8 +271,8 @@ class MotorsBus(abc.ABC):
self.logs = {} # TODO(aliberts): use subclass logger
self.calibration = None
self._id_to_model = dict(self.motors.values())
self._id_to_name = {idx: name for name, (idx, _) in self.motors.items()}
self._id_to_model = {m.id: m.model for m in self.motors.values()}
self._id_to_name = {m.id: name for name, m in self.motors.items()}
def __len__(self):
return len(self.motors)
@ -278,39 +287,39 @@ class MotorsBus(abc.ABC):
@cached_property
def _has_different_ctrl_tables(self) -> bool:
if len(self.motor_models) < 2:
if len(self.models) < 2:
return False
first_table = self.model_ctrl_table[self.motor_models[0]]
return any(DeepDiff(first_table, self.model_ctrl_table[model]) for model in self.motor_models[1:])
first_table = self.model_ctrl_table[self.models[0]]
return any(DeepDiff(first_table, self.model_ctrl_table[model]) for model in self.models[1:])
def idx_to_model(self, idx: int) -> str:
return self._id_to_model[idx]
def id_to_model(self, motor_id: int) -> str:
return self._id_to_model[motor_id]
def idx_to_name(self, idx: int) -> str:
return self._id_to_name[idx]
def id_to_name(self, motor_id: int) -> str:
return self._id_to_name[motor_id]
@cached_property
def motor_names(self) -> list[str]:
def names(self) -> list[str]:
return list(self.motors)
@cached_property
def motor_models(self) -> list[str]:
return [model for _, model in self.motors.values()]
def models(self) -> list[str]:
return [m.model for m in self.motors.values()]
@cached_property
def motor_ids(self) -> list[int]:
return [idx for idx, _ in self.motors.values()]
def ids(self) -> list[int]:
return [m.id for m in self.motors.values()]
def _validate_motors(self) -> None:
# TODO(aliberts): improve error messages for this (display problematics values)
if len(self.motor_ids) != len(set(self.motor_ids)):
if len(self.ids) != len(set(self.ids)):
raise ValueError("Some motors have the same id.")
if len(self.motor_names) != len(set(self.motor_names)):
if len(self.names) != len(set(self.names)):
raise ValueError("Some motors have the same name.")
if any(model not in self.model_resolution_table for model in self.motor_models):
if any(model not in self.model_resolution_table for model in self.models):
raise ValueError("Some motors models are not available.")
@property
@ -347,13 +356,13 @@ class MotorsBus(abc.ABC):
"""
try:
# TODO(aliberts): use ping instead
return (self.motor_ids == self.read("ID")).all()
return (self.ids == self.read("ID")).all()
except ConnectionError as e:
print(e)
return False
def ping(self, motor: str | int, num_retry: int | None = None) -> int:
idx = self.get_safe_id(motor)
def ping(self, motor: MotorLike, num_retry: int | None = None) -> int:
idx = self.get_motor_id(motor)
for _ in range(num_retry):
model_number, comm, _ = self.packet_handler.ping(self.port, idx)
if self._is_comm_success(comm):
@ -453,16 +462,18 @@ class MotorsBus(abc.ABC):
"""
pass
def get_safe_id(self, motor: str | int) -> int:
def get_motor_id(self, motor: MotorLike) -> int:
if isinstance(motor, str):
return self.motors[motor][0]
return self.motors[motor].id
elif isinstance(motor, int):
return motor
elif isinstance(motor, Motor):
return motor.id
else:
raise ValueError(f"{motor} should be int or str.")
raise ValueError(f"{motor} should be int, str or Motor.")
def read(
self, data_name: str, motors: str | int | list[str | int] | None = None, num_retry: int = 1
self, data_name: str, motors: MotorLike | list[MotorLike] | None = None, num_retry: int = 1
) -> dict[str, float]:
if not self.is_connected:
raise DeviceNotConnectedError(
@ -472,17 +483,17 @@ class MotorsBus(abc.ABC):
start_time = time.perf_counter()
if motors is None:
motors = self.motor_ids
motors = self.ids
if isinstance(motors, (str, int)):
motors = [motors]
motor_ids = [self.get_safe_id(motor) for motor in motors]
motor_ids = [self.get_motor_id(motor) for motor in motors]
if self._has_different_ctrl_tables:
models = [self.idx_to_model(idx) for idx in motor_ids]
models = [self.id_to_model(idx) for idx in motor_ids]
assert_same_address(self.model_ctrl_table, models, data_name)
model = self.idx_to_model(next(iter(motor_ids)))
model = self.id_to_model(next(iter(motor_ids)))
addr, n_bytes = self.model_ctrl_table[model][data_name]
comm, ids_values = self._read(motor_ids, addr, n_bytes, num_retry)
@ -496,7 +507,7 @@ class MotorsBus(abc.ABC):
ids_values = self.calibrate_values(ids_values)
# TODO(aliberts): return keys in the same format we got them?
ids_values = {self.idx_to_name(idx): val for idx, val in ids_values.items()}
ids_values = {self.id_to_name(idx): val for idx, val in ids_values.items()}
# log the number of seconds it took to read the data from the motors
delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_ids)
@ -540,7 +551,7 @@ class MotorsBus(abc.ABC):
# for idx in motor_ids:
# value = self.reader.getData(idx, address, n_bytes)
def write(self, data_name: str, values_dict: dict[str | int, int], num_retry: int = 1) -> None:
def write(self, data_name: str, values_dict: dict[MotorLike, int], num_retry: int = 1) -> None:
if not self.is_connected:
raise DeviceNotConnectedError(
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
@ -548,16 +559,16 @@ class MotorsBus(abc.ABC):
start_time = time.perf_counter()
ids_values = {self.get_safe_id(motor): val for motor, val in values_dict.items()}
ids_values = {self.get_motor_id(motor): val for motor, val in values_dict.items()}
if self._has_different_ctrl_tables:
models = [self.idx_to_model(idx) for idx in ids_values]
models = [self.id_to_model(idx) for idx in ids_values]
assert_same_address(self.model_ctrl_table, models, data_name)
if data_name in self.calibration_required and self.calibration is not None:
ids_values = self.uncalibrate_values(ids_values)
model = self.idx_to_model(next(iter(ids_values)))
model = self.id_to_model(next(iter(ids_values)))
addr, n_bytes = self.model_ctrl_table[model][data_name]
comm = self._write(ids_values, addr, n_bytes, num_retry)

View File

@ -4,6 +4,7 @@ from unittest.mock import patch
import dynamixel_sdk as dxl
import pytest
from lerobot.common.motors import Motor
from lerobot.common.motors.dynamixel import DynamixelMotorsBus
from tests.mocks.mock_dynamixel import MockMotors, MockPortHandler
@ -17,6 +18,15 @@ def patch_port_handler():
yield
@pytest.fixture
def dummy_motors() -> dict[str, Motor]:
return {
"dummy_1": Motor(id=1, model="xl430-w250"),
"dummy_2": Motor(id=2, model="xm540-w270"),
"dummy_3": Motor(id=3, model="xl330-m077"),
}
@pytest.mark.skipif(sys.platform != "darwin", reason=f"No patching needed on {sys.platform=}")
def test_autouse_patch():
"""Ensures that the autouse fixture correctly patches dxl.PortHandler with MockPortHandler."""
@ -68,8 +78,9 @@ def test_split_int_bytes_large_number():
DynamixelMotorsBus.split_int_bytes(2**32, 4) # 4-byte max is 0xFFFFFFFF
def test_abc_implementation():
def test_abc_implementation(dummy_motors):
"""Instantiation should raise an error if the class doesn't implement abstract methods/properties."""
DynamixelMotorsBus(port="/dev/dummy-port", motors=dummy_motors)
DynamixelMotorsBus(port="/dev/dummy-port", motors={"dummy": (1, "xl330-m077")})
@ -83,17 +94,13 @@ def test_abc_implementation():
],
ids=["None", "by ids", "by names", "mixed"],
)
def test_read_all_motors(motors):
def test_read_all_motors(motors, dummy_motors):
mock_motors = MockMotors([1, 2, 3])
positions = [1337, 42, 4016]
mock_motors.build_all_motors_stub("Present_Position", return_values=positions)
mock_motors.build_sync_read_all_motors_stub("Present_Position", return_values=positions)
motors_bus = DynamixelMotorsBus(
port=mock_motors.port,
motors={
"dummy_1": (1, "xl330-m077"),
"dummy_2": (2, "xl330-m077"),
"dummy_3": (3, "xl330-m077"),
},
motors=dummy_motors,
)
motors_bus.connect()
@ -113,16 +120,12 @@ def test_read_all_motors(motors):
[3, 4016],
],
)
def test_read_single_motor_by_name(idx, pos):
def test_read_single_motor_by_name(idx, pos, dummy_motors):
mock_motors = MockMotors([1, 2, 3])
mock_motors.build_single_motor_stubs("Present_Position", return_value=pos)
mock_motors.build_sync_read_single_motor_stubs("Present_Position", return_value=pos)
motors_bus = DynamixelMotorsBus(
port=mock_motors.port,
motors={
"dummy_1": (1, "xl330-m077"),
"dummy_2": (2, "xl330-m077"),
"dummy_3": (3, "xl330-m077"),
},
motors=dummy_motors,
)
motors_bus.connect()
@ -141,16 +144,12 @@ def test_read_single_motor_by_name(idx, pos):
[3, 4016],
],
)
def test_read_single_motor_by_id(idx, pos):
def test_read_single_motor_by_id(idx, pos, dummy_motors):
mock_motors = MockMotors([1, 2, 3])
mock_motors.build_single_motor_stubs("Present_Position", return_value=pos)
mock_motors.build_sync_read_single_motor_stubs("Present_Position", return_value=pos)
motors_bus = DynamixelMotorsBus(
port=mock_motors.port,
motors={
"dummy_1": (1, "xl330-m077"),
"dummy_2": (2, "xl330-m077"),
"dummy_3": (3, "xl330-m077"),
},
motors=dummy_motors,
)
motors_bus.connect()
@ -170,18 +169,14 @@ def test_read_single_motor_by_id(idx, pos):
[2, 1, 999],
],
)
def test_read_num_retry(num_retry, num_invalid_try, pos):
def test_read_num_retry(num_retry, num_invalid_try, pos, dummy_motors):
mock_motors = MockMotors([1, 2, 3])
mock_motors.build_single_motor_stubs(
mock_motors.build_sync_read_single_motor_stubs(
"Present_Position", return_value=pos, num_invalid_try=num_invalid_try
)
motors_bus = DynamixelMotorsBus(
port=mock_motors.port,
motors={
"dummy_1": (1, "xl330-m077"),
"dummy_2": (2, "xl330-m077"),
"dummy_3": (3, "xl330-m077"),
},
motors=dummy_motors,
)
motors_bus.connect()