Add Motor class
This commit is contained in:
parent
56c04ffc53
commit
a32081757d
|
@ -1,3 +1 @@
|
|||
from .motors_bus import CalibrationMode, DriveMode, MotorsBus, TorqueMode
|
||||
|
||||
__all__ = ["CalibrationMode", "DriveMode", "MotorsBus", "TorqueMode"]
|
||||
from .motors_bus import CalibrationMode, DriveMode, Motor, MotorsBus, TorqueMode
|
||||
|
|
|
@ -22,11 +22,12 @@
|
|||
import abc
|
||||
import json
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
from typing import Protocol
|
||||
from typing import Protocol, TypeAlias, Union
|
||||
|
||||
import serial
|
||||
import tqdm
|
||||
|
@ -35,6 +36,8 @@ from deepdiff import DeepDiff
|
|||
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.common.utils.utils import capture_timestamp_utc
|
||||
|
||||
MotorLike: TypeAlias = Union[str, int, "Motor"]
|
||||
|
||||
MAX_ID_RANGE = 252
|
||||
|
||||
|
||||
|
@ -197,6 +200,12 @@ class GroupSyncWrite(Protocol):
|
|||
def txPacket(self): ...
|
||||
|
||||
|
||||
@dataclass
|
||||
class Motor:
|
||||
id: int
|
||||
model: str
|
||||
|
||||
|
||||
class MotorsBus(abc.ABC):
|
||||
"""The main LeRobot class for implementing motors buses.
|
||||
|
||||
|
@ -248,7 +257,7 @@ class MotorsBus(abc.ABC):
|
|||
def __init__(
|
||||
self,
|
||||
port: str,
|
||||
motors: dict[str, tuple[int, str]],
|
||||
motors: dict[str, Motor],
|
||||
):
|
||||
self.port = port
|
||||
self.motors = motors
|
||||
|
@ -262,8 +271,8 @@ class MotorsBus(abc.ABC):
|
|||
self.logs = {} # TODO(aliberts): use subclass logger
|
||||
self.calibration = None
|
||||
|
||||
self._id_to_model = dict(self.motors.values())
|
||||
self._id_to_name = {idx: name for name, (idx, _) in self.motors.items()}
|
||||
self._id_to_model = {m.id: m.model for m in self.motors.values()}
|
||||
self._id_to_name = {m.id: name for name, m in self.motors.items()}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.motors)
|
||||
|
@ -278,39 +287,39 @@ class MotorsBus(abc.ABC):
|
|||
|
||||
@cached_property
|
||||
def _has_different_ctrl_tables(self) -> bool:
|
||||
if len(self.motor_models) < 2:
|
||||
if len(self.models) < 2:
|
||||
return False
|
||||
|
||||
first_table = self.model_ctrl_table[self.motor_models[0]]
|
||||
return any(DeepDiff(first_table, self.model_ctrl_table[model]) for model in self.motor_models[1:])
|
||||
first_table = self.model_ctrl_table[self.models[0]]
|
||||
return any(DeepDiff(first_table, self.model_ctrl_table[model]) for model in self.models[1:])
|
||||
|
||||
def idx_to_model(self, idx: int) -> str:
|
||||
return self._id_to_model[idx]
|
||||
def id_to_model(self, motor_id: int) -> str:
|
||||
return self._id_to_model[motor_id]
|
||||
|
||||
def idx_to_name(self, idx: int) -> str:
|
||||
return self._id_to_name[idx]
|
||||
def id_to_name(self, motor_id: int) -> str:
|
||||
return self._id_to_name[motor_id]
|
||||
|
||||
@cached_property
|
||||
def motor_names(self) -> list[str]:
|
||||
def names(self) -> list[str]:
|
||||
return list(self.motors)
|
||||
|
||||
@cached_property
|
||||
def motor_models(self) -> list[str]:
|
||||
return [model for _, model in self.motors.values()]
|
||||
def models(self) -> list[str]:
|
||||
return [m.model for m in self.motors.values()]
|
||||
|
||||
@cached_property
|
||||
def motor_ids(self) -> list[int]:
|
||||
return [idx for idx, _ in self.motors.values()]
|
||||
def ids(self) -> list[int]:
|
||||
return [m.id for m in self.motors.values()]
|
||||
|
||||
def _validate_motors(self) -> None:
|
||||
# TODO(aliberts): improve error messages for this (display problematics values)
|
||||
if len(self.motor_ids) != len(set(self.motor_ids)):
|
||||
if len(self.ids) != len(set(self.ids)):
|
||||
raise ValueError("Some motors have the same id.")
|
||||
|
||||
if len(self.motor_names) != len(set(self.motor_names)):
|
||||
if len(self.names) != len(set(self.names)):
|
||||
raise ValueError("Some motors have the same name.")
|
||||
|
||||
if any(model not in self.model_resolution_table for model in self.motor_models):
|
||||
if any(model not in self.model_resolution_table for model in self.models):
|
||||
raise ValueError("Some motors models are not available.")
|
||||
|
||||
@property
|
||||
|
@ -347,13 +356,13 @@ class MotorsBus(abc.ABC):
|
|||
"""
|
||||
try:
|
||||
# TODO(aliberts): use ping instead
|
||||
return (self.motor_ids == self.read("ID")).all()
|
||||
return (self.ids == self.read("ID")).all()
|
||||
except ConnectionError as e:
|
||||
print(e)
|
||||
return False
|
||||
|
||||
def ping(self, motor: str | int, num_retry: int | None = None) -> int:
|
||||
idx = self.get_safe_id(motor)
|
||||
def ping(self, motor: MotorLike, num_retry: int | None = None) -> int:
|
||||
idx = self.get_motor_id(motor)
|
||||
for _ in range(num_retry):
|
||||
model_number, comm, _ = self.packet_handler.ping(self.port, idx)
|
||||
if self._is_comm_success(comm):
|
||||
|
@ -453,16 +462,18 @@ class MotorsBus(abc.ABC):
|
|||
"""
|
||||
pass
|
||||
|
||||
def get_safe_id(self, motor: str | int) -> int:
|
||||
def get_motor_id(self, motor: MotorLike) -> int:
|
||||
if isinstance(motor, str):
|
||||
return self.motors[motor][0]
|
||||
return self.motors[motor].id
|
||||
elif isinstance(motor, int):
|
||||
return motor
|
||||
elif isinstance(motor, Motor):
|
||||
return motor.id
|
||||
else:
|
||||
raise ValueError(f"{motor} should be int or str.")
|
||||
raise ValueError(f"{motor} should be int, str or Motor.")
|
||||
|
||||
def read(
|
||||
self, data_name: str, motors: str | int | list[str | int] | None = None, num_retry: int = 1
|
||||
self, data_name: str, motors: MotorLike | list[MotorLike] | None = None, num_retry: int = 1
|
||||
) -> dict[str, float]:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
|
@ -472,17 +483,17 @@ class MotorsBus(abc.ABC):
|
|||
start_time = time.perf_counter()
|
||||
|
||||
if motors is None:
|
||||
motors = self.motor_ids
|
||||
motors = self.ids
|
||||
|
||||
if isinstance(motors, (str, int)):
|
||||
motors = [motors]
|
||||
|
||||
motor_ids = [self.get_safe_id(motor) for motor in motors]
|
||||
motor_ids = [self.get_motor_id(motor) for motor in motors]
|
||||
if self._has_different_ctrl_tables:
|
||||
models = [self.idx_to_model(idx) for idx in motor_ids]
|
||||
models = [self.id_to_model(idx) for idx in motor_ids]
|
||||
assert_same_address(self.model_ctrl_table, models, data_name)
|
||||
|
||||
model = self.idx_to_model(next(iter(motor_ids)))
|
||||
model = self.id_to_model(next(iter(motor_ids)))
|
||||
addr, n_bytes = self.model_ctrl_table[model][data_name]
|
||||
|
||||
comm, ids_values = self._read(motor_ids, addr, n_bytes, num_retry)
|
||||
|
@ -496,7 +507,7 @@ class MotorsBus(abc.ABC):
|
|||
ids_values = self.calibrate_values(ids_values)
|
||||
|
||||
# TODO(aliberts): return keys in the same format we got them?
|
||||
ids_values = {self.idx_to_name(idx): val for idx, val in ids_values.items()}
|
||||
ids_values = {self.id_to_name(idx): val for idx, val in ids_values.items()}
|
||||
|
||||
# log the number of seconds it took to read the data from the motors
|
||||
delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_ids)
|
||||
|
@ -540,7 +551,7 @@ class MotorsBus(abc.ABC):
|
|||
# for idx in motor_ids:
|
||||
# value = self.reader.getData(idx, address, n_bytes)
|
||||
|
||||
def write(self, data_name: str, values_dict: dict[str | int, int], num_retry: int = 1) -> None:
|
||||
def write(self, data_name: str, values_dict: dict[MotorLike, int], num_retry: int = 1) -> None:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
|
||||
|
@ -548,16 +559,16 @@ class MotorsBus(abc.ABC):
|
|||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
ids_values = {self.get_safe_id(motor): val for motor, val in values_dict.items()}
|
||||
ids_values = {self.get_motor_id(motor): val for motor, val in values_dict.items()}
|
||||
|
||||
if self._has_different_ctrl_tables:
|
||||
models = [self.idx_to_model(idx) for idx in ids_values]
|
||||
models = [self.id_to_model(idx) for idx in ids_values]
|
||||
assert_same_address(self.model_ctrl_table, models, data_name)
|
||||
|
||||
if data_name in self.calibration_required and self.calibration is not None:
|
||||
ids_values = self.uncalibrate_values(ids_values)
|
||||
|
||||
model = self.idx_to_model(next(iter(ids_values)))
|
||||
model = self.id_to_model(next(iter(ids_values)))
|
||||
addr, n_bytes = self.model_ctrl_table[model][data_name]
|
||||
|
||||
comm = self._write(ids_values, addr, n_bytes, num_retry)
|
||||
|
|
|
@ -4,6 +4,7 @@ from unittest.mock import patch
|
|||
import dynamixel_sdk as dxl
|
||||
import pytest
|
||||
|
||||
from lerobot.common.motors import Motor
|
||||
from lerobot.common.motors.dynamixel import DynamixelMotorsBus
|
||||
from tests.mocks.mock_dynamixel import MockMotors, MockPortHandler
|
||||
|
||||
|
@ -17,6 +18,15 @@ def patch_port_handler():
|
|||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_motors() -> dict[str, Motor]:
|
||||
return {
|
||||
"dummy_1": Motor(id=1, model="xl430-w250"),
|
||||
"dummy_2": Motor(id=2, model="xm540-w270"),
|
||||
"dummy_3": Motor(id=3, model="xl330-m077"),
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform != "darwin", reason=f"No patching needed on {sys.platform=}")
|
||||
def test_autouse_patch():
|
||||
"""Ensures that the autouse fixture correctly patches dxl.PortHandler with MockPortHandler."""
|
||||
|
@ -68,8 +78,9 @@ def test_split_int_bytes_large_number():
|
|||
DynamixelMotorsBus.split_int_bytes(2**32, 4) # 4-byte max is 0xFFFFFFFF
|
||||
|
||||
|
||||
def test_abc_implementation():
|
||||
def test_abc_implementation(dummy_motors):
|
||||
"""Instantiation should raise an error if the class doesn't implement abstract methods/properties."""
|
||||
DynamixelMotorsBus(port="/dev/dummy-port", motors=dummy_motors)
|
||||
DynamixelMotorsBus(port="/dev/dummy-port", motors={"dummy": (1, "xl330-m077")})
|
||||
|
||||
|
||||
|
@ -83,17 +94,13 @@ def test_abc_implementation():
|
|||
],
|
||||
ids=["None", "by ids", "by names", "mixed"],
|
||||
)
|
||||
def test_read_all_motors(motors):
|
||||
def test_read_all_motors(motors, dummy_motors):
|
||||
mock_motors = MockMotors([1, 2, 3])
|
||||
positions = [1337, 42, 4016]
|
||||
mock_motors.build_all_motors_stub("Present_Position", return_values=positions)
|
||||
mock_motors.build_sync_read_all_motors_stub("Present_Position", return_values=positions)
|
||||
motors_bus = DynamixelMotorsBus(
|
||||
port=mock_motors.port,
|
||||
motors={
|
||||
"dummy_1": (1, "xl330-m077"),
|
||||
"dummy_2": (2, "xl330-m077"),
|
||||
"dummy_3": (3, "xl330-m077"),
|
||||
},
|
||||
motors=dummy_motors,
|
||||
)
|
||||
motors_bus.connect()
|
||||
|
||||
|
@ -113,16 +120,12 @@ def test_read_all_motors(motors):
|
|||
[3, 4016],
|
||||
],
|
||||
)
|
||||
def test_read_single_motor_by_name(idx, pos):
|
||||
def test_read_single_motor_by_name(idx, pos, dummy_motors):
|
||||
mock_motors = MockMotors([1, 2, 3])
|
||||
mock_motors.build_single_motor_stubs("Present_Position", return_value=pos)
|
||||
mock_motors.build_sync_read_single_motor_stubs("Present_Position", return_value=pos)
|
||||
motors_bus = DynamixelMotorsBus(
|
||||
port=mock_motors.port,
|
||||
motors={
|
||||
"dummy_1": (1, "xl330-m077"),
|
||||
"dummy_2": (2, "xl330-m077"),
|
||||
"dummy_3": (3, "xl330-m077"),
|
||||
},
|
||||
motors=dummy_motors,
|
||||
)
|
||||
motors_bus.connect()
|
||||
|
||||
|
@ -141,16 +144,12 @@ def test_read_single_motor_by_name(idx, pos):
|
|||
[3, 4016],
|
||||
],
|
||||
)
|
||||
def test_read_single_motor_by_id(idx, pos):
|
||||
def test_read_single_motor_by_id(idx, pos, dummy_motors):
|
||||
mock_motors = MockMotors([1, 2, 3])
|
||||
mock_motors.build_single_motor_stubs("Present_Position", return_value=pos)
|
||||
mock_motors.build_sync_read_single_motor_stubs("Present_Position", return_value=pos)
|
||||
motors_bus = DynamixelMotorsBus(
|
||||
port=mock_motors.port,
|
||||
motors={
|
||||
"dummy_1": (1, "xl330-m077"),
|
||||
"dummy_2": (2, "xl330-m077"),
|
||||
"dummy_3": (3, "xl330-m077"),
|
||||
},
|
||||
motors=dummy_motors,
|
||||
)
|
||||
motors_bus.connect()
|
||||
|
||||
|
@ -170,18 +169,14 @@ def test_read_single_motor_by_id(idx, pos):
|
|||
[2, 1, 999],
|
||||
],
|
||||
)
|
||||
def test_read_num_retry(num_retry, num_invalid_try, pos):
|
||||
def test_read_num_retry(num_retry, num_invalid_try, pos, dummy_motors):
|
||||
mock_motors = MockMotors([1, 2, 3])
|
||||
mock_motors.build_single_motor_stubs(
|
||||
mock_motors.build_sync_read_single_motor_stubs(
|
||||
"Present_Position", return_value=pos, num_invalid_try=num_invalid_try
|
||||
)
|
||||
motors_bus = DynamixelMotorsBus(
|
||||
port=mock_motors.port,
|
||||
motors={
|
||||
"dummy_1": (1, "xl330-m077"),
|
||||
"dummy_2": (2, "xl330-m077"),
|
||||
"dummy_3": (3, "xl330-m077"),
|
||||
},
|
||||
motors=dummy_motors,
|
||||
)
|
||||
motors_bus.connect()
|
||||
|
||||
|
|
Loading…
Reference in New Issue