Add feetech write tests

This commit is contained in:
Simon Alibert 2025-03-22 17:02:01 +01:00
parent 0ccc957d5c
commit 1d3e1cbdbd
2 changed files with 149 additions and 7 deletions

View File

@ -1,12 +1,14 @@
import abc import abc
import random import random
import threading
import time
from typing import Callable from typing import Callable
import scservo_sdk as scs import scservo_sdk as scs
import serial import serial
from mock_serial import MockSerial from mock_serial.mock_serial import MockSerial, Stub
from lerobot.common.motors.feetech.tables import SCS_SERIES_CONTROL_TABLE from lerobot.common.motors.feetech import SCS_SERIES_CONTROL_TABLE, FeetechMotorsBus
# https://files.waveshare.com/upload/2/27/Communication_Protocol_User_Manual-EN%28191218-0923%29.pdf # https://files.waveshare.com/upload/2/27/Communication_Protocol_User_Manual-EN%28191218-0923%29.pdf
INSTRUCTION_TYPES = { INSTRUCTION_TYPES = {
@ -100,12 +102,12 @@ class MockInstructionPacket(MockFeetechPacket):
""" """
Builds a "Sync_Read" broadcast instruction. Builds a "Sync_Read" broadcast instruction.
The parameters for Sync Read (Protocol 2.0) are: The parameters for Sync Read are:
param[0] = start_address param[0] = start_address
param[1] = data_length param[1] = data_length
param[2+] = motor IDs to read from param[2+] = motor IDs to read from
And 'length' = (number_of_params + 7), where: And 'length' = (number_of_params + 4), where:
+1 is for instruction byte, +1 is for instruction byte,
+1 is for the address byte, +1 is for the address byte,
+1 is for the length bytes, +1 is for the length bytes,
@ -115,6 +117,44 @@ class MockInstructionPacket(MockFeetechPacket):
length = len(scs_ids) + 4 length = len(scs_ids) + 4
return cls.build(scs_id=scs.BROADCAST_ID, params=params, length=length, instruct_type="Sync_Read") return cls.build(scs_id=scs.BROADCAST_ID, params=params, length=length, instruct_type="Sync_Read")
@classmethod
def sync_write(
cls,
ids_values: dict[int],
start_address: int,
data_length: int,
) -> bytes:
"""
Builds a "Sync_Write" broadcast instruction.
The parameters for Sync_Write are:
param[0] = start_address
param[1] = data_length
param[2] = [1st motor] ID
param[2+1] = [1st motor] 1st Byte
param[2+2] = [1st motor] 2nd Byte
...
param[5+X] = [1st motor] X-th Byte
param[6] = [2nd motor] ID
param[6+1] = [2nd motor] 1st Byte
param[6+2] = [2nd motor] 2nd Byte
...
param[6+X] = [2nd motor] X-th Byte
And 'length' = ((number_of_params * 1 + data_length) + 4), where:
+1 is for instruction byte,
+1 is for the address byte,
+1 is for the length bytes,
+1 is for the checksum at the end.
"""
data = []
for idx, value in ids_values.items():
split_value = FeetechMotorsBus.split_int_bytes(value, data_length)
data += [idx, *split_value]
params = [start_address, data_length, *data]
length = len(ids_values) * (1 + data_length) + 4
return cls.build(scs_id=scs.BROADCAST_ID, params=params, length=length, instruct_type="Sync_Write")
class MockStatusPacket(MockFeetechPacket): class MockStatusPacket(MockFeetechPacket):
""" """
@ -205,13 +245,38 @@ class MockPortHandler(scs.PortHandler):
return True return True
class WaitableStub(Stub):
"""
In some situations, a test might be checking if a stub has been called before `MockSerial` thread had time
to read, match, and call the stub. In these situations, the test can fail randomly.
Use `wait_called()` or `wait_calls()` to block until the stub is called, avoiding race conditions.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._event = threading.Event()
def call(self):
self._event.set()
return super().call()
def wait_called(self, timeout: float = 1.0):
return self._event.wait(timeout)
def wait_calls(self, min_calls: int = 1, timeout: float = 1.0):
start = time.perf_counter()
while time.perf_counter() - start < timeout:
if self.calls >= min_calls:
return self.calls
time.sleep(0.005)
raise TimeoutError(f"Stub not called {min_calls} times within {timeout} seconds.")
class MockMotors(MockSerial): class MockMotors(MockSerial):
""" """
This class will simulate physical motors by responding with valid status packets upon receiving some This class will simulate physical motors by responding with valid status packets upon receiving some
instruction packets. It is meant to test MotorsBus classes. instruction packets. It is meant to test MotorsBus classes.
'data_name' supported:
- Present_Position
""" """
ctrl_table = SCS_SERIES_CONTROL_TABLE ctrl_table = SCS_SERIES_CONTROL_TABLE
@ -220,6 +285,15 @@ class MockMotors(MockSerial):
super().__init__() super().__init__()
self.open() self.open()
@property
def stubs(self) -> dict[str, WaitableStub]:
return super().stubs
def stub(self, *, name=None, **kwargs):
new_stub = WaitableStub(**kwargs)
self._MockSerial__stubs[name or new_stub.receive_bytes] = new_stub
return new_stub
def build_broadcast_ping_stub( def build_broadcast_ping_stub(
self, ids_models_firmwares: dict[int, list[int]] | None = None, num_invalid_try: int = 0 self, ids_models_firmwares: dict[int, list[int]] | None = None, num_invalid_try: int = 0
) -> str: ) -> str:
@ -256,6 +330,10 @@ class MockMotors(MockSerial):
def build_sync_read_stub( def build_sync_read_stub(
self, data_name: str, ids_values: dict[int, int] | None = None, num_invalid_try: int = 0 self, data_name: str, ids_values: dict[int, int] | None = None, num_invalid_try: int = 0
) -> str: ) -> str:
"""
'data_name' supported:
- Present_Position
"""
address, length = self.ctrl_table[data_name] address, length = self.ctrl_table[data_name]
sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length) sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length)
if data_name != "Present_Position": if data_name != "Present_Position":
@ -274,6 +352,19 @@ class MockMotors(MockSerial):
) )
return stub_name return stub_name
def build_sync_write_stub(
self, data_name: str, ids_values: dict[int, int] | None = None, num_invalid_try: int = 0
) -> str:
address, length = self.ctrl_table[data_name]
sync_read_request = MockInstructionPacket.sync_write(ids_values, address, length)
stub_name = f"Sync_Write_{data_name}_" + "_".join([str(idx) for idx in ids_values])
self.stub(
name=stub_name,
receive_bytes=sync_read_request,
send_fn=self._build_send_fn(b"", num_invalid_try),
)
return stub_name
@staticmethod @staticmethod
def _build_send_fn(packet: bytes, num_invalid_try: int = 0) -> Callable[[int], bytes]: def _build_send_fn(packet: bytes, num_invalid_try: int = 0) -> Callable[[int], bytes]:
def send_fn(_call_count: int) -> bytes: def send_fn(_call_count: int) -> bytes:

View File

@ -235,3 +235,54 @@ def test_read_num_retry(num_retry, num_invalid_try, pos, dummy_motors):
expected_calls = min(1 + num_retry, 1 + num_invalid_try) expected_calls = min(1 + num_retry, 1 + num_invalid_try)
assert mock_motors.stubs[stub_name].calls == expected_calls assert mock_motors.stubs[stub_name].calls == expected_calls
@pytest.mark.parametrize(
"motors",
[
[1, 2, 3],
["dummy_1", "dummy_2", "dummy_3"],
[1, "dummy_2", 3],
],
ids=["by ids", "by names", "mixed"],
)
def test_write_all_motors(motors, dummy_motors):
mock_motors = MockMotors()
goal_positions = {
1: 1337,
2: 42,
3: 4016,
}
stub_name = mock_motors.build_sync_write_stub("Goal_Position", goal_positions)
motors_bus = FeetechMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
motors_bus.connect()
values = dict(zip(motors, goal_positions.values(), strict=True))
motors_bus.write("Goal_Position", values)
assert mock_motors.stubs[stub_name].wait_called()
@pytest.mark.parametrize(
"data_name, value",
[
["Torque_Enable", 0],
["Torque_Enable", 1],
],
)
def test_write_all_motors_single_value(data_name, value, dummy_motors):
mock_motors = MockMotors()
values = {m.id: value for m in dummy_motors.values()}
stub_name = mock_motors.build_sync_write_stub(data_name, values)
motors_bus = FeetechMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
motors_bus.connect()
motors_bus.write(data_name, value)
assert mock_motors.stubs[stub_name].wait_called()