Improve feetech mocking
This commit is contained in:
parent
fc4a95f187
commit
857f335be9
|
@ -77,6 +77,19 @@ class MockInstructionPacket(MockFeetechPacket):
|
||||||
0x00, # placeholder for checksum
|
0x00, # placeholder for checksum
|
||||||
] # fmt: skip
|
] # fmt: skip
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def ping(
|
||||||
|
cls,
|
||||||
|
scs_id: int,
|
||||||
|
) -> bytes:
|
||||||
|
"""
|
||||||
|
Builds a "Ping" broadcast instruction.
|
||||||
|
|
||||||
|
No parameters required.
|
||||||
|
"""
|
||||||
|
params, length = [], 2
|
||||||
|
return cls.build(scs_id=scs_id, params=params, length=length, instruct_type="Ping")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sync_read(
|
def sync_read(
|
||||||
cls,
|
cls,
|
||||||
|
@ -128,6 +141,25 @@ class MockStatusPacket(MockFeetechPacket):
|
||||||
0x00, # placeholder for checksum
|
0x00, # placeholder for checksum
|
||||||
] # fmt: skip
|
] # fmt: skip
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def ping(cls, scs_id: int, model_nb: int = 1190, firm_ver: int = 50) -> bytes:
|
||||||
|
"""Builds a 'Ping' status packet.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scs_id (int): ID of the servo responding.
|
||||||
|
model_nb (int, optional): Desired 'model number' to be returned in the packet. Defaults to 1190
|
||||||
|
which corresponds to a XL330-M077-T.
|
||||||
|
firm_ver (int, optional): Desired 'firmware version' to be returned in the packet.
|
||||||
|
Defaults to 50.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bytes: The raw 'Ping' status packet ready to be sent through serial.
|
||||||
|
"""
|
||||||
|
# raise NotImplementedError
|
||||||
|
params = [scs.SCS_LOBYTE(model_nb), scs.SCS_HIBYTE(model_nb), firm_ver]
|
||||||
|
length = 2
|
||||||
|
return cls.build(scs_id, params=params, length=length)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def present_position(cls, scs_id: int, pos: int | None = None, min_max_range: tuple = (0, 4095)) -> bytes:
|
def present_position(cls, scs_id: int, pos: int | None = None, min_max_range: tuple = (0, 4095)) -> bytes:
|
||||||
"""Builds a 'Present_Position' status packet.
|
"""Builds a 'Present_Position' status packet.
|
||||||
|
@ -184,62 +216,69 @@ class MockMotors(MockSerial):
|
||||||
|
|
||||||
ctrl_table = SCS_SERIES_CONTROL_TABLE
|
ctrl_table = SCS_SERIES_CONTROL_TABLE
|
||||||
|
|
||||||
def __init__(self, scs_ids: list[int]):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._ids = scs_ids
|
|
||||||
self.open()
|
self.open()
|
||||||
|
|
||||||
def build_single_motor_stubs(
|
def build_broadcast_ping_stub(
|
||||||
self, data_name: str, return_value: int | None = None, num_invalid_try: int | None = None
|
self, ids_models_firmwares: dict[int, list[int]] | None = None, num_invalid_try: int = 0
|
||||||
) -> None:
|
) -> str:
|
||||||
address, length = self.ctrl_table[data_name]
|
ping_request = MockInstructionPacket.ping(scs.BROADCAST_ID)
|
||||||
for idx in self._ids:
|
return_packets = b"".join(
|
||||||
if data_name == "Present_Position":
|
MockStatusPacket.ping(idx, model, firm_ver)
|
||||||
sync_read_request_single = MockInstructionPacket.sync_read([idx], address, length)
|
for idx, (model, firm_ver) in ids_models_firmwares.items()
|
||||||
sync_read_response_single = self._build_present_pos_send_fn(
|
|
||||||
[idx], [return_value], num_invalid_try
|
|
||||||
)
|
)
|
||||||
else:
|
ping_response = self._build_send_fn(return_packets, num_invalid_try)
|
||||||
raise NotImplementedError # TODO(aliberts): add ping?
|
|
||||||
|
|
||||||
|
stub_name = "Ping_" + "_".join([str(idx) for idx in ids_models_firmwares])
|
||||||
self.stub(
|
self.stub(
|
||||||
name=f"SyncRead_{data_name}_{idx}",
|
name=stub_name,
|
||||||
receive_bytes=sync_read_request_single,
|
receive_bytes=ping_request,
|
||||||
send_fn=sync_read_response_single,
|
send_fn=ping_response,
|
||||||
)
|
)
|
||||||
|
return stub_name
|
||||||
|
|
||||||
def build_all_motors_stub(
|
def build_ping_stub(
|
||||||
self, data_name: str, return_values: list[int] | None = None, num_invalid_try: int | None = None
|
self, scs_id: int, model_nb: int, firm_ver: int = 50, num_invalid_try: int = 0
|
||||||
) -> None:
|
) -> str:
|
||||||
address, length = self.ctrl_table[data_name]
|
ping_request = MockInstructionPacket.ping(scs_id)
|
||||||
if data_name == "Present_Position":
|
return_packet = MockStatusPacket.ping(scs_id, model_nb, firm_ver)
|
||||||
sync_read_request_all = MockInstructionPacket.sync_read(self._ids, address, length)
|
ping_response = self._build_send_fn(return_packet, num_invalid_try)
|
||||||
sync_read_response_all = self._build_present_pos_send_fn(
|
|
||||||
self._ids, return_values, num_invalid_try
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError # TODO(aliberts): add ping?
|
|
||||||
|
|
||||||
|
stub_name = f"Ping_{scs_id}"
|
||||||
self.stub(
|
self.stub(
|
||||||
name=f"SyncRead_{data_name}_all",
|
name=stub_name,
|
||||||
receive_bytes=sync_read_request_all,
|
receive_bytes=ping_request,
|
||||||
send_fn=sync_read_response_all,
|
send_fn=ping_response,
|
||||||
)
|
)
|
||||||
|
return stub_name
|
||||||
|
|
||||||
def _build_present_pos_send_fn(
|
def build_sync_read_stub(
|
||||||
self, scs_ids: list[int], return_pos: list[int] | None = None, num_invalid_try: int | None = None
|
self, data_name: str, ids_values: dict[int, int] | None = None, num_invalid_try: int = 0
|
||||||
) -> Callable[[int], bytes]:
|
) -> str:
|
||||||
return_pos = [None for _ in scs_ids] if return_pos is None else return_pos
|
address, length = self.ctrl_table[data_name]
|
||||||
assert len(return_pos) == len(scs_ids)
|
sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length)
|
||||||
|
if data_name != "Present_Position":
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
return_packets = b"".join(
|
||||||
|
MockStatusPacket.present_position(idx, pos) for idx, pos in ids_values.items()
|
||||||
|
)
|
||||||
|
sync_read_response = self._build_send_fn(return_packets, num_invalid_try)
|
||||||
|
|
||||||
|
stub_name = f"Sync_Read_{data_name}_" + "_".join([str(idx) for idx in ids_values])
|
||||||
|
self.stub(
|
||||||
|
name=stub_name,
|
||||||
|
receive_bytes=sync_read_request,
|
||||||
|
send_fn=sync_read_response,
|
||||||
|
)
|
||||||
|
return stub_name
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_send_fn(packet: bytes, num_invalid_try: int = 0) -> Callable[[int], bytes]:
|
||||||
def send_fn(_call_count: int) -> bytes:
|
def send_fn(_call_count: int) -> bytes:
|
||||||
if num_invalid_try is not None and num_invalid_try >= _call_count:
|
if num_invalid_try >= _call_count:
|
||||||
return b""
|
return b""
|
||||||
|
return packet
|
||||||
packets = b"".join(
|
|
||||||
MockStatusPacket.present_position(idx, pos)
|
|
||||||
for idx, pos in zip(scs_ids, return_pos, strict=True)
|
|
||||||
)
|
|
||||||
return packets
|
|
||||||
|
|
||||||
return send_fn
|
return send_fn
|
||||||
|
|
|
@ -4,6 +4,7 @@ from unittest.mock import patch
|
||||||
import pytest
|
import pytest
|
||||||
import scservo_sdk as scs
|
import scservo_sdk as scs
|
||||||
|
|
||||||
|
from lerobot.common.motors import Motor
|
||||||
from lerobot.common.motors.feetech import FeetechMotorsBus
|
from lerobot.common.motors.feetech import FeetechMotorsBus
|
||||||
from tests.mocks.mock_feetech import MockMotors, MockPortHandler
|
from tests.mocks.mock_feetech import MockMotors, MockPortHandler
|
||||||
|
|
||||||
|
@ -17,6 +18,15 @@ def patch_port_handler():
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def dummy_motors() -> dict[str, Motor]:
|
||||||
|
return {
|
||||||
|
"dummy_1": Motor(id=1, model="sts3215"),
|
||||||
|
"dummy_2": Motor(id=2, model="sts3215"),
|
||||||
|
"dummy_3": Motor(id=3, model="sts3215"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(sys.platform != "darwin", reason=f"No patching needed on {sys.platform=}")
|
@pytest.mark.skipif(sys.platform != "darwin", reason=f"No patching needed on {sys.platform=}")
|
||||||
def test_autouse_patch():
|
def test_autouse_patch():
|
||||||
"""Ensures that the autouse fixture correctly patches scs.PortHandler with MockPortHandler."""
|
"""Ensures that the autouse fixture correctly patches scs.PortHandler with MockPortHandler."""
|
||||||
|
@ -68,9 +78,52 @@ def test_split_int_bytes_large_number():
|
||||||
FeetechMotorsBus.split_int_bytes(2**32, 4) # 4-byte max is 0xFFFFFFFF
|
FeetechMotorsBus.split_int_bytes(2**32, 4) # 4-byte max is 0xFFFFFFFF
|
||||||
|
|
||||||
|
|
||||||
def test_abc_implementation():
|
def test_abc_implementation(dummy_motors):
|
||||||
"""Instantiation should raise an error if the class doesn't implement abstract methods/properties."""
|
"""Instantiation should raise an error if the class doesn't implement abstract methods/properties."""
|
||||||
FeetechMotorsBus(port="/dev/dummy-port", motors={"dummy": (1, "sts3215")})
|
FeetechMotorsBus(port="/dev/dummy-port", motors=dummy_motors)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip("TODO")
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"idx, model_nb",
|
||||||
|
[
|
||||||
|
[1, 1190],
|
||||||
|
[2, 1200],
|
||||||
|
[3, 1120],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_ping(idx, model_nb, dummy_motors):
|
||||||
|
mock_motors = MockMotors()
|
||||||
|
mock_motors.build_ping_stub(idx, model_nb)
|
||||||
|
motors_bus = FeetechMotorsBus(
|
||||||
|
port=mock_motors.port,
|
||||||
|
motors=dummy_motors,
|
||||||
|
)
|
||||||
|
motors_bus.connect()
|
||||||
|
|
||||||
|
ping_model_nb = motors_bus.ping(idx)
|
||||||
|
|
||||||
|
assert ping_model_nb == model_nb
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip("TODO")
|
||||||
|
def test_broadcast_ping(dummy_motors):
|
||||||
|
expected_pings = {
|
||||||
|
1: [1060, 50],
|
||||||
|
2: [1120, 30],
|
||||||
|
3: [1190, 10],
|
||||||
|
}
|
||||||
|
mock_motors = MockMotors()
|
||||||
|
mock_motors.build_broadcast_ping_stub(expected_pings)
|
||||||
|
motors_bus = FeetechMotorsBus(
|
||||||
|
port=mock_motors.port,
|
||||||
|
motors=dummy_motors,
|
||||||
|
)
|
||||||
|
motors_bus.connect()
|
||||||
|
|
||||||
|
ping_list = motors_bus.broadcast_ping()
|
||||||
|
|
||||||
|
assert ping_list == expected_pings
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -83,26 +136,25 @@ def test_abc_implementation():
|
||||||
],
|
],
|
||||||
ids=["None", "by ids", "by names", "mixed"],
|
ids=["None", "by ids", "by names", "mixed"],
|
||||||
)
|
)
|
||||||
def test_read_all_motors(motors):
|
def test_read_all_motors(motors, dummy_motors):
|
||||||
mock_motors = MockMotors([1, 2, 3])
|
mock_motors = MockMotors()
|
||||||
positions = [1337, 42, 4016]
|
expected_positions = {
|
||||||
mock_motors.build_all_motors_stub("Present_Position", return_values=positions)
|
1: 1337,
|
||||||
|
2: 42,
|
||||||
|
3: 4016,
|
||||||
|
}
|
||||||
|
stub_name = mock_motors.build_sync_read_stub("Present_Position", expected_positions)
|
||||||
motors_bus = FeetechMotorsBus(
|
motors_bus = FeetechMotorsBus(
|
||||||
port=mock_motors.port,
|
port=mock_motors.port,
|
||||||
motors={
|
motors=dummy_motors,
|
||||||
"dummy_1": (1, "sts3215"),
|
|
||||||
"dummy_2": (2, "sts3215"),
|
|
||||||
"dummy_3": (3, "sts3215"),
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
motors_bus.connect()
|
motors_bus.connect()
|
||||||
|
|
||||||
pos_dict = motors_bus.read("Present_Position", motors=motors)
|
positions_read = motors_bus.read("Present_Position", motors=motors)
|
||||||
|
|
||||||
assert mock_motors.stubs["SyncRead_Present_Position_all"].called
|
motors = ["dummy_1", "dummy_2", "dummy_3"] if motors is None else motors
|
||||||
assert all(returned_pos == pos for returned_pos, pos in zip(pos_dict.values(), positions, strict=True))
|
assert mock_motors.stubs[stub_name].called
|
||||||
assert set(pos_dict) == {"dummy_1", "dummy_2", "dummy_3"}
|
assert positions_read == dict(zip(motors, expected_positions.values(), strict=True))
|
||||||
assert all(pos >= 0 and pos <= 4095 for pos in pos_dict.values())
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -113,24 +165,20 @@ def test_read_all_motors(motors):
|
||||||
[3, 4016],
|
[3, 4016],
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_read_single_motor_by_name(idx, pos):
|
def test_read_single_motor_by_name(idx, pos, dummy_motors):
|
||||||
mock_motors = MockMotors([1, 2, 3])
|
mock_motors = MockMotors()
|
||||||
mock_motors.build_single_motor_stubs("Present_Position", return_value=pos)
|
expected_position = {idx: pos}
|
||||||
|
stub_name = mock_motors.build_sync_read_stub("Present_Position", expected_position)
|
||||||
motors_bus = FeetechMotorsBus(
|
motors_bus = FeetechMotorsBus(
|
||||||
port=mock_motors.port,
|
port=mock_motors.port,
|
||||||
motors={
|
motors=dummy_motors,
|
||||||
"dummy_1": (1, "sts3215"),
|
|
||||||
"dummy_2": (2, "sts3215"),
|
|
||||||
"dummy_3": (3, "sts3215"),
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
motors_bus.connect()
|
motors_bus.connect()
|
||||||
|
|
||||||
pos_dict = motors_bus.read("Present_Position", f"dummy_{idx}")
|
pos_dict = motors_bus.read("Present_Position", f"dummy_{idx}")
|
||||||
|
|
||||||
assert mock_motors.stubs[f"SyncRead_Present_Position_{idx}"].called
|
assert mock_motors.stubs[stub_name].called
|
||||||
assert pos_dict == {f"dummy_{idx}": pos}
|
assert pos_dict == {f"dummy_{idx}": pos}
|
||||||
assert all(pos >= 0 and pos <= 4095 for pos in pos_dict.values())
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -141,56 +189,49 @@ def test_read_single_motor_by_name(idx, pos):
|
||||||
[3, 4016],
|
[3, 4016],
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_read_single_motor_by_id(idx, pos):
|
def test_read_single_motor_by_id(idx, pos, dummy_motors):
|
||||||
mock_motors = MockMotors([1, 2, 3])
|
mock_motors = MockMotors()
|
||||||
mock_motors.build_single_motor_stubs("Present_Position", return_value=pos)
|
expected_position = {idx: pos}
|
||||||
|
stub_name = mock_motors.build_sync_read_stub("Present_Position", expected_position)
|
||||||
motors_bus = FeetechMotorsBus(
|
motors_bus = FeetechMotorsBus(
|
||||||
port=mock_motors.port,
|
port=mock_motors.port,
|
||||||
motors={
|
motors=dummy_motors,
|
||||||
"dummy_1": (1, "sts3215"),
|
|
||||||
"dummy_2": (2, "sts3215"),
|
|
||||||
"dummy_3": (3, "sts3215"),
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
motors_bus.connect()
|
motors_bus.connect()
|
||||||
|
|
||||||
pos_dict = motors_bus.read("Present_Position", idx)
|
pos_dict = motors_bus.read("Present_Position", idx)
|
||||||
|
|
||||||
assert mock_motors.stubs[f"SyncRead_Present_Position_{idx}"].called
|
assert mock_motors.stubs[stub_name].called
|
||||||
assert pos_dict == {f"dummy_{idx}": pos}
|
assert pos_dict == {idx: pos}
|
||||||
assert all(pos >= 0 and pos <= 4095 for pos in pos_dict.values())
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"num_retry, num_invalid_try, pos",
|
"num_retry, num_invalid_try, pos",
|
||||||
[
|
[
|
||||||
[1, 2, 1337],
|
[0, 2, 1337],
|
||||||
[2, 3, 42],
|
[2, 3, 42],
|
||||||
[3, 2, 4016],
|
[3, 2, 4016],
|
||||||
[2, 1, 999],
|
[2, 1, 999],
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_read_num_retry(num_retry, num_invalid_try, pos):
|
def test_read_num_retry(num_retry, num_invalid_try, pos, dummy_motors):
|
||||||
mock_motors = MockMotors([1, 2, 3])
|
mock_motors = MockMotors()
|
||||||
mock_motors.build_single_motor_stubs(
|
expected_position = {1: pos}
|
||||||
"Present_Position", return_value=pos, num_invalid_try=num_invalid_try
|
stub_name = mock_motors.build_sync_read_stub(
|
||||||
|
"Present_Position", expected_position, num_invalid_try=num_invalid_try
|
||||||
)
|
)
|
||||||
motors_bus = FeetechMotorsBus(
|
motors_bus = FeetechMotorsBus(
|
||||||
port=mock_motors.port,
|
port=mock_motors.port,
|
||||||
motors={
|
motors=dummy_motors,
|
||||||
"dummy_1": (1, "sts3215"),
|
|
||||||
"dummy_2": (2, "sts3215"),
|
|
||||||
"dummy_3": (3, "sts3215"),
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
motors_bus.connect()
|
motors_bus.connect()
|
||||||
|
|
||||||
if num_retry >= num_invalid_try:
|
if num_retry >= num_invalid_try:
|
||||||
pos_dict = motors_bus.read("Present_Position", 1, num_retry=num_retry)
|
pos_dict = motors_bus.read("Present_Position", 1, num_retry=num_retry)
|
||||||
assert pos_dict == {"dummy_1": pos}
|
assert pos_dict == {1: pos}
|
||||||
assert all(pos >= 0 and pos <= 4095 for pos in pos_dict.values())
|
|
||||||
else:
|
else:
|
||||||
with pytest.raises(ConnectionError):
|
with pytest.raises(ConnectionError):
|
||||||
_ = motors_bus.read("Present_Position", 1, num_retry=num_retry)
|
_ = motors_bus.read("Present_Position", 1, num_retry=num_retry)
|
||||||
|
|
||||||
assert mock_motors.stubs["SyncRead_Present_Position_1"].calls == num_retry
|
expected_calls = min(1 + num_retry, 1 + num_invalid_try)
|
||||||
|
assert mock_motors.stubs[stub_name].calls == expected_calls
|
||||||
|
|
Loading…
Reference in New Issue