Add dxl write tests

This commit is contained in:
Simon Alibert 2025-03-22 14:50:05 +01:00
parent f2ed2bfb2f
commit 8ca03a7255
2 changed files with 159 additions and 5 deletions

View File

@ -1,12 +1,14 @@
import abc import abc
import random import random
import threading
import time
from typing import Callable from typing import Callable
import dynamixel_sdk as dxl import dynamixel_sdk as dxl
import serial import serial
from mock_serial import MockSerial from mock_serial.mock_serial import MockSerial, Stub
from lerobot.common.motors.dynamixel import X_SERIES_CONTROL_TABLE from lerobot.common.motors.dynamixel import X_SERIES_CONTROL_TABLE, DynamixelMotorsBus
# https://emanual.robotis.com/docs/en/dxl/crc/ # https://emanual.robotis.com/docs/en/dxl/crc/
DXL_CRC_TABLE = [ DXL_CRC_TABLE = [
@ -245,6 +247,53 @@ class MockInstructionPacket(MockDynamixelPacketv2):
length = len(dxl_ids) + 7 length = len(dxl_ids) + 7
return cls.build(dxl_id=dxl.BROADCAST_ID, params=params, length=length, instruct_type="Sync_Read") return cls.build(dxl_id=dxl.BROADCAST_ID, params=params, length=length, instruct_type="Sync_Read")
@classmethod
def sync_write(
cls,
ids_values: dict[int],
start_address: int,
data_length: int,
) -> bytes:
"""
Builds a "Sync_Write" broadcast instruction.
(from https://emanual.robotis.com/docs/en/dxl/protocol2/#sync-write-0x83)
The parameters for Sync_Write (Protocol 2.0) are:
param[0] = start_address L
param[1] = start_address H
param[2] = data_length L
param[3] = data_length H
param[5] = [1st motor] ID
param[5+1] = [1st motor] 1st Byte
param[5+2] = [1st motor] 2nd Byte
...
param[5+X] = [1st motor] X-th Byte
param[6] = [2nd motor] ID
param[6+1] = [2nd motor] 1st Byte
param[6+2] = [2nd motor] 2nd Byte
...
param[6+X] = [2nd motor] X-th Byte
And 'length' = ((number_of_params * 1 + data_length) + 7), where:
+1 is for instruction byte,
+2 is for the address bytes,
+2 is for the length bytes,
+2 is for the CRC at the end.
"""
data = []
for idx, value in ids_values.items():
split_value = DynamixelMotorsBus.split_int_bytes(value, data_length)
data += [idx, *split_value]
params = [
dxl.DXL_LOBYTE(start_address),
dxl.DXL_HIBYTE(start_address),
dxl.DXL_LOBYTE(data_length),
dxl.DXL_HIBYTE(data_length),
*data,
]
length = len(ids_values) * (1 + data_length) + 7
return cls.build(dxl_id=dxl.BROADCAST_ID, params=params, length=length, instruct_type="Sync_Write")
class MockStatusPacket(MockDynamixelPacketv2): class MockStatusPacket(MockDynamixelPacketv2):
""" """
@ -337,13 +386,38 @@ class MockPortHandler(dxl.PortHandler):
return True return True
class WaitableStub(Stub):
"""
In some situations, a test might be checking if a stub has been called before `MockSerial` thread had time
to read, match, and call the stub. In these situations, the test can fail randomly.
Use `wait_called()` or `wait_calls()` to block until the stub is called, avoiding race conditions.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._event = threading.Event()
def call(self):
self._event.set()
return super().call()
def wait_called(self, timeout: float = 1.0):
return self._event.wait(timeout)
def wait_calls(self, min_calls: int = 1, timeout: float = 1.0):
start = time.perf_counter()
while time.perf_counter() - start < timeout:
if self.calls >= min_calls:
return self.calls
time.sleep(0.005)
raise TimeoutError(f"Stub not called {min_calls} times within {timeout} seconds.")
class MockMotors(MockSerial): class MockMotors(MockSerial):
""" """
This class will simulate physical motors by responding with valid status packets upon receiving some This class will simulate physical motors by responding with valid status packets upon receiving some
instruction packets. It is meant to test MotorsBus classes. instruction packets. It is meant to test MotorsBus classes.
'data_name' supported:
- Present_Position
""" """
ctrl_table = X_SERIES_CONTROL_TABLE ctrl_table = X_SERIES_CONTROL_TABLE
@ -352,6 +426,15 @@ class MockMotors(MockSerial):
super().__init__() super().__init__()
self.open() self.open()
@property
def stubs(self) -> dict[str, WaitableStub]:
return super().stubs
def stub(self, *, name=None, **kwargs):
new_stub = WaitableStub(**kwargs)
self._MockSerial__stubs[name or new_stub.receive_bytes] = new_stub
return new_stub
def build_broadcast_ping_stub( def build_broadcast_ping_stub(
self, ids_models_firmwares: dict[int, list[int]] | None = None, num_invalid_try: int = 0 self, ids_models_firmwares: dict[int, list[int]] | None = None, num_invalid_try: int = 0
) -> str: ) -> str:
@ -388,6 +471,10 @@ class MockMotors(MockSerial):
def build_sync_read_stub( def build_sync_read_stub(
self, data_name: str, ids_values: dict[int, int] | None = None, num_invalid_try: int = 0 self, data_name: str, ids_values: dict[int, int] | None = None, num_invalid_try: int = 0
) -> str: ) -> str:
"""
'data_name' supported:
- Present_Position
"""
address, length = self.ctrl_table[data_name] address, length = self.ctrl_table[data_name]
sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length) sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length)
if data_name != "Present_Position": if data_name != "Present_Position":
@ -406,6 +493,22 @@ class MockMotors(MockSerial):
) )
return stub_name return stub_name
def build_sync_write_stub(
self, data_name: str, ids_values: dict[int, int] | None = None, num_invalid_try: int = 0
) -> str:
address, length = self.ctrl_table[data_name]
sync_read_request = MockInstructionPacket.sync_write(ids_values, address, length)
# if data_name != "Goal_Position":
# raise NotImplementedError
stub_name = f"Sync_Write_{data_name}_" + "_".join([str(idx) for idx in ids_values])
self.stub(
name=stub_name,
receive_bytes=sync_read_request,
send_fn=self._build_send_fn(b"", num_invalid_try),
)
return stub_name
@staticmethod @staticmethod
def _build_send_fn(packet: bytes, num_invalid_try: int = 0) -> Callable[[int], bytes]: def _build_send_fn(packet: bytes, num_invalid_try: int = 0) -> Callable[[int], bytes]:
def send_fn(_call_count: int) -> bytes: def send_fn(_call_count: int) -> bytes:

View File

@ -233,3 +233,54 @@ def test_read_num_retry(num_retry, num_invalid_try, pos, dummy_motors):
expected_calls = min(1 + num_retry, 1 + num_invalid_try) expected_calls = min(1 + num_retry, 1 + num_invalid_try)
assert mock_motors.stubs[stub_name].calls == expected_calls assert mock_motors.stubs[stub_name].calls == expected_calls
@pytest.mark.parametrize(
"motors",
[
[1, 2, 3],
["dummy_1", "dummy_2", "dummy_3"],
[1, "dummy_2", 3],
],
ids=["by ids", "by names", "mixed"],
)
def test_write_all_motors(motors, dummy_motors):
mock_motors = MockMotors()
goal_positions = {
1: 1337,
2: 42,
3: 4016,
}
stub_name = mock_motors.build_sync_write_stub("Goal_Position", goal_positions)
motors_bus = DynamixelMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
motors_bus.connect()
values = dict(zip(motors, goal_positions.values(), strict=True))
motors_bus.write("Goal_Position", values)
assert mock_motors.stubs[stub_name].wait_called()
@pytest.mark.parametrize(
"data_name, value",
[
["Torque_Enable", 0],
["Torque_Enable", 1],
],
)
def test_write_all_motors_single_value(data_name, value, dummy_motors):
mock_motors = MockMotors()
values = {m.id: value for m in dummy_motors.values()}
stub_name = mock_motors.build_sync_write_stub(data_name, values)
motors_bus = DynamixelMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
motors_bus.connect()
motors_bus.write(data_name, value)
assert mock_motors.stubs[stub_name].wait_called()