Add dxl write tests

This commit is contained in:
Simon Alibert 2025-03-22 14:50:05 +01:00
parent f2ed2bfb2f
commit 8ca03a7255
2 changed files with 159 additions and 5 deletions

View File

@ -1,12 +1,14 @@
import abc
import random
import threading
import time
from typing import Callable
import dynamixel_sdk as dxl
import serial
from mock_serial import MockSerial
from mock_serial.mock_serial import MockSerial, Stub
from lerobot.common.motors.dynamixel import X_SERIES_CONTROL_TABLE
from lerobot.common.motors.dynamixel import X_SERIES_CONTROL_TABLE, DynamixelMotorsBus
# https://emanual.robotis.com/docs/en/dxl/crc/
DXL_CRC_TABLE = [
@ -245,6 +247,53 @@ class MockInstructionPacket(MockDynamixelPacketv2):
length = len(dxl_ids) + 7
return cls.build(dxl_id=dxl.BROADCAST_ID, params=params, length=length, instruct_type="Sync_Read")
@classmethod
def sync_write(
cls,
ids_values: dict[int],
start_address: int,
data_length: int,
) -> bytes:
"""
Builds a "Sync_Write" broadcast instruction.
(from https://emanual.robotis.com/docs/en/dxl/protocol2/#sync-write-0x83)
The parameters for Sync_Write (Protocol 2.0) are:
param[0] = start_address L
param[1] = start_address H
param[2] = data_length L
param[3] = data_length H
param[5] = [1st motor] ID
param[5+1] = [1st motor] 1st Byte
param[5+2] = [1st motor] 2nd Byte
...
param[5+X] = [1st motor] X-th Byte
param[6] = [2nd motor] ID
param[6+1] = [2nd motor] 1st Byte
param[6+2] = [2nd motor] 2nd Byte
...
param[6+X] = [2nd motor] X-th Byte
And 'length' = ((number_of_params * 1 + data_length) + 7), where:
+1 is for instruction byte,
+2 is for the address bytes,
+2 is for the length bytes,
+2 is for the CRC at the end.
"""
data = []
for idx, value in ids_values.items():
split_value = DynamixelMotorsBus.split_int_bytes(value, data_length)
data += [idx, *split_value]
params = [
dxl.DXL_LOBYTE(start_address),
dxl.DXL_HIBYTE(start_address),
dxl.DXL_LOBYTE(data_length),
dxl.DXL_HIBYTE(data_length),
*data,
]
length = len(ids_values) * (1 + data_length) + 7
return cls.build(dxl_id=dxl.BROADCAST_ID, params=params, length=length, instruct_type="Sync_Write")
class MockStatusPacket(MockDynamixelPacketv2):
"""
@ -337,13 +386,38 @@ class MockPortHandler(dxl.PortHandler):
return True
class WaitableStub(Stub):
"""
In some situations, a test might be checking if a stub has been called before `MockSerial` thread had time
to read, match, and call the stub. In these situations, the test can fail randomly.
Use `wait_called()` or `wait_calls()` to block until the stub is called, avoiding race conditions.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._event = threading.Event()
def call(self):
self._event.set()
return super().call()
def wait_called(self, timeout: float = 1.0):
return self._event.wait(timeout)
def wait_calls(self, min_calls: int = 1, timeout: float = 1.0):
start = time.perf_counter()
while time.perf_counter() - start < timeout:
if self.calls >= min_calls:
return self.calls
time.sleep(0.005)
raise TimeoutError(f"Stub not called {min_calls} times within {timeout} seconds.")
class MockMotors(MockSerial):
"""
This class will simulate physical motors by responding with valid status packets upon receiving some
instruction packets. It is meant to test MotorsBus classes.
'data_name' supported:
- Present_Position
"""
ctrl_table = X_SERIES_CONTROL_TABLE
@ -352,6 +426,15 @@ class MockMotors(MockSerial):
super().__init__()
self.open()
@property
def stubs(self) -> dict[str, WaitableStub]:
return super().stubs
def stub(self, *, name=None, **kwargs):
new_stub = WaitableStub(**kwargs)
self._MockSerial__stubs[name or new_stub.receive_bytes] = new_stub
return new_stub
def build_broadcast_ping_stub(
self, ids_models_firmwares: dict[int, list[int]] | None = None, num_invalid_try: int = 0
) -> str:
@ -388,6 +471,10 @@ class MockMotors(MockSerial):
def build_sync_read_stub(
self, data_name: str, ids_values: dict[int, int] | None = None, num_invalid_try: int = 0
) -> str:
"""
'data_name' supported:
- Present_Position
"""
address, length = self.ctrl_table[data_name]
sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length)
if data_name != "Present_Position":
@ -406,6 +493,22 @@ class MockMotors(MockSerial):
)
return stub_name
def build_sync_write_stub(
self, data_name: str, ids_values: dict[int, int] | None = None, num_invalid_try: int = 0
) -> str:
address, length = self.ctrl_table[data_name]
sync_read_request = MockInstructionPacket.sync_write(ids_values, address, length)
# if data_name != "Goal_Position":
# raise NotImplementedError
stub_name = f"Sync_Write_{data_name}_" + "_".join([str(idx) for idx in ids_values])
self.stub(
name=stub_name,
receive_bytes=sync_read_request,
send_fn=self._build_send_fn(b"", num_invalid_try),
)
return stub_name
@staticmethod
def _build_send_fn(packet: bytes, num_invalid_try: int = 0) -> Callable[[int], bytes]:
def send_fn(_call_count: int) -> bytes:

View File

@ -233,3 +233,54 @@ def test_read_num_retry(num_retry, num_invalid_try, pos, dummy_motors):
expected_calls = min(1 + num_retry, 1 + num_invalid_try)
assert mock_motors.stubs[stub_name].calls == expected_calls
@pytest.mark.parametrize(
"motors",
[
[1, 2, 3],
["dummy_1", "dummy_2", "dummy_3"],
[1, "dummy_2", 3],
],
ids=["by ids", "by names", "mixed"],
)
def test_write_all_motors(motors, dummy_motors):
mock_motors = MockMotors()
goal_positions = {
1: 1337,
2: 42,
3: 4016,
}
stub_name = mock_motors.build_sync_write_stub("Goal_Position", goal_positions)
motors_bus = DynamixelMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
motors_bus.connect()
values = dict(zip(motors, goal_positions.values(), strict=True))
motors_bus.write("Goal_Position", values)
assert mock_motors.stubs[stub_name].wait_called()
@pytest.mark.parametrize(
"data_name, value",
[
["Torque_Enable", 0],
["Torque_Enable", 1],
],
)
def test_write_all_motors_single_value(data_name, value, dummy_motors):
mock_motors = MockMotors()
values = {m.id: value for m in dummy_motors.values()}
stub_name = mock_motors.build_sync_write_stub(data_name, values)
motors_bus = DynamixelMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
motors_bus.connect()
motors_bus.write(data_name, value)
assert mock_motors.stubs[stub_name].wait_called()