Add feetech write tests
This commit is contained in:
parent
0ccc957d5c
commit
1d3e1cbdbd
|
@ -1,12 +1,14 @@
|
|||
import abc
|
||||
import random
|
||||
import threading
|
||||
import time
|
||||
from typing import Callable
|
||||
|
||||
import scservo_sdk as scs
|
||||
import serial
|
||||
from mock_serial import MockSerial
|
||||
from mock_serial.mock_serial import MockSerial, Stub
|
||||
|
||||
from lerobot.common.motors.feetech.tables import SCS_SERIES_CONTROL_TABLE
|
||||
from lerobot.common.motors.feetech import SCS_SERIES_CONTROL_TABLE, FeetechMotorsBus
|
||||
|
||||
# https://files.waveshare.com/upload/2/27/Communication_Protocol_User_Manual-EN%28191218-0923%29.pdf
|
||||
INSTRUCTION_TYPES = {
|
||||
|
@ -100,12 +102,12 @@ class MockInstructionPacket(MockFeetechPacket):
|
|||
"""
|
||||
Builds a "Sync_Read" broadcast instruction.
|
||||
|
||||
The parameters for Sync Read (Protocol 2.0) are:
|
||||
The parameters for Sync Read are:
|
||||
param[0] = start_address
|
||||
param[1] = data_length
|
||||
param[2+] = motor IDs to read from
|
||||
|
||||
And 'length' = (number_of_params + 7), where:
|
||||
And 'length' = (number_of_params + 4), where:
|
||||
+1 is for instruction byte,
|
||||
+1 is for the address byte,
|
||||
+1 is for the length bytes,
|
||||
|
@ -115,6 +117,44 @@ class MockInstructionPacket(MockFeetechPacket):
|
|||
length = len(scs_ids) + 4
|
||||
return cls.build(scs_id=scs.BROADCAST_ID, params=params, length=length, instruct_type="Sync_Read")
|
||||
|
||||
@classmethod
|
||||
def sync_write(
|
||||
cls,
|
||||
ids_values: dict[int],
|
||||
start_address: int,
|
||||
data_length: int,
|
||||
) -> bytes:
|
||||
"""
|
||||
Builds a "Sync_Write" broadcast instruction.
|
||||
|
||||
The parameters for Sync_Write are:
|
||||
param[0] = start_address
|
||||
param[1] = data_length
|
||||
param[2] = [1st motor] ID
|
||||
param[2+1] = [1st motor] 1st Byte
|
||||
param[2+2] = [1st motor] 2nd Byte
|
||||
...
|
||||
param[5+X] = [1st motor] X-th Byte
|
||||
param[6] = [2nd motor] ID
|
||||
param[6+1] = [2nd motor] 1st Byte
|
||||
param[6+2] = [2nd motor] 2nd Byte
|
||||
...
|
||||
param[6+X] = [2nd motor] X-th Byte
|
||||
|
||||
And 'length' = ((number_of_params * 1 + data_length) + 4), where:
|
||||
+1 is for instruction byte,
|
||||
+1 is for the address byte,
|
||||
+1 is for the length bytes,
|
||||
+1 is for the checksum at the end.
|
||||
"""
|
||||
data = []
|
||||
for idx, value in ids_values.items():
|
||||
split_value = FeetechMotorsBus.split_int_bytes(value, data_length)
|
||||
data += [idx, *split_value]
|
||||
params = [start_address, data_length, *data]
|
||||
length = len(ids_values) * (1 + data_length) + 4
|
||||
return cls.build(scs_id=scs.BROADCAST_ID, params=params, length=length, instruct_type="Sync_Write")
|
||||
|
||||
|
||||
class MockStatusPacket(MockFeetechPacket):
|
||||
"""
|
||||
|
@ -205,13 +245,38 @@ class MockPortHandler(scs.PortHandler):
|
|||
return True
|
||||
|
||||
|
||||
class WaitableStub(Stub):
|
||||
"""
|
||||
In some situations, a test might be checking if a stub has been called before `MockSerial` thread had time
|
||||
to read, match, and call the stub. In these situations, the test can fail randomly.
|
||||
|
||||
Use `wait_called()` or `wait_calls()` to block until the stub is called, avoiding race conditions.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._event = threading.Event()
|
||||
|
||||
def call(self):
|
||||
self._event.set()
|
||||
return super().call()
|
||||
|
||||
def wait_called(self, timeout: float = 1.0):
|
||||
return self._event.wait(timeout)
|
||||
|
||||
def wait_calls(self, min_calls: int = 1, timeout: float = 1.0):
|
||||
start = time.perf_counter()
|
||||
while time.perf_counter() - start < timeout:
|
||||
if self.calls >= min_calls:
|
||||
return self.calls
|
||||
time.sleep(0.005)
|
||||
raise TimeoutError(f"Stub not called {min_calls} times within {timeout} seconds.")
|
||||
|
||||
|
||||
class MockMotors(MockSerial):
|
||||
"""
|
||||
This class will simulate physical motors by responding with valid status packets upon receiving some
|
||||
instruction packets. It is meant to test MotorsBus classes.
|
||||
|
||||
'data_name' supported:
|
||||
- Present_Position
|
||||
"""
|
||||
|
||||
ctrl_table = SCS_SERIES_CONTROL_TABLE
|
||||
|
@ -220,6 +285,15 @@ class MockMotors(MockSerial):
|
|||
super().__init__()
|
||||
self.open()
|
||||
|
||||
@property
|
||||
def stubs(self) -> dict[str, WaitableStub]:
|
||||
return super().stubs
|
||||
|
||||
def stub(self, *, name=None, **kwargs):
|
||||
new_stub = WaitableStub(**kwargs)
|
||||
self._MockSerial__stubs[name or new_stub.receive_bytes] = new_stub
|
||||
return new_stub
|
||||
|
||||
def build_broadcast_ping_stub(
|
||||
self, ids_models_firmwares: dict[int, list[int]] | None = None, num_invalid_try: int = 0
|
||||
) -> str:
|
||||
|
@ -256,6 +330,10 @@ class MockMotors(MockSerial):
|
|||
def build_sync_read_stub(
|
||||
self, data_name: str, ids_values: dict[int, int] | None = None, num_invalid_try: int = 0
|
||||
) -> str:
|
||||
"""
|
||||
'data_name' supported:
|
||||
- Present_Position
|
||||
"""
|
||||
address, length = self.ctrl_table[data_name]
|
||||
sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length)
|
||||
if data_name != "Present_Position":
|
||||
|
@ -274,6 +352,19 @@ class MockMotors(MockSerial):
|
|||
)
|
||||
return stub_name
|
||||
|
||||
def build_sync_write_stub(
|
||||
self, data_name: str, ids_values: dict[int, int] | None = None, num_invalid_try: int = 0
|
||||
) -> str:
|
||||
address, length = self.ctrl_table[data_name]
|
||||
sync_read_request = MockInstructionPacket.sync_write(ids_values, address, length)
|
||||
stub_name = f"Sync_Write_{data_name}_" + "_".join([str(idx) for idx in ids_values])
|
||||
self.stub(
|
||||
name=stub_name,
|
||||
receive_bytes=sync_read_request,
|
||||
send_fn=self._build_send_fn(b"", num_invalid_try),
|
||||
)
|
||||
return stub_name
|
||||
|
||||
@staticmethod
|
||||
def _build_send_fn(packet: bytes, num_invalid_try: int = 0) -> Callable[[int], bytes]:
|
||||
def send_fn(_call_count: int) -> bytes:
|
||||
|
|
|
@ -235,3 +235,54 @@ def test_read_num_retry(num_retry, num_invalid_try, pos, dummy_motors):
|
|||
|
||||
expected_calls = min(1 + num_retry, 1 + num_invalid_try)
|
||||
assert mock_motors.stubs[stub_name].calls == expected_calls
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"motors",
|
||||
[
|
||||
[1, 2, 3],
|
||||
["dummy_1", "dummy_2", "dummy_3"],
|
||||
[1, "dummy_2", 3],
|
||||
],
|
||||
ids=["by ids", "by names", "mixed"],
|
||||
)
|
||||
def test_write_all_motors(motors, dummy_motors):
|
||||
mock_motors = MockMotors()
|
||||
goal_positions = {
|
||||
1: 1337,
|
||||
2: 42,
|
||||
3: 4016,
|
||||
}
|
||||
stub_name = mock_motors.build_sync_write_stub("Goal_Position", goal_positions)
|
||||
motors_bus = FeetechMotorsBus(
|
||||
port=mock_motors.port,
|
||||
motors=dummy_motors,
|
||||
)
|
||||
motors_bus.connect()
|
||||
|
||||
values = dict(zip(motors, goal_positions.values(), strict=True))
|
||||
motors_bus.write("Goal_Position", values)
|
||||
|
||||
assert mock_motors.stubs[stub_name].wait_called()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"data_name, value",
|
||||
[
|
||||
["Torque_Enable", 0],
|
||||
["Torque_Enable", 1],
|
||||
],
|
||||
)
|
||||
def test_write_all_motors_single_value(data_name, value, dummy_motors):
|
||||
mock_motors = MockMotors()
|
||||
values = {m.id: value for m in dummy_motors.values()}
|
||||
stub_name = mock_motors.build_sync_write_stub(data_name, values)
|
||||
motors_bus = FeetechMotorsBus(
|
||||
port=mock_motors.port,
|
||||
motors=dummy_motors,
|
||||
)
|
||||
motors_bus.connect()
|
||||
|
||||
motors_bus.write(data_name, value)
|
||||
|
||||
assert mock_motors.stubs[stub_name].wait_called()
|
||||
|
|
Loading…
Reference in New Issue