Rename read/write -> sync_read/write, refactor, add write
This commit is contained in:
parent
a2f5c34625
commit
5a57e6f4a7
|
@ -55,8 +55,8 @@ class DynamixelMotorsBus(MotorsBus):
|
|||
|
||||
self.port_handler = dxl.PortHandler(self.port)
|
||||
self.packet_handler = dxl.PacketHandler(PROTOCOL_VERSION)
|
||||
self.reader = dxl.GroupSyncRead(self.port_handler, self.packet_handler, 0, 0)
|
||||
self.writer = dxl.GroupSyncWrite(self.port_handler, self.packet_handler, 0, 0)
|
||||
self.sync_reader = dxl.GroupSyncRead(self.port_handler, self.packet_handler, 0, 0)
|
||||
self.sync_writer = dxl.GroupSyncWrite(self.port_handler, self.packet_handler, 0, 0)
|
||||
|
||||
def broadcast_ping(
|
||||
self, num_retry: int = 0, raise_on_error: bool = False
|
||||
|
@ -82,6 +82,9 @@ class DynamixelMotorsBus(MotorsBus):
|
|||
|
||||
return comm == dxl.COMM_SUCCESS
|
||||
|
||||
def _is_error(self, error: int) -> bool:
|
||||
return error != 0x00
|
||||
|
||||
@staticmethod
|
||||
def split_int_bytes(value: int, n_bytes: int) -> list[int]:
|
||||
# Validate input
|
||||
|
|
|
@ -51,8 +51,8 @@ class FeetechMotorsBus(MotorsBus):
|
|||
|
||||
self.port_handler = scs.PortHandler(self.port)
|
||||
self.packet_handler = scs.PacketHandler(PROTOCOL_VERSION)
|
||||
self.reader = scs.GroupSyncRead(self.port_handler, self.packet_handler, 0, 0)
|
||||
self.writer = scs.GroupSyncWrite(self.port_handler, self.packet_handler, 0, 0)
|
||||
self.sync_reader = scs.GroupSyncRead(self.port_handler, self.packet_handler, 0, 0)
|
||||
self.sync_writer = scs.GroupSyncWrite(self.port_handler, self.packet_handler, 0, 0)
|
||||
|
||||
def broadcast_ping(self, num_retry: int | None = None):
|
||||
raise NotImplementedError # TODO
|
||||
|
@ -70,6 +70,9 @@ class FeetechMotorsBus(MotorsBus):
|
|||
|
||||
return comm == scs.COMM_SUCCESS
|
||||
|
||||
def _is_error(self, error: int) -> bool:
|
||||
return error != 0x00
|
||||
|
||||
@staticmethod
|
||||
def split_int_bytes(value: int, n_bytes: int) -> list[int]:
|
||||
# Validate input
|
||||
|
|
|
@ -256,8 +256,8 @@ class MotorsBus(abc.ABC):
|
|||
|
||||
self.port_handler: PortHandler
|
||||
self.packet_handler: PacketHandler
|
||||
self.reader: GroupSyncRead
|
||||
self.writer: GroupSyncWrite
|
||||
self.sync_reader: GroupSyncRead
|
||||
self.sync_writer: GroupSyncWrite
|
||||
|
||||
self.calibration = None
|
||||
|
||||
|
@ -347,7 +347,7 @@ class MotorsBus(abc.ABC):
|
|||
"""
|
||||
try:
|
||||
# TODO(aliberts): use ping instead
|
||||
return (self.ids == self.read("ID")).all()
|
||||
return (self.ids == self.sync_read("ID")).all()
|
||||
except ConnectionError as e:
|
||||
logger.error(e)
|
||||
return False
|
||||
|
@ -395,6 +395,10 @@ class MotorsBus(abc.ABC):
|
|||
def _is_comm_success(self, comm: int) -> bool:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _is_error(self, error: int) -> bool:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abc.abstractmethod
|
||||
def split_int_bytes(value: int, n_bytes: int) -> list[int]:
|
||||
|
@ -442,12 +446,12 @@ class MotorsBus(abc.ABC):
|
|||
raise TypeError(f"'{motor}' should be int, str.")
|
||||
|
||||
@overload
|
||||
def read(self, data_name: str, motors: None = ..., num_retry: int = ...) -> dict[str, Value]: ...
|
||||
def sync_read(self, data_name: str, motors: None = ..., num_retry: int = ...) -> dict[str, Value]: ...
|
||||
@overload
|
||||
def read(
|
||||
def sync_read(
|
||||
self, data_name: str, motors: NameOrID | list[NameOrID], num_retry: int = ...
|
||||
) -> dict[NameOrID, Value]: ...
|
||||
def read(
|
||||
def sync_read(
|
||||
self, data_name: str, motors: NameOrID | list[NameOrID] | None = None, num_retry: int = 0
|
||||
) -> dict[NameOrID, Value]:
|
||||
if not self.is_connected:
|
||||
|
@ -466,17 +470,11 @@ class MotorsBus(abc.ABC):
|
|||
raise TypeError(motors)
|
||||
|
||||
motor_ids = list(id_key_map)
|
||||
if self._has_different_ctrl_tables:
|
||||
models = [self.id_to_model(idx) for idx in motor_ids]
|
||||
assert_same_address(self.model_ctrl_table, models, data_name)
|
||||
|
||||
model = self.id_to_model(next(iter(motor_ids)))
|
||||
addr, n_bytes = self.model_ctrl_table[model][data_name]
|
||||
|
||||
comm, ids_values = self._read(motor_ids, addr, n_bytes, num_retry)
|
||||
comm, ids_values = self._sync_read(data_name, motor_ids, num_retry)
|
||||
if not self._is_comm_success(comm):
|
||||
raise ConnectionError(
|
||||
f"Failed to read {data_name} on port {self.port} for ids {motor_ids}:"
|
||||
f"Failed to sync read '{data_name}' on {motor_ids=} after {num_retry + 1} tries."
|
||||
f"{self.packet_handler.getTxRxResult(comm)}"
|
||||
)
|
||||
|
||||
|
@ -485,40 +483,50 @@ class MotorsBus(abc.ABC):
|
|||
|
||||
return {id_key_map[idx]: val for idx, val in ids_values.items()}
|
||||
|
||||
def _read(
|
||||
self, motor_ids: list[str], address: int, n_bytes: int, num_retry: int = 0
|
||||
def _sync_read(
|
||||
self, data_name: str, motor_ids: list[str], num_retry: int = 0
|
||||
) -> tuple[int, dict[int, int]]:
|
||||
self.reader.clearParam()
|
||||
self.reader.start_address = address
|
||||
self.reader.data_length = n_bytes
|
||||
if self._has_different_ctrl_tables:
|
||||
models = [self.id_to_model(idx) for idx in motor_ids]
|
||||
assert_same_address(self.model_ctrl_table, models, data_name)
|
||||
|
||||
model = self.id_to_model(next(iter(motor_ids)))
|
||||
addr, n_bytes = self.model_ctrl_table[model][data_name]
|
||||
self._setup_sync_reader(motor_ids, addr, n_bytes)
|
||||
|
||||
# FIXME(aliberts, pkooij): We should probably not have to do this.
|
||||
# Let's try to see if we can do with better comm status handling instead.
|
||||
# self.port_handler.ser.reset_output_buffer()
|
||||
# self.port_handler.ser.reset_input_buffer()
|
||||
|
||||
for idx in motor_ids:
|
||||
self.reader.addParam(idx)
|
||||
|
||||
for n_try in range(1 + num_retry):
|
||||
comm = self.reader.txRxPacket()
|
||||
comm = self.sync_reader.txRxPacket()
|
||||
if self._is_comm_success(comm):
|
||||
break
|
||||
logger.debug(f"ids={list(motor_ids)} @{address} ({n_bytes} bytes) {n_try=} got {comm=}")
|
||||
logger.debug(f"Failed to sync read '{data_name}' ({addr=} {n_bytes=}) on {motor_ids=} ({n_try=})")
|
||||
logger.debug(self.packet_handler.getRxPacketError(comm))
|
||||
|
||||
values = {idx: self.reader.getData(idx, address, n_bytes) for idx in motor_ids}
|
||||
values = {idx: self.sync_reader.getData(idx, addr, n_bytes) for idx in motor_ids}
|
||||
return comm, values
|
||||
|
||||
# TODO(aliberts, pkooij): Implementing something like this could get much faster read times.
|
||||
# Note: this could be at the cost of increase latency between the moment the data is produced by the
|
||||
# motors and the moment it is used by a policy
|
||||
def _setup_sync_reader(self, motor_ids: list[str], addr: int, n_bytes: int) -> None:
|
||||
self.sync_reader.clearParam()
|
||||
self.sync_reader.start_address = addr
|
||||
self.sync_reader.data_length = n_bytes
|
||||
for idx in motor_ids:
|
||||
self.sync_reader.addParam(idx)
|
||||
|
||||
# TODO(aliberts, pkooij): Implementing something like this could get even much faster read times if need be.
|
||||
# Would have to handle the logic of checking if a packet has been sent previously though but doable.
|
||||
# This could be at the cost of increase latency between the moment the data is produced by the motors and
|
||||
# the moment it is used by a policy.
|
||||
# def _async_read(self, motor_ids: list[str], address: int, n_bytes: int):
|
||||
# self.reader.rxPacket()
|
||||
# self.reader.txPacket()
|
||||
# for idx in motor_ids:
|
||||
# value = self.reader.getData(idx, address, n_bytes)
|
||||
|
||||
def write(self, data_name: str, values: Value | dict[NameOrID, Value], num_retry: int = 0) -> None:
|
||||
def sync_write(self, data_name: str, values: Value | dict[NameOrID, Value], num_retry: int = 0) -> None:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
|
||||
|
@ -531,40 +539,84 @@ class MotorsBus(abc.ABC):
|
|||
else:
|
||||
raise ValueError(f"'values' is expected to be a single value or a dict. Got {values}")
|
||||
|
||||
if data_name in self.calibration_required and self.calibration is not None:
|
||||
ids_values = self.uncalibrate_values(ids_values)
|
||||
|
||||
comm = self._sync_write(data_name, ids_values, num_retry)
|
||||
if not self._is_comm_success(comm):
|
||||
raise ConnectionError(
|
||||
f"Failed to sync write '{data_name}' with {ids_values=} after {num_retry + 1} tries."
|
||||
f"\n{self.packet_handler.getTxRxResult(comm)}"
|
||||
)
|
||||
|
||||
def _sync_write(self, data_name: str, ids_values: dict[int, int], num_retry: int = 0) -> int:
|
||||
if self._has_different_ctrl_tables:
|
||||
models = [self.id_to_model(idx) for idx in ids_values]
|
||||
assert_same_address(self.model_ctrl_table, models, data_name)
|
||||
|
||||
if data_name in self.calibration_required and self.calibration is not None:
|
||||
ids_values = self.uncalibrate_values(ids_values)
|
||||
|
||||
model = self.id_to_model(next(iter(ids_values)))
|
||||
addr, n_bytes = self.model_ctrl_table[model][data_name]
|
||||
|
||||
comm = self._write(ids_values, addr, n_bytes, num_retry)
|
||||
if not self._is_comm_success(comm):
|
||||
raise ConnectionError(
|
||||
f"Failed to write {data_name} on port {self.port} for ids {list(ids_values)}:"
|
||||
f"{self.packet_handler.getTxRxResult(comm)}"
|
||||
)
|
||||
|
||||
def _write(self, ids_values: dict[int, int], address: int, n_bytes: int, num_retry: int = 0) -> int:
|
||||
self.writer.clearParam()
|
||||
self.writer.start_address = address
|
||||
self.writer.data_length = n_bytes
|
||||
|
||||
for idx, value in ids_values.items():
|
||||
data = self.split_int_bytes(value, n_bytes)
|
||||
self.writer.addParam(idx, data)
|
||||
self._setup_sync_writer(ids_values, addr, n_bytes)
|
||||
|
||||
for n_try in range(1 + num_retry):
|
||||
comm = self.writer.txPacket()
|
||||
comm = self.sync_writer.txPacket()
|
||||
if self._is_comm_success(comm):
|
||||
break
|
||||
logger.debug(f"ids={list(ids_values)} @{address} ({n_bytes} bytes) {n_try=} got {comm=}")
|
||||
logger.debug(
|
||||
f"Failed to sync write '{data_name}' ({addr=} {n_bytes=}) with {ids_values=} ({n_try=})"
|
||||
)
|
||||
logger.debug(self.packet_handler.getRxPacketError(comm))
|
||||
|
||||
return comm
|
||||
|
||||
def _setup_sync_writer(self, ids_values: dict[int, int], addr: int, n_bytes: int) -> None:
|
||||
self.sync_writer.clearParam()
|
||||
self.sync_writer.start_address = addr
|
||||
self.sync_writer.data_length = n_bytes
|
||||
for idx, value in ids_values.items():
|
||||
data = self.split_int_bytes(value, n_bytes)
|
||||
self.sync_writer.addParam(idx, data)
|
||||
|
||||
def write(self, data_name: str, motor: NameOrID, value: Value, num_retry: int = 0) -> None:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
|
||||
)
|
||||
|
||||
idx = self.get_motor_id(motor)
|
||||
|
||||
if data_name in self.calibration_required and self.calibration is not None:
|
||||
id_value = self.uncalibrate_values({idx: value})
|
||||
value = id_value[idx]
|
||||
|
||||
comm, error = self._write(data_name, idx, value, num_retry)
|
||||
if not self._is_comm_success(comm):
|
||||
raise ConnectionError(
|
||||
f"Failed to write '{data_name}' on {idx=} with '{value}' after {num_retry + 1} tries."
|
||||
f"\n{self.packet_handler.getTxRxResult(comm)}"
|
||||
)
|
||||
elif self._is_error(error):
|
||||
raise RuntimeError(
|
||||
f"Failed to write '{data_name}' on {idx=} with '{value}' after {num_retry + 1} tries."
|
||||
f"\n{self.packet_handler.getRxPacketError(error)}"
|
||||
)
|
||||
|
||||
def _write(self, data_name: str, motor_id: int, value: int, num_retry: int = 0) -> tuple[int, int]:
|
||||
model = self.id_to_model(motor_id)
|
||||
addr, n_bytes = self.model_ctrl_table[model][data_name]
|
||||
data = self.split_int_bytes(value, n_bytes)
|
||||
|
||||
for n_try in range(1 + num_retry):
|
||||
comm, error = self.packet_handler.writeTxRx(self.port_handler, motor_id, addr, n_bytes, data)
|
||||
if self._is_comm_success(comm):
|
||||
break
|
||||
logger.debug(
|
||||
f"Failed to write '{data_name}' ({addr=} {n_bytes=}) on {motor_id=} with '{value}' ({n_try=})"
|
||||
)
|
||||
logger.debug(self.packet_handler.getRxPacketError(comm))
|
||||
|
||||
return comm, error
|
||||
|
||||
def disconnect(self) -> None:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
|
|
|
@ -156,7 +156,7 @@ def test_read_all_motors(motors, mock_motors, dummy_motors):
|
|||
)
|
||||
motors_bus.connect()
|
||||
|
||||
positions_read = motors_bus.read("Present_Position", motors=motors)
|
||||
positions_read = motors_bus.sync_read("Present_Position", motors=motors)
|
||||
|
||||
motors = ["dummy_1", "dummy_2", "dummy_3"] if motors is None else motors
|
||||
assert mock_motors.stubs[stub_name].called
|
||||
|
@ -180,7 +180,7 @@ def test_read_single_motor_by_name(idx, pos, mock_motors, dummy_motors):
|
|||
)
|
||||
motors_bus.connect()
|
||||
|
||||
pos_dict = motors_bus.read("Present_Position", f"dummy_{idx}")
|
||||
pos_dict = motors_bus.sync_read("Present_Position", f"dummy_{idx}")
|
||||
|
||||
assert mock_motors.stubs[stub_name].called
|
||||
assert pos_dict == {f"dummy_{idx}": pos}
|
||||
|
@ -203,7 +203,7 @@ def test_read_single_motor_by_id(idx, pos, mock_motors, dummy_motors):
|
|||
)
|
||||
motors_bus.connect()
|
||||
|
||||
pos_dict = motors_bus.read("Present_Position", idx)
|
||||
pos_dict = motors_bus.sync_read("Present_Position", idx)
|
||||
|
||||
assert mock_motors.stubs[stub_name].called
|
||||
assert pos_dict == {idx: pos}
|
||||
|
@ -230,11 +230,11 @@ def test_read_num_retry(num_retry, num_invalid_try, pos, mock_motors, dummy_moto
|
|||
motors_bus.connect()
|
||||
|
||||
if num_retry >= num_invalid_try:
|
||||
pos_dict = motors_bus.read("Present_Position", 1, num_retry=num_retry)
|
||||
pos_dict = motors_bus.sync_read("Present_Position", 1, num_retry=num_retry)
|
||||
assert pos_dict == {1: pos}
|
||||
else:
|
||||
with pytest.raises(ConnectionError):
|
||||
_ = motors_bus.read("Present_Position", 1, num_retry=num_retry)
|
||||
_ = motors_bus.sync_read("Present_Position", 1, num_retry=num_retry)
|
||||
|
||||
expected_calls = min(1 + num_retry, 1 + num_invalid_try)
|
||||
assert mock_motors.stubs[stub_name].calls == expected_calls
|
||||
|
@ -263,7 +263,7 @@ def test_write_all_motors(motors, mock_motors, dummy_motors):
|
|||
motors_bus.connect()
|
||||
|
||||
values = dict(zip(motors, goal_positions.values(), strict=True))
|
||||
motors_bus.write("Goal_Position", values)
|
||||
motors_bus.sync_write("Goal_Position", values)
|
||||
|
||||
assert mock_motors.stubs[stub_name].wait_called()
|
||||
|
||||
|
@ -284,6 +284,6 @@ def test_write_all_motors_single_value(data_name, value, mock_motors, dummy_moto
|
|||
)
|
||||
motors_bus.connect()
|
||||
|
||||
motors_bus.write(data_name, value)
|
||||
motors_bus.sync_write(data_name, value)
|
||||
|
||||
assert mock_motors.stubs[stub_name].wait_called()
|
||||
|
|
|
@ -158,7 +158,7 @@ def test_read_all_motors(motors, mock_motors, dummy_motors):
|
|||
)
|
||||
motors_bus.connect()
|
||||
|
||||
positions_read = motors_bus.read("Present_Position", motors=motors)
|
||||
positions_read = motors_bus.sync_read("Present_Position", motors=motors)
|
||||
|
||||
motors = ["dummy_1", "dummy_2", "dummy_3"] if motors is None else motors
|
||||
assert mock_motors.stubs[stub_name].called
|
||||
|
@ -182,7 +182,7 @@ def test_read_single_motor_by_name(idx, pos, mock_motors, dummy_motors):
|
|||
)
|
||||
motors_bus.connect()
|
||||
|
||||
pos_dict = motors_bus.read("Present_Position", f"dummy_{idx}")
|
||||
pos_dict = motors_bus.sync_read("Present_Position", f"dummy_{idx}")
|
||||
|
||||
assert mock_motors.stubs[stub_name].called
|
||||
assert pos_dict == {f"dummy_{idx}": pos}
|
||||
|
@ -205,7 +205,7 @@ def test_read_single_motor_by_id(idx, pos, mock_motors, dummy_motors):
|
|||
)
|
||||
motors_bus.connect()
|
||||
|
||||
pos_dict = motors_bus.read("Present_Position", idx)
|
||||
pos_dict = motors_bus.sync_read("Present_Position", idx)
|
||||
|
||||
assert mock_motors.stubs[stub_name].called
|
||||
assert pos_dict == {idx: pos}
|
||||
|
@ -232,11 +232,11 @@ def test_read_num_retry(num_retry, num_invalid_try, pos, mock_motors, dummy_moto
|
|||
motors_bus.connect()
|
||||
|
||||
if num_retry >= num_invalid_try:
|
||||
pos_dict = motors_bus.read("Present_Position", 1, num_retry=num_retry)
|
||||
pos_dict = motors_bus.sync_read("Present_Position", 1, num_retry=num_retry)
|
||||
assert pos_dict == {1: pos}
|
||||
else:
|
||||
with pytest.raises(ConnectionError):
|
||||
_ = motors_bus.read("Present_Position", 1, num_retry=num_retry)
|
||||
_ = motors_bus.sync_read("Present_Position", 1, num_retry=num_retry)
|
||||
|
||||
expected_calls = min(1 + num_retry, 1 + num_invalid_try)
|
||||
assert mock_motors.stubs[stub_name].calls == expected_calls
|
||||
|
@ -265,7 +265,7 @@ def test_write_all_motors(motors, mock_motors, dummy_motors):
|
|||
motors_bus.connect()
|
||||
|
||||
values = dict(zip(motors, goal_positions.values(), strict=True))
|
||||
motors_bus.write("Goal_Position", values)
|
||||
motors_bus.sync_write("Goal_Position", values)
|
||||
|
||||
assert mock_motors.stubs[stub_name].wait_called()
|
||||
|
||||
|
@ -286,6 +286,6 @@ def test_write_all_motors_single_value(data_name, value, mock_motors, dummy_moto
|
|||
)
|
||||
motors_bus.connect()
|
||||
|
||||
motors_bus.write(data_name, value)
|
||||
motors_bus.sync_write(data_name, value)
|
||||
|
||||
assert mock_motors.stubs[stub_name].wait_called()
|
||||
|
|
Loading…
Reference in New Issue