Add Motor class
This commit is contained in:
parent
56c04ffc53
commit
a32081757d
|
@ -1,3 +1 @@
|
||||||
from .motors_bus import CalibrationMode, DriveMode, MotorsBus, TorqueMode
|
from .motors_bus import CalibrationMode, DriveMode, Motor, MotorsBus, TorqueMode
|
||||||
|
|
||||||
__all__ = ["CalibrationMode", "DriveMode", "MotorsBus", "TorqueMode"]
|
|
||||||
|
|
|
@ -22,11 +22,12 @@
|
||||||
import abc
|
import abc
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
from typing import Protocol
|
from typing import Protocol, TypeAlias, Union
|
||||||
|
|
||||||
import serial
|
import serial
|
||||||
import tqdm
|
import tqdm
|
||||||
|
@ -35,6 +36,8 @@ from deepdiff import DeepDiff
|
||||||
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||||
from lerobot.common.utils.utils import capture_timestamp_utc
|
from lerobot.common.utils.utils import capture_timestamp_utc
|
||||||
|
|
||||||
|
MotorLike: TypeAlias = Union[str, int, "Motor"]
|
||||||
|
|
||||||
MAX_ID_RANGE = 252
|
MAX_ID_RANGE = 252
|
||||||
|
|
||||||
|
|
||||||
|
@ -197,6 +200,12 @@ class GroupSyncWrite(Protocol):
|
||||||
def txPacket(self): ...
|
def txPacket(self): ...
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Motor:
|
||||||
|
id: int
|
||||||
|
model: str
|
||||||
|
|
||||||
|
|
||||||
class MotorsBus(abc.ABC):
|
class MotorsBus(abc.ABC):
|
||||||
"""The main LeRobot class for implementing motors buses.
|
"""The main LeRobot class for implementing motors buses.
|
||||||
|
|
||||||
|
@ -248,7 +257,7 @@ class MotorsBus(abc.ABC):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
port: str,
|
port: str,
|
||||||
motors: dict[str, tuple[int, str]],
|
motors: dict[str, Motor],
|
||||||
):
|
):
|
||||||
self.port = port
|
self.port = port
|
||||||
self.motors = motors
|
self.motors = motors
|
||||||
|
@ -262,8 +271,8 @@ class MotorsBus(abc.ABC):
|
||||||
self.logs = {} # TODO(aliberts): use subclass logger
|
self.logs = {} # TODO(aliberts): use subclass logger
|
||||||
self.calibration = None
|
self.calibration = None
|
||||||
|
|
||||||
self._id_to_model = dict(self.motors.values())
|
self._id_to_model = {m.id: m.model for m in self.motors.values()}
|
||||||
self._id_to_name = {idx: name for name, (idx, _) in self.motors.items()}
|
self._id_to_name = {m.id: name for name, m in self.motors.items()}
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.motors)
|
return len(self.motors)
|
||||||
|
@ -278,39 +287,39 @@ class MotorsBus(abc.ABC):
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def _has_different_ctrl_tables(self) -> bool:
|
def _has_different_ctrl_tables(self) -> bool:
|
||||||
if len(self.motor_models) < 2:
|
if len(self.models) < 2:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
first_table = self.model_ctrl_table[self.motor_models[0]]
|
first_table = self.model_ctrl_table[self.models[0]]
|
||||||
return any(DeepDiff(first_table, self.model_ctrl_table[model]) for model in self.motor_models[1:])
|
return any(DeepDiff(first_table, self.model_ctrl_table[model]) for model in self.models[1:])
|
||||||
|
|
||||||
def idx_to_model(self, idx: int) -> str:
|
def id_to_model(self, motor_id: int) -> str:
|
||||||
return self._id_to_model[idx]
|
return self._id_to_model[motor_id]
|
||||||
|
|
||||||
def idx_to_name(self, idx: int) -> str:
|
def id_to_name(self, motor_id: int) -> str:
|
||||||
return self._id_to_name[idx]
|
return self._id_to_name[motor_id]
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def motor_names(self) -> list[str]:
|
def names(self) -> list[str]:
|
||||||
return list(self.motors)
|
return list(self.motors)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def motor_models(self) -> list[str]:
|
def models(self) -> list[str]:
|
||||||
return [model for _, model in self.motors.values()]
|
return [m.model for m in self.motors.values()]
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def motor_ids(self) -> list[int]:
|
def ids(self) -> list[int]:
|
||||||
return [idx for idx, _ in self.motors.values()]
|
return [m.id for m in self.motors.values()]
|
||||||
|
|
||||||
def _validate_motors(self) -> None:
|
def _validate_motors(self) -> None:
|
||||||
# TODO(aliberts): improve error messages for this (display problematics values)
|
# TODO(aliberts): improve error messages for this (display problematics values)
|
||||||
if len(self.motor_ids) != len(set(self.motor_ids)):
|
if len(self.ids) != len(set(self.ids)):
|
||||||
raise ValueError("Some motors have the same id.")
|
raise ValueError("Some motors have the same id.")
|
||||||
|
|
||||||
if len(self.motor_names) != len(set(self.motor_names)):
|
if len(self.names) != len(set(self.names)):
|
||||||
raise ValueError("Some motors have the same name.")
|
raise ValueError("Some motors have the same name.")
|
||||||
|
|
||||||
if any(model not in self.model_resolution_table for model in self.motor_models):
|
if any(model not in self.model_resolution_table for model in self.models):
|
||||||
raise ValueError("Some motors models are not available.")
|
raise ValueError("Some motors models are not available.")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -347,13 +356,13 @@ class MotorsBus(abc.ABC):
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# TODO(aliberts): use ping instead
|
# TODO(aliberts): use ping instead
|
||||||
return (self.motor_ids == self.read("ID")).all()
|
return (self.ids == self.read("ID")).all()
|
||||||
except ConnectionError as e:
|
except ConnectionError as e:
|
||||||
print(e)
|
print(e)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def ping(self, motor: str | int, num_retry: int | None = None) -> int:
|
def ping(self, motor: MotorLike, num_retry: int | None = None) -> int:
|
||||||
idx = self.get_safe_id(motor)
|
idx = self.get_motor_id(motor)
|
||||||
for _ in range(num_retry):
|
for _ in range(num_retry):
|
||||||
model_number, comm, _ = self.packet_handler.ping(self.port, idx)
|
model_number, comm, _ = self.packet_handler.ping(self.port, idx)
|
||||||
if self._is_comm_success(comm):
|
if self._is_comm_success(comm):
|
||||||
|
@ -453,16 +462,18 @@ class MotorsBus(abc.ABC):
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_safe_id(self, motor: str | int) -> int:
|
def get_motor_id(self, motor: MotorLike) -> int:
|
||||||
if isinstance(motor, str):
|
if isinstance(motor, str):
|
||||||
return self.motors[motor][0]
|
return self.motors[motor].id
|
||||||
elif isinstance(motor, int):
|
elif isinstance(motor, int):
|
||||||
return motor
|
return motor
|
||||||
|
elif isinstance(motor, Motor):
|
||||||
|
return motor.id
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"{motor} should be int or str.")
|
raise ValueError(f"{motor} should be int, str or Motor.")
|
||||||
|
|
||||||
def read(
|
def read(
|
||||||
self, data_name: str, motors: str | int | list[str | int] | None = None, num_retry: int = 1
|
self, data_name: str, motors: MotorLike | list[MotorLike] | None = None, num_retry: int = 1
|
||||||
) -> dict[str, float]:
|
) -> dict[str, float]:
|
||||||
if not self.is_connected:
|
if not self.is_connected:
|
||||||
raise DeviceNotConnectedError(
|
raise DeviceNotConnectedError(
|
||||||
|
@ -472,17 +483,17 @@ class MotorsBus(abc.ABC):
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
if motors is None:
|
if motors is None:
|
||||||
motors = self.motor_ids
|
motors = self.ids
|
||||||
|
|
||||||
if isinstance(motors, (str, int)):
|
if isinstance(motors, (str, int)):
|
||||||
motors = [motors]
|
motors = [motors]
|
||||||
|
|
||||||
motor_ids = [self.get_safe_id(motor) for motor in motors]
|
motor_ids = [self.get_motor_id(motor) for motor in motors]
|
||||||
if self._has_different_ctrl_tables:
|
if self._has_different_ctrl_tables:
|
||||||
models = [self.idx_to_model(idx) for idx in motor_ids]
|
models = [self.id_to_model(idx) for idx in motor_ids]
|
||||||
assert_same_address(self.model_ctrl_table, models, data_name)
|
assert_same_address(self.model_ctrl_table, models, data_name)
|
||||||
|
|
||||||
model = self.idx_to_model(next(iter(motor_ids)))
|
model = self.id_to_model(next(iter(motor_ids)))
|
||||||
addr, n_bytes = self.model_ctrl_table[model][data_name]
|
addr, n_bytes = self.model_ctrl_table[model][data_name]
|
||||||
|
|
||||||
comm, ids_values = self._read(motor_ids, addr, n_bytes, num_retry)
|
comm, ids_values = self._read(motor_ids, addr, n_bytes, num_retry)
|
||||||
|
@ -496,7 +507,7 @@ class MotorsBus(abc.ABC):
|
||||||
ids_values = self.calibrate_values(ids_values)
|
ids_values = self.calibrate_values(ids_values)
|
||||||
|
|
||||||
# TODO(aliberts): return keys in the same format we got them?
|
# TODO(aliberts): return keys in the same format we got them?
|
||||||
ids_values = {self.idx_to_name(idx): val for idx, val in ids_values.items()}
|
ids_values = {self.id_to_name(idx): val for idx, val in ids_values.items()}
|
||||||
|
|
||||||
# log the number of seconds it took to read the data from the motors
|
# log the number of seconds it took to read the data from the motors
|
||||||
delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_ids)
|
delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_ids)
|
||||||
|
@ -540,7 +551,7 @@ class MotorsBus(abc.ABC):
|
||||||
# for idx in motor_ids:
|
# for idx in motor_ids:
|
||||||
# value = self.reader.getData(idx, address, n_bytes)
|
# value = self.reader.getData(idx, address, n_bytes)
|
||||||
|
|
||||||
def write(self, data_name: str, values_dict: dict[str | int, int], num_retry: int = 1) -> None:
|
def write(self, data_name: str, values_dict: dict[MotorLike, int], num_retry: int = 1) -> None:
|
||||||
if not self.is_connected:
|
if not self.is_connected:
|
||||||
raise DeviceNotConnectedError(
|
raise DeviceNotConnectedError(
|
||||||
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
|
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
|
||||||
|
@ -548,16 +559,16 @@ class MotorsBus(abc.ABC):
|
||||||
|
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
ids_values = {self.get_safe_id(motor): val for motor, val in values_dict.items()}
|
ids_values = {self.get_motor_id(motor): val for motor, val in values_dict.items()}
|
||||||
|
|
||||||
if self._has_different_ctrl_tables:
|
if self._has_different_ctrl_tables:
|
||||||
models = [self.idx_to_model(idx) for idx in ids_values]
|
models = [self.id_to_model(idx) for idx in ids_values]
|
||||||
assert_same_address(self.model_ctrl_table, models, data_name)
|
assert_same_address(self.model_ctrl_table, models, data_name)
|
||||||
|
|
||||||
if data_name in self.calibration_required and self.calibration is not None:
|
if data_name in self.calibration_required and self.calibration is not None:
|
||||||
ids_values = self.uncalibrate_values(ids_values)
|
ids_values = self.uncalibrate_values(ids_values)
|
||||||
|
|
||||||
model = self.idx_to_model(next(iter(ids_values)))
|
model = self.id_to_model(next(iter(ids_values)))
|
||||||
addr, n_bytes = self.model_ctrl_table[model][data_name]
|
addr, n_bytes = self.model_ctrl_table[model][data_name]
|
||||||
|
|
||||||
comm = self._write(ids_values, addr, n_bytes, num_retry)
|
comm = self._write(ids_values, addr, n_bytes, num_retry)
|
||||||
|
|
|
@ -4,6 +4,7 @@ from unittest.mock import patch
|
||||||
import dynamixel_sdk as dxl
|
import dynamixel_sdk as dxl
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from lerobot.common.motors import Motor
|
||||||
from lerobot.common.motors.dynamixel import DynamixelMotorsBus
|
from lerobot.common.motors.dynamixel import DynamixelMotorsBus
|
||||||
from tests.mocks.mock_dynamixel import MockMotors, MockPortHandler
|
from tests.mocks.mock_dynamixel import MockMotors, MockPortHandler
|
||||||
|
|
||||||
|
@ -17,6 +18,15 @@ def patch_port_handler():
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def dummy_motors() -> dict[str, Motor]:
|
||||||
|
return {
|
||||||
|
"dummy_1": Motor(id=1, model="xl430-w250"),
|
||||||
|
"dummy_2": Motor(id=2, model="xm540-w270"),
|
||||||
|
"dummy_3": Motor(id=3, model="xl330-m077"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(sys.platform != "darwin", reason=f"No patching needed on {sys.platform=}")
|
@pytest.mark.skipif(sys.platform != "darwin", reason=f"No patching needed on {sys.platform=}")
|
||||||
def test_autouse_patch():
|
def test_autouse_patch():
|
||||||
"""Ensures that the autouse fixture correctly patches dxl.PortHandler with MockPortHandler."""
|
"""Ensures that the autouse fixture correctly patches dxl.PortHandler with MockPortHandler."""
|
||||||
|
@ -68,8 +78,9 @@ def test_split_int_bytes_large_number():
|
||||||
DynamixelMotorsBus.split_int_bytes(2**32, 4) # 4-byte max is 0xFFFFFFFF
|
DynamixelMotorsBus.split_int_bytes(2**32, 4) # 4-byte max is 0xFFFFFFFF
|
||||||
|
|
||||||
|
|
||||||
def test_abc_implementation():
|
def test_abc_implementation(dummy_motors):
|
||||||
"""Instantiation should raise an error if the class doesn't implement abstract methods/properties."""
|
"""Instantiation should raise an error if the class doesn't implement abstract methods/properties."""
|
||||||
|
DynamixelMotorsBus(port="/dev/dummy-port", motors=dummy_motors)
|
||||||
DynamixelMotorsBus(port="/dev/dummy-port", motors={"dummy": (1, "xl330-m077")})
|
DynamixelMotorsBus(port="/dev/dummy-port", motors={"dummy": (1, "xl330-m077")})
|
||||||
|
|
||||||
|
|
||||||
|
@ -83,17 +94,13 @@ def test_abc_implementation():
|
||||||
],
|
],
|
||||||
ids=["None", "by ids", "by names", "mixed"],
|
ids=["None", "by ids", "by names", "mixed"],
|
||||||
)
|
)
|
||||||
def test_read_all_motors(motors):
|
def test_read_all_motors(motors, dummy_motors):
|
||||||
mock_motors = MockMotors([1, 2, 3])
|
mock_motors = MockMotors([1, 2, 3])
|
||||||
positions = [1337, 42, 4016]
|
positions = [1337, 42, 4016]
|
||||||
mock_motors.build_all_motors_stub("Present_Position", return_values=positions)
|
mock_motors.build_sync_read_all_motors_stub("Present_Position", return_values=positions)
|
||||||
motors_bus = DynamixelMotorsBus(
|
motors_bus = DynamixelMotorsBus(
|
||||||
port=mock_motors.port,
|
port=mock_motors.port,
|
||||||
motors={
|
motors=dummy_motors,
|
||||||
"dummy_1": (1, "xl330-m077"),
|
|
||||||
"dummy_2": (2, "xl330-m077"),
|
|
||||||
"dummy_3": (3, "xl330-m077"),
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
motors_bus.connect()
|
motors_bus.connect()
|
||||||
|
|
||||||
|
@ -113,16 +120,12 @@ def test_read_all_motors(motors):
|
||||||
[3, 4016],
|
[3, 4016],
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_read_single_motor_by_name(idx, pos):
|
def test_read_single_motor_by_name(idx, pos, dummy_motors):
|
||||||
mock_motors = MockMotors([1, 2, 3])
|
mock_motors = MockMotors([1, 2, 3])
|
||||||
mock_motors.build_single_motor_stubs("Present_Position", return_value=pos)
|
mock_motors.build_sync_read_single_motor_stubs("Present_Position", return_value=pos)
|
||||||
motors_bus = DynamixelMotorsBus(
|
motors_bus = DynamixelMotorsBus(
|
||||||
port=mock_motors.port,
|
port=mock_motors.port,
|
||||||
motors={
|
motors=dummy_motors,
|
||||||
"dummy_1": (1, "xl330-m077"),
|
|
||||||
"dummy_2": (2, "xl330-m077"),
|
|
||||||
"dummy_3": (3, "xl330-m077"),
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
motors_bus.connect()
|
motors_bus.connect()
|
||||||
|
|
||||||
|
@ -141,16 +144,12 @@ def test_read_single_motor_by_name(idx, pos):
|
||||||
[3, 4016],
|
[3, 4016],
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_read_single_motor_by_id(idx, pos):
|
def test_read_single_motor_by_id(idx, pos, dummy_motors):
|
||||||
mock_motors = MockMotors([1, 2, 3])
|
mock_motors = MockMotors([1, 2, 3])
|
||||||
mock_motors.build_single_motor_stubs("Present_Position", return_value=pos)
|
mock_motors.build_sync_read_single_motor_stubs("Present_Position", return_value=pos)
|
||||||
motors_bus = DynamixelMotorsBus(
|
motors_bus = DynamixelMotorsBus(
|
||||||
port=mock_motors.port,
|
port=mock_motors.port,
|
||||||
motors={
|
motors=dummy_motors,
|
||||||
"dummy_1": (1, "xl330-m077"),
|
|
||||||
"dummy_2": (2, "xl330-m077"),
|
|
||||||
"dummy_3": (3, "xl330-m077"),
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
motors_bus.connect()
|
motors_bus.connect()
|
||||||
|
|
||||||
|
@ -170,18 +169,14 @@ def test_read_single_motor_by_id(idx, pos):
|
||||||
[2, 1, 999],
|
[2, 1, 999],
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_read_num_retry(num_retry, num_invalid_try, pos):
|
def test_read_num_retry(num_retry, num_invalid_try, pos, dummy_motors):
|
||||||
mock_motors = MockMotors([1, 2, 3])
|
mock_motors = MockMotors([1, 2, 3])
|
||||||
mock_motors.build_single_motor_stubs(
|
mock_motors.build_sync_read_single_motor_stubs(
|
||||||
"Present_Position", return_value=pos, num_invalid_try=num_invalid_try
|
"Present_Position", return_value=pos, num_invalid_try=num_invalid_try
|
||||||
)
|
)
|
||||||
motors_bus = DynamixelMotorsBus(
|
motors_bus = DynamixelMotorsBus(
|
||||||
port=mock_motors.port,
|
port=mock_motors.port,
|
||||||
motors={
|
motors=dummy_motors,
|
||||||
"dummy_1": (1, "xl330-m077"),
|
|
||||||
"dummy_2": (2, "xl330-m077"),
|
|
||||||
"dummy_3": (3, "xl330-m077"),
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
motors_bus.connect()
|
motors_bus.connect()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue