lerobot/tests/motors/test_motors_bus.py

466 lines
15 KiB
Python

# ruff: noqa: N802
import re
from unittest.mock import patch
import pytest
from lerobot.common.motors.motors_bus import (
Motor,
MotorNormMode,
MotorsBus,
assert_same_address,
get_address,
get_ctrl_table,
)
DUMMY_CTRL_TABLE_1 = {
"Firmware_Version": (0, 1),
"Model_Number": (1, 2),
"Present_Position": (3, 4),
"Goal_Position": (11, 2),
}
DUMMY_CTRL_TABLE_2 = {
"Model_Number": (0, 2),
"Firmware_Version": (2, 1),
"Present_Position": (3, 4),
"Present_Velocity": (7, 4),
"Goal_Position": (11, 4),
"Goal_Velocity": (15, 4),
"Lock": (19, 1),
}
DUMMY_MODEL_CTRL_TABLE = {
"model_1": DUMMY_CTRL_TABLE_1,
"model_2": DUMMY_CTRL_TABLE_2,
"model_3": DUMMY_CTRL_TABLE_2,
}
DUMMY_BAUDRATE_TABLE = {
0: 1_000_000,
1: 500_000,
2: 250_000,
}
DUMMY_MODEL_BAUDRATE_TABLE = {
"model_1": DUMMY_BAUDRATE_TABLE,
"model_2": DUMMY_BAUDRATE_TABLE,
"model_3": DUMMY_BAUDRATE_TABLE,
}
DUMMY_ENCODING_TABLE = {
"Present_Position": 8,
"Goal_Position": 10,
}
DUMMY_MODEL_ENCODING_TABLE = {
"model_1": DUMMY_ENCODING_TABLE,
"model_2": DUMMY_ENCODING_TABLE,
"model_3": DUMMY_ENCODING_TABLE,
}
DUMMY_MODEL_NUMBER_TABLE = {
"model_1": 1234,
"model_2": 5678,
"model_3": 5799,
}
DUMMY_MODEL_RESOLUTION_TABLE = {
"model_1": 4096,
"model_2": 1024,
"model_3": 4096,
}
class MockPortHandler:
def __init__(self, port_name):
self.is_open: bool = False
self.baudrate: int
self.packet_start_time: float
self.packet_timeout: float
self.tx_time_per_byte: float
self.is_using: bool = False
self.port_name: str = port_name
self.ser = None
def openPort(self):
self.is_open = True
return self.is_open
def closePort(self):
self.is_open = False
def clearPort(self): ...
def setPortName(self, port_name):
self.port_name = port_name
def getPortName(self):
return self.port_name
def setBaudRate(self, baudrate):
self.baudrate: baudrate
def getBaudRate(self):
return self.baudrate
def getBytesAvailable(self): ...
def readPort(self, length): ...
def writePort(self, packet): ...
def setPacketTimeout(self, packet_length): ...
def setPacketTimeoutMillis(self, msec): ...
def isPacketTimeout(self): ...
def getCurrentTime(self): ...
def getTimeSinceStart(self): ...
def setupPort(self, cflag_baud): ...
def getCFlagBaud(self, baudrate): ...
class MockMotorsBus(MotorsBus):
available_baudrates = [500_000, 1_000_000]
default_timeout = 1000
model_baudrate_table = DUMMY_MODEL_BAUDRATE_TABLE
model_ctrl_table = DUMMY_MODEL_CTRL_TABLE
model_encoding_table = DUMMY_MODEL_ENCODING_TABLE
model_number_table = DUMMY_MODEL_NUMBER_TABLE
model_resolution_table = DUMMY_MODEL_RESOLUTION_TABLE
normalized_data = ["Present_Position", "Goal_Position"]
def __init__(self, port: str, motors: dict[str, Motor]):
super().__init__(port, motors)
self.port_handler = MockPortHandler(port)
def _assert_protocol_is_compatible(self, instruction_name): ...
def configure_motors(self): ...
def disable_torque(self, motors): ...
def enable_torque(self, motors): ...
def _get_half_turn_homings(self, positions): ...
def _encode_sign(self, data_name, ids_values): ...
def _decode_sign(self, data_name, ids_values): ...
def _split_into_byte_chunks(self, value, length): ...
def broadcast_ping(self, num_retry, raise_on_error): ...
@pytest.fixture
def dummy_motors() -> dict[str, Motor]:
return {
"dummy_1": Motor(1, "model_2", MotorNormMode.RANGE_M100_100),
"dummy_2": Motor(2, "model_3", MotorNormMode.RANGE_M100_100),
"dummy_3": Motor(3, "model_2", MotorNormMode.RANGE_0_100),
}
def test_get_ctrl_table():
model = "model_1"
ctrl_table = get_ctrl_table(DUMMY_MODEL_CTRL_TABLE, model)
assert ctrl_table == DUMMY_CTRL_TABLE_1
def test_get_ctrl_table_error():
model = "model_99"
with pytest.raises(KeyError, match=f"Control table for {model=} not found."):
get_ctrl_table(DUMMY_MODEL_CTRL_TABLE, model)
def test_get_address():
addr, n_bytes = get_address(DUMMY_MODEL_CTRL_TABLE, "model_1", "Firmware_Version")
assert addr == 0
assert n_bytes == 1
def test_get_address_error():
model = "model_1"
data_name = "Lock"
with pytest.raises(KeyError, match=f"Address for '{data_name}' not found in {model} control table."):
get_address(DUMMY_MODEL_CTRL_TABLE, "model_1", data_name)
def test_assert_same_address():
models = ["model_1", "model_2"]
assert_same_address(DUMMY_MODEL_CTRL_TABLE, models, "Present_Position")
def test_assert_same_length_different_addresses():
models = ["model_1", "model_2"]
with pytest.raises(
NotImplementedError,
match=re.escape("At least two motor models use a different address"),
):
assert_same_address(DUMMY_MODEL_CTRL_TABLE, models, "Model_Number")
def test_assert_same_address_different_length():
models = ["model_1", "model_2"]
with pytest.raises(
NotImplementedError,
match=re.escape("At least two motor models use a different bytes representation"),
):
assert_same_address(DUMMY_MODEL_CTRL_TABLE, models, "Goal_Position")
def test__serialize_data_invalid_length():
bus = MockMotorsBus("", {})
with pytest.raises(NotImplementedError):
bus._serialize_data(100, 3)
def test__serialize_data_negative_numbers():
bus = MockMotorsBus("", {})
with pytest.raises(ValueError):
bus._serialize_data(-1, 1)
def test__serialize_data_large_number():
bus = MockMotorsBus("", {})
with pytest.raises(ValueError):
bus._serialize_data(2**32, 4) # 4-byte max is 0xFFFFFFFF
@pytest.mark.parametrize(
"data_name, id_, value",
[
("Firmware_Version", 1, 14),
("Model_Number", 1, 5678),
("Present_Position", 2, 1337),
("Present_Velocity", 3, 42),
],
)
def test_read(data_name, id_, value, dummy_motors):
bus = MockMotorsBus("/dev/dummy-port", dummy_motors)
bus.connect(handshake=False)
addr, length = DUMMY_CTRL_TABLE_2[data_name]
with (
patch.object(MockMotorsBus, "_read", return_value=(value, 0, 0)) as mock__read,
patch.object(MockMotorsBus, "_decode_sign", return_value={id_: value}) as mock__decode_sign,
patch.object(MockMotorsBus, "_normalize", return_value={id_: value}) as mock__normalize,
):
returned_value = bus.read(data_name, f"dummy_{id_}")
assert returned_value == value
mock__read.assert_called_once_with(
addr,
length,
id_,
num_retry=0,
raise_on_error=True,
err_msg=f"Failed to read '{data_name}' on {id_=} after 1 tries.",
)
mock__decode_sign.assert_called_once_with(data_name, {id_: value})
if data_name in bus.normalized_data:
mock__normalize.assert_called_once_with(data_name, {id_: value})
@pytest.mark.parametrize(
"data_name, id_, value",
[
("Goal_Position", 1, 1337),
("Goal_Velocity", 2, 3682),
("Lock", 3, 1),
],
)
def test_write(data_name, id_, value, dummy_motors):
bus = MockMotorsBus("/dev/dummy-port", dummy_motors)
bus.connect(handshake=False)
addr, length = DUMMY_CTRL_TABLE_2[data_name]
with (
patch.object(MockMotorsBus, "_write", return_value=(0, 0)) as mock__write,
patch.object(MockMotorsBus, "_encode_sign", return_value={id_: value}) as mock__encode_sign,
patch.object(MockMotorsBus, "_unnormalize", return_value={id_: value}) as mock__unnormalize,
):
bus.write(data_name, f"dummy_{id_}", value)
mock__write.assert_called_once_with(
addr,
length,
id_,
value,
num_retry=0,
raise_on_error=True,
err_msg=f"Failed to write '{data_name}' on {id_=} with '{value}' after 1 tries.",
)
mock__encode_sign.assert_called_once_with(data_name, {id_: value})
if data_name in bus.normalized_data:
mock__unnormalize.assert_called_once_with(data_name, {id_: value})
@pytest.mark.parametrize(
"data_name, id_, value",
[
("Firmware_Version", 1, 14),
("Model_Number", 1, 5678),
("Present_Position", 2, 1337),
("Present_Velocity", 3, 42),
],
)
def test_sync_read_by_str(data_name, id_, value, dummy_motors):
bus = MockMotorsBus("/dev/dummy-port", dummy_motors)
bus.connect(handshake=False)
addr, length = DUMMY_CTRL_TABLE_2[data_name]
ids = [id_]
expected_value = {f"dummy_{id_}": value}
with (
patch.object(MockMotorsBus, "_sync_read", return_value=({id_: value}, 0)) as mock__sync_read,
patch.object(MockMotorsBus, "_decode_sign", return_value={id_: value}) as mock__decode_sign,
patch.object(MockMotorsBus, "_normalize", return_value={id_: value}) as mock__normalize,
):
returned_dict = bus.sync_read(data_name, f"dummy_{id_}")
assert returned_dict == expected_value
mock__sync_read.assert_called_once_with(
addr,
length,
ids,
num_retry=0,
raise_on_error=True,
err_msg=f"Failed to sync read '{data_name}' on {ids=} after 1 tries.",
)
mock__decode_sign.assert_called_once_with(data_name, {id_: value})
if data_name in bus.normalized_data:
mock__normalize.assert_called_once_with(data_name, {id_: value})
@pytest.mark.parametrize(
"data_name, ids_values",
[
("Model_Number", {1: 5678}),
("Present_Position", {1: 1337, 2: 42}),
("Present_Velocity", {1: 1337, 2: 42, 3: 4016}),
],
ids=["1 motor", "2 motors", "3 motors"],
)
def test_sync_read_by_list(data_name, ids_values, dummy_motors):
bus = MockMotorsBus("/dev/dummy-port", dummy_motors)
bus.connect(handshake=False)
addr, length = DUMMY_CTRL_TABLE_2[data_name]
ids = list(ids_values)
expected_values = {f"dummy_{id_}": val for id_, val in ids_values.items()}
with (
patch.object(MockMotorsBus, "_sync_read", return_value=(ids_values, 0)) as mock__sync_read,
patch.object(MockMotorsBus, "_decode_sign", return_value=ids_values) as mock__decode_sign,
patch.object(MockMotorsBus, "_normalize", return_value=ids_values) as mock__normalize,
):
returned_dict = bus.sync_read(data_name, [f"dummy_{id_}" for id_ in ids])
assert returned_dict == expected_values
mock__sync_read.assert_called_once_with(
addr,
length,
ids,
num_retry=0,
raise_on_error=True,
err_msg=f"Failed to sync read '{data_name}' on {ids=} after 1 tries.",
)
mock__decode_sign.assert_called_once_with(data_name, ids_values)
if data_name in bus.normalized_data:
mock__normalize.assert_called_once_with(data_name, ids_values)
@pytest.mark.parametrize(
"data_name, ids_values",
[
("Model_Number", {1: 5678, 2: 5799, 3: 5678}),
("Present_Position", {1: 1337, 2: 42, 3: 4016}),
("Goal_Position", {1: 4008, 2: 199, 3: 3446}),
],
ids=["Model_Number", "Present_Position", "Goal_Position"],
)
def test_sync_read_by_none(data_name, ids_values, dummy_motors):
bus = MockMotorsBus("/dev/dummy-port", dummy_motors)
bus.connect(handshake=False)
addr, length = DUMMY_CTRL_TABLE_2[data_name]
ids = list(ids_values)
expected_values = {f"dummy_{id_}": val for id_, val in ids_values.items()}
with (
patch.object(MockMotorsBus, "_sync_read", return_value=(ids_values, 0)) as mock__sync_read,
patch.object(MockMotorsBus, "_decode_sign", return_value=ids_values) as mock__decode_sign,
patch.object(MockMotorsBus, "_normalize", return_value=ids_values) as mock__normalize,
):
returned_dict = bus.sync_read(data_name)
assert returned_dict == expected_values
mock__sync_read.assert_called_once_with(
addr,
length,
ids,
num_retry=0,
raise_on_error=True,
err_msg=f"Failed to sync read '{data_name}' on {ids=} after 1 tries.",
)
mock__decode_sign.assert_called_once_with(data_name, ids_values)
if data_name in bus.normalized_data:
mock__normalize.assert_called_once_with(data_name, ids_values)
@pytest.mark.parametrize(
"data_name, value",
[
("Goal_Position", 500),
("Goal_Velocity", 4010),
("Lock", 0),
],
)
def test_sync_write_by_single_value(data_name, value, dummy_motors):
bus = MockMotorsBus("/dev/dummy-port", dummy_motors)
bus.connect(handshake=False)
addr, length = DUMMY_CTRL_TABLE_2[data_name]
ids_values = {m.id: value for m in dummy_motors.values()}
with (
patch.object(MockMotorsBus, "_sync_write", return_value=(ids_values, 0)) as mock__sync_write,
patch.object(MockMotorsBus, "_encode_sign", return_value=ids_values) as mock__encode_sign,
patch.object(MockMotorsBus, "_unnormalize", return_value=ids_values) as mock__unnormalize,
):
bus.sync_write(data_name, value)
mock__sync_write.assert_called_once_with(
addr,
length,
ids_values,
num_retry=0,
raise_on_error=True,
err_msg=f"Failed to sync write '{data_name}' with {ids_values=} after 1 tries.",
)
mock__encode_sign.assert_called_once_with(data_name, ids_values)
if data_name in bus.normalized_data:
mock__unnormalize.assert_called_once_with(data_name, ids_values)
@pytest.mark.parametrize(
"data_name, ids_values",
[
("Goal_Position", {1: 1337, 2: 42, 3: 4016}),
("Goal_Velocity", {1: 50, 2: 83, 3: 2777}),
("Lock", {1: 0, 2: 0, 3: 1}),
],
ids=["Goal_Position", "Goal_Velocity", "Lock"],
)
def test_sync_write_by_value_dict(data_name, ids_values, dummy_motors):
bus = MockMotorsBus("/dev/dummy-port", dummy_motors)
bus.connect(handshake=False)
addr, length = DUMMY_CTRL_TABLE_2[data_name]
values = {f"dummy_{id_}": val for id_, val in ids_values.items()}
with (
patch.object(MockMotorsBus, "_sync_write", return_value=(ids_values, 0)) as mock__sync_write,
patch.object(MockMotorsBus, "_encode_sign", return_value=ids_values) as mock__encode_sign,
patch.object(MockMotorsBus, "_unnormalize", return_value=ids_values) as mock__unnormalize,
):
bus.sync_write(data_name, values)
mock__sync_write.assert_called_once_with(
addr,
length,
ids_values,
num_retry=0,
raise_on_error=True,
err_msg=f"Failed to sync write '{data_name}' with {ids_values=} after 1 tries.",
)
mock__encode_sign.assert_called_once_with(data_name, ids_values)
if data_name in bus.normalized_data:
mock__unnormalize.assert_called_once_with(data_name, ids_values)