Improve feetech mocking

This commit is contained in:
Simon Alibert 2025-03-22 01:19:51 +01:00
parent fc4a95f187
commit 857f335be9
2 changed files with 179 additions and 99 deletions

View File

@ -77,6 +77,19 @@ class MockInstructionPacket(MockFeetechPacket):
0x00, # placeholder for checksum
] # fmt: skip
@classmethod
def ping(
cls,
scs_id: int,
) -> bytes:
"""
Builds a "Ping" broadcast instruction.
No parameters required.
"""
params, length = [], 2
return cls.build(scs_id=scs_id, params=params, length=length, instruct_type="Ping")
@classmethod
def sync_read(
cls,
@ -128,6 +141,25 @@ class MockStatusPacket(MockFeetechPacket):
0x00, # placeholder for checksum
] # fmt: skip
@classmethod
def ping(cls, scs_id: int, model_nb: int = 1190, firm_ver: int = 50) -> bytes:
"""Builds a 'Ping' status packet.
Args:
scs_id (int): ID of the servo responding.
model_nb (int, optional): Desired 'model number' to be returned in the packet. Defaults to 1190
which corresponds to a XL330-M077-T.
firm_ver (int, optional): Desired 'firmware version' to be returned in the packet.
Defaults to 50.
Returns:
bytes: The raw 'Ping' status packet ready to be sent through serial.
"""
# raise NotImplementedError
params = [scs.SCS_LOBYTE(model_nb), scs.SCS_HIBYTE(model_nb), firm_ver]
length = 2
return cls.build(scs_id, params=params, length=length)
@classmethod
def present_position(cls, scs_id: int, pos: int | None = None, min_max_range: tuple = (0, 4095)) -> bytes:
"""Builds a 'Present_Position' status packet.
@ -184,62 +216,69 @@ class MockMotors(MockSerial):
ctrl_table = SCS_SERIES_CONTROL_TABLE
def __init__(self, scs_ids: list[int]):
def __init__(self):
super().__init__()
self._ids = scs_ids
self.open()
def build_single_motor_stubs(
self, data_name: str, return_value: int | None = None, num_invalid_try: int | None = None
) -> None:
address, length = self.ctrl_table[data_name]
for idx in self._ids:
if data_name == "Present_Position":
sync_read_request_single = MockInstructionPacket.sync_read([idx], address, length)
sync_read_response_single = self._build_present_pos_send_fn(
[idx], [return_value], num_invalid_try
)
else:
raise NotImplementedError # TODO(aliberts): add ping?
self.stub(
name=f"SyncRead_{data_name}_{idx}",
receive_bytes=sync_read_request_single,
send_fn=sync_read_response_single,
)
def build_all_motors_stub(
self, data_name: str, return_values: list[int] | None = None, num_invalid_try: int | None = None
) -> None:
address, length = self.ctrl_table[data_name]
if data_name == "Present_Position":
sync_read_request_all = MockInstructionPacket.sync_read(self._ids, address, length)
sync_read_response_all = self._build_present_pos_send_fn(
self._ids, return_values, num_invalid_try
)
else:
raise NotImplementedError # TODO(aliberts): add ping?
self.stub(
name=f"SyncRead_{data_name}_all",
receive_bytes=sync_read_request_all,
send_fn=sync_read_response_all,
def build_broadcast_ping_stub(
self, ids_models_firmwares: dict[int, list[int]] | None = None, num_invalid_try: int = 0
) -> str:
ping_request = MockInstructionPacket.ping(scs.BROADCAST_ID)
return_packets = b"".join(
MockStatusPacket.ping(idx, model, firm_ver)
for idx, (model, firm_ver) in ids_models_firmwares.items()
)
ping_response = self._build_send_fn(return_packets, num_invalid_try)
def _build_present_pos_send_fn(
self, scs_ids: list[int], return_pos: list[int] | None = None, num_invalid_try: int | None = None
) -> Callable[[int], bytes]:
return_pos = [None for _ in scs_ids] if return_pos is None else return_pos
assert len(return_pos) == len(scs_ids)
stub_name = "Ping_" + "_".join([str(idx) for idx in ids_models_firmwares])
self.stub(
name=stub_name,
receive_bytes=ping_request,
send_fn=ping_response,
)
return stub_name
def build_ping_stub(
self, scs_id: int, model_nb: int, firm_ver: int = 50, num_invalid_try: int = 0
) -> str:
ping_request = MockInstructionPacket.ping(scs_id)
return_packet = MockStatusPacket.ping(scs_id, model_nb, firm_ver)
ping_response = self._build_send_fn(return_packet, num_invalid_try)
stub_name = f"Ping_{scs_id}"
self.stub(
name=stub_name,
receive_bytes=ping_request,
send_fn=ping_response,
)
return stub_name
def build_sync_read_stub(
self, data_name: str, ids_values: dict[int, int] | None = None, num_invalid_try: int = 0
) -> str:
address, length = self.ctrl_table[data_name]
sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length)
if data_name != "Present_Position":
raise NotImplementedError
return_packets = b"".join(
MockStatusPacket.present_position(idx, pos) for idx, pos in ids_values.items()
)
sync_read_response = self._build_send_fn(return_packets, num_invalid_try)
stub_name = f"Sync_Read_{data_name}_" + "_".join([str(idx) for idx in ids_values])
self.stub(
name=stub_name,
receive_bytes=sync_read_request,
send_fn=sync_read_response,
)
return stub_name
@staticmethod
def _build_send_fn(packet: bytes, num_invalid_try: int = 0) -> Callable[[int], bytes]:
def send_fn(_call_count: int) -> bytes:
if num_invalid_try is not None and num_invalid_try >= _call_count:
if num_invalid_try >= _call_count:
return b""
packets = b"".join(
MockStatusPacket.present_position(idx, pos)
for idx, pos in zip(scs_ids, return_pos, strict=True)
)
return packets
return packet
return send_fn

View File

@ -4,6 +4,7 @@ from unittest.mock import patch
import pytest
import scservo_sdk as scs
from lerobot.common.motors import Motor
from lerobot.common.motors.feetech import FeetechMotorsBus
from tests.mocks.mock_feetech import MockMotors, MockPortHandler
@ -17,6 +18,15 @@ def patch_port_handler():
yield
@pytest.fixture
def dummy_motors() -> dict[str, Motor]:
return {
"dummy_1": Motor(id=1, model="sts3215"),
"dummy_2": Motor(id=2, model="sts3215"),
"dummy_3": Motor(id=3, model="sts3215"),
}
@pytest.mark.skipif(sys.platform != "darwin", reason=f"No patching needed on {sys.platform=}")
def test_autouse_patch():
"""Ensures that the autouse fixture correctly patches scs.PortHandler with MockPortHandler."""
@ -68,9 +78,52 @@ def test_split_int_bytes_large_number():
FeetechMotorsBus.split_int_bytes(2**32, 4) # 4-byte max is 0xFFFFFFFF
def test_abc_implementation():
def test_abc_implementation(dummy_motors):
"""Instantiation should raise an error if the class doesn't implement abstract methods/properties."""
FeetechMotorsBus(port="/dev/dummy-port", motors={"dummy": (1, "sts3215")})
FeetechMotorsBus(port="/dev/dummy-port", motors=dummy_motors)
@pytest.mark.skip("TODO")
@pytest.mark.parametrize(
"idx, model_nb",
[
[1, 1190],
[2, 1200],
[3, 1120],
],
)
def test_ping(idx, model_nb, dummy_motors):
mock_motors = MockMotors()
mock_motors.build_ping_stub(idx, model_nb)
motors_bus = FeetechMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
motors_bus.connect()
ping_model_nb = motors_bus.ping(idx)
assert ping_model_nb == model_nb
@pytest.mark.skip("TODO")
def test_broadcast_ping(dummy_motors):
expected_pings = {
1: [1060, 50],
2: [1120, 30],
3: [1190, 10],
}
mock_motors = MockMotors()
mock_motors.build_broadcast_ping_stub(expected_pings)
motors_bus = FeetechMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
motors_bus.connect()
ping_list = motors_bus.broadcast_ping()
assert ping_list == expected_pings
@pytest.mark.parametrize(
@ -83,26 +136,25 @@ def test_abc_implementation():
],
ids=["None", "by ids", "by names", "mixed"],
)
def test_read_all_motors(motors):
mock_motors = MockMotors([1, 2, 3])
positions = [1337, 42, 4016]
mock_motors.build_all_motors_stub("Present_Position", return_values=positions)
def test_read_all_motors(motors, dummy_motors):
mock_motors = MockMotors()
expected_positions = {
1: 1337,
2: 42,
3: 4016,
}
stub_name = mock_motors.build_sync_read_stub("Present_Position", expected_positions)
motors_bus = FeetechMotorsBus(
port=mock_motors.port,
motors={
"dummy_1": (1, "sts3215"),
"dummy_2": (2, "sts3215"),
"dummy_3": (3, "sts3215"),
},
motors=dummy_motors,
)
motors_bus.connect()
pos_dict = motors_bus.read("Present_Position", motors=motors)
positions_read = motors_bus.read("Present_Position", motors=motors)
assert mock_motors.stubs["SyncRead_Present_Position_all"].called
assert all(returned_pos == pos for returned_pos, pos in zip(pos_dict.values(), positions, strict=True))
assert set(pos_dict) == {"dummy_1", "dummy_2", "dummy_3"}
assert all(pos >= 0 and pos <= 4095 for pos in pos_dict.values())
motors = ["dummy_1", "dummy_2", "dummy_3"] if motors is None else motors
assert mock_motors.stubs[stub_name].called
assert positions_read == dict(zip(motors, expected_positions.values(), strict=True))
@pytest.mark.parametrize(
@ -113,24 +165,20 @@ def test_read_all_motors(motors):
[3, 4016],
],
)
def test_read_single_motor_by_name(idx, pos):
mock_motors = MockMotors([1, 2, 3])
mock_motors.build_single_motor_stubs("Present_Position", return_value=pos)
def test_read_single_motor_by_name(idx, pos, dummy_motors):
mock_motors = MockMotors()
expected_position = {idx: pos}
stub_name = mock_motors.build_sync_read_stub("Present_Position", expected_position)
motors_bus = FeetechMotorsBus(
port=mock_motors.port,
motors={
"dummy_1": (1, "sts3215"),
"dummy_2": (2, "sts3215"),
"dummy_3": (3, "sts3215"),
},
motors=dummy_motors,
)
motors_bus.connect()
pos_dict = motors_bus.read("Present_Position", f"dummy_{idx}")
assert mock_motors.stubs[f"SyncRead_Present_Position_{idx}"].called
assert mock_motors.stubs[stub_name].called
assert pos_dict == {f"dummy_{idx}": pos}
assert all(pos >= 0 and pos <= 4095 for pos in pos_dict.values())
@pytest.mark.parametrize(
@ -141,56 +189,49 @@ def test_read_single_motor_by_name(idx, pos):
[3, 4016],
],
)
def test_read_single_motor_by_id(idx, pos):
mock_motors = MockMotors([1, 2, 3])
mock_motors.build_single_motor_stubs("Present_Position", return_value=pos)
def test_read_single_motor_by_id(idx, pos, dummy_motors):
mock_motors = MockMotors()
expected_position = {idx: pos}
stub_name = mock_motors.build_sync_read_stub("Present_Position", expected_position)
motors_bus = FeetechMotorsBus(
port=mock_motors.port,
motors={
"dummy_1": (1, "sts3215"),
"dummy_2": (2, "sts3215"),
"dummy_3": (3, "sts3215"),
},
motors=dummy_motors,
)
motors_bus.connect()
pos_dict = motors_bus.read("Present_Position", idx)
assert mock_motors.stubs[f"SyncRead_Present_Position_{idx}"].called
assert pos_dict == {f"dummy_{idx}": pos}
assert all(pos >= 0 and pos <= 4095 for pos in pos_dict.values())
assert mock_motors.stubs[stub_name].called
assert pos_dict == {idx: pos}
@pytest.mark.parametrize(
"num_retry, num_invalid_try, pos",
[
[1, 2, 1337],
[0, 2, 1337],
[2, 3, 42],
[3, 2, 4016],
[2, 1, 999],
],
)
def test_read_num_retry(num_retry, num_invalid_try, pos):
mock_motors = MockMotors([1, 2, 3])
mock_motors.build_single_motor_stubs(
"Present_Position", return_value=pos, num_invalid_try=num_invalid_try
def test_read_num_retry(num_retry, num_invalid_try, pos, dummy_motors):
mock_motors = MockMotors()
expected_position = {1: pos}
stub_name = mock_motors.build_sync_read_stub(
"Present_Position", expected_position, num_invalid_try=num_invalid_try
)
motors_bus = FeetechMotorsBus(
port=mock_motors.port,
motors={
"dummy_1": (1, "sts3215"),
"dummy_2": (2, "sts3215"),
"dummy_3": (3, "sts3215"),
},
motors=dummy_motors,
)
motors_bus.connect()
if num_retry >= num_invalid_try:
pos_dict = motors_bus.read("Present_Position", 1, num_retry=num_retry)
assert pos_dict == {"dummy_1": pos}
assert all(pos >= 0 and pos <= 4095 for pos in pos_dict.values())
assert pos_dict == {1: pos}
else:
with pytest.raises(ConnectionError):
_ = motors_bus.read("Present_Position", 1, num_retry=num_retry)
assert mock_motors.stubs["SyncRead_Present_Position_1"].calls == num_retry
expected_calls = min(1 + num_retry, 1 + num_invalid_try)
assert mock_motors.stubs[stub_name].calls == expected_calls