Compare commits

..

42 Commits

Author SHA1 Message Date
Steven Palma a0657ee274
refactor(robots): multiple changes from feedback 2025-04-17 14:40:55 +02:00
Steven Palma d7b9866a7c
refactor(robots): update lekiwi for the latest motor bus api
chore(teleop): Add missing abstract methods to keyboard implementation

refactor(robots): update lekiwi client and host code for the new api

chore(config): update host lekiwi ip in client config

chore(examples): move application scripts to the examples directory

fix(motors): missing type check condition in set_half_turn_homings

fix(robots): fix assumption in calibrate() for robots with more than just an arm

fix(robot): change Mode to Operating_Mode in configure write for lekiwi

fix(robots): make sure message is display in calibrate() method lekiwi

fix(robots): no need for .tolist() in lekiwi host app

fix(teleop): fix is_connected in teleoperator keyboard

fix(teleop): always display calibration message in so100

fix(robots): fix send_action in lekiwi_client

debug(examples): configuration for lekiwi client app

fix(robots): fix send_action in lekiwi client part 2

refactor(robots): use dicts in lekiwi for get_obs and send_action

dbg(robots): check sent action wheels lekiwi

debug(robots): fix overflow base commands

debug(robots): fix how we deal with negative values lekiwi

debug(robots): lekiwi sign degrees fix

fix(robots): right motors id in lekiwi host

chore(doc): update todos

chore(doc): added todos
2025-04-17 14:37:57 +02:00
Simon Alibert 6cd06196c3
Group config files 2025-04-17 14:37:53 +02:00
Simon Alibert 4f5d840cac
Cleanup imports 2025-04-17 14:37:50 +02:00
Simon Alibert 7dedbeb457
Rename Lekiwi files & classes 2025-04-17 14:37:45 +02:00
Simon Alibert 6b4931b4f0
Update Lekiwi with new MotorsBus 2025-04-17 14:37:22 +02:00
Steven Palma a38e989cab
refactor(kiwi): update to latest motor API 2025-04-17 14:37:13 +02:00
Steven Palma 833ab383dd
chore(doc): update todos + license 2025-04-17 14:37:09 +02:00
Steven Palma 48b7e2a137
feat(lekiwi): Make dataset recording work 2025-04-17 14:37:05 +02:00
Steven Palma 2b100122f5
feat(lekiwi): de-couple classes + make it single-threaded 2025-04-17 14:36:58 +02:00
Steven Palma 66325b5a42
fix(lekiwi): fix calibration issue 2025-04-17 14:34:59 +02:00
Steven Palma dc3360c06b
fix(lekiwi): HW fixes v0.4 2025-04-17 14:34:55 +02:00
Steven Palma 66017f16a0
fix(lekiwi): HW fixes v0.3 2025-04-17 14:34:52 +02:00
Steven Palma 87b0a5995c
fix(lekiwi): HW fixes v0.2 2025-04-17 14:34:48 +02:00
pre-commit-ci[bot] cf35a5e986
fix(lekiwi): HW fixes v0.1 2025-04-17 14:34:44 +02:00
Steven Palma caa69be553
refactor(robots): lekiwi v0.5 2025-04-17 14:34:41 +02:00
Steven Palma a247e4b2be
refactor(robots): lekiwi v0.4 2025-04-17 14:34:35 +02:00
Steven Palma 5c925c779b
refactor(robots): lewiki v0.3 2025-04-17 14:34:29 +02:00
Steven Palma 73956e31b2
refactor(robots): lekiwi v0.2 2025-04-17 14:34:13 +02:00
Steven Palma d43f1a8136
refactor(robots): lewiki v0.1 2025-04-17 14:34:05 +02:00
Simon Alibert bf1c737858 Fix calibration msg display 2025-04-17 13:18:32 +02:00
Simon Alibert d07c7347f8 Add setup_motor 2025-04-17 13:14:06 +02:00
Simon Alibert 57e5e4cc07 Move read/write_calibration implementations 2025-04-16 11:23:33 +02:00
Simon Alibert 2743c29a96 Update feetech tables 2025-04-16 11:01:12 +02:00
Simon Alibert 2bb73ac431 Add torque_disabled context 2025-04-15 11:43:22 +02:00
Simon Alibert 9afc4b771c Motors config & disconnect fixes 2025-04-15 11:20:42 +02:00
Simon Alibert f71e224023 Fix tests 2025-04-15 11:18:44 +02:00
Simon Alibert 889de7c415 Add handshake, fix feetech _read_firmware_version 2025-04-14 17:14:06 +02:00
Simon Alibert 3539251b18 Merge remote-tracking branch 'origin/main' into user/aliberts/2025_02_25_refactor_robots 2025-04-14 15:30:35 +02:00
Simon Alibert 1f210bc8a3 Refactor tests 2025-04-14 15:26:29 +02:00
Simon Alibert d70bc4bde9 Add more segmented tests (dynamixel) 2025-04-14 15:16:38 +02:00
Simon Alibert bdbca09cb2 Add more segmented tests (base motor bus & feetech), add feetech protocol 1 support 2025-04-14 11:56:53 +02:00
Simon Alibert e0b292ab51 Remove test_motors_bus fixtures 2025-04-11 12:24:30 +02:00
Simon Alibert f960f4d8d4 Fix unormalize 2025-04-11 11:58:31 +02:00
Simon Alibert 9e57ec7837 Add support for feetech protocol 1 to _split_into_byte_chunks 2025-04-11 11:58:09 +02:00
Simon Alibert 0a7f51f0da Cleanup 2025-04-11 11:03:09 +02:00
Simon Alibert 4ca92a28e9 Make feetech broadcast ping faster in protocol 1 2025-04-11 11:02:54 +02:00
Simon Alibert 0464dc91b3 Add feetech sm8512bl 2025-04-11 11:02:01 +02:00
Simon Alibert d32daebf75 Refactor & add _serialize_data 2025-04-11 11:01:12 +02:00
Steven Palma 5322417c03
fix(examples): removes extra backtick (#948) 2025-04-09 17:44:32 +02:00
Steven Palma 4041f57943
feat(visualization): replace cv2 GUI with Rerun (and solves ffmpeg versioning issues) (#903) 2025-04-09 17:33:01 +02:00
Simon Alibert 2c86fea78a
Switch typos pre-commit to mirror (#953) 2025-04-08 12:44:09 +02:00
44 changed files with 1862 additions and 1188 deletions

2
.gitignore vendored
View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
.dev
# Logging
logs
tmp

View File

@ -36,8 +36,8 @@ repos:
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/crate-ci/typos
rev: v1
- repo: https://github.com/adhtruong/mirrors-typos
rev: v1.31.1
hooks:
- id: typos
args: [--force-exclude]

View File

@ -98,14 +98,14 @@ conda create -y -n lerobot python=3.10
conda activate lerobot
```
When using `miniconda`, if you don't have `ffmpeg` in your environment:
When using `miniconda`, install `ffmpeg` in your environment:
```bash
conda install ffmpeg
conda install ffmpeg -c conda-forge
```
Install 🤗 LeRobot:
```bash
pip install --no-binary=av -e .
pip install -e .
```
> **NOTE:** If you encounter build errors, you may need to install additional dependencies (`cmake`, `build-essential`, and `ffmpeg libs`). On Linux, run:
@ -118,7 +118,7 @@ For simulations, 🤗 LeRobot comes with gymnasium environments that can be inst
For instance, to install 🤗 LeRobot with aloha and pusht, use:
```bash
pip install --no-binary=av -e ".[aloha, pusht]"
pip install -e ".[aloha, pusht]"
```
To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiment tracking, log in with

View File

@ -17,12 +17,21 @@
import argparse
import datetime as dt
import os
import time
from pathlib import Path
import cv2
import rerun as rr
# see https://rerun.io/docs/howto/visualization/limit-ram
RERUN_MEMORY_LIMIT = os.getenv("LEROBOT_RERUN_MEMORY_LIMIT", "5%")
def display_and_save_video_stream(output_dir: Path, fps: int, width: int, height: int):
def display_and_save_video_stream(output_dir: Path, fps: int, width: int, height: int, duration: int):
rr.init("lerobot_capture_camera_feed")
rr.spawn(memory_limit=RERUN_MEMORY_LIMIT)
now = dt.datetime.now()
capture_dir = output_dir / f"{now:%Y-%m-%d}" / f"{now:%H-%M-%S}"
if not capture_dir.exists():
@ -39,24 +48,21 @@ def display_and_save_video_stream(output_dir: Path, fps: int, width: int, height
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
frame_index = 0
while True:
start_time = time.time()
while time.time() - start_time < duration:
ret, frame = cap.read()
if not ret:
print("Error: Could not read frame.")
break
cv2.imshow("Video Stream", frame)
rr.log("video/stream", rr.Image(frame.numpy()), static=True)
cv2.imwrite(str(capture_dir / f"frame_{frame_index:06d}.png"), frame)
frame_index += 1
# Break the loop on 'q' key press
if cv2.waitKey(1) & 0xFF == ord("q"):
break
# Release the capture and destroy all windows
# Release the capture
cap.release()
cv2.destroyAllWindows()
# TODO(Steven): Add a graceful shutdown via a close() method for the Viewer context, though not currently supported in the Rerun API.
if __name__ == "__main__":
@ -86,5 +92,11 @@ if __name__ == "__main__":
default=720,
help="Height of the captured images.",
)
parser.add_argument(
"--duration",
type=int,
default=20,
help="Duration in seconds for which the video stream should be captured.",
)
args = parser.parse_args()
display_and_save_video_stream(**vars(args))

View File

@ -18,7 +18,7 @@ training outputs directory. In the latter case, you might want to run examples/3
It requires the installation of the 'gym_pusht' simulation environment. Install it by running:
```bash
pip install --no-binary=av -e ".[pusht]"`
pip install -e ".[pusht]"
```
"""

View File

@ -33,7 +33,7 @@ First, install the additional dependencies required for robots built with dynami
Using `pip`:
```bash
pip install --no-binary=av -e ".[dynamixel]"
pip install -e ".[dynamixel]"
```
Using `poetry`:
@ -55,6 +55,9 @@ Finally, connect both arms to your computer via USB. Note that the USB doesn't p
Now you are ready to configure your motors for the first time, as detailed in the sections below. In the upcoming sections, you'll learn about our classes and functions by running some python code in an interactive session, or by copy-pasting it in a python file.
If you have already configured your motors the first time, you can streamline the process by directly running the teleoperate script (which is detailed further in the tutorial):
> **NOTE:** To visualize the data, enable `--control.display_data=true`. This streams the data using `rerun`.
```bash
python lerobot/scripts/control_robot.py \
--robot.type=koch \
@ -828,10 +831,10 @@ It contains:
Troubleshooting:
- On Linux, if you encounter any issue during video encoding with `ffmpeg: unknown encoder libsvtav1`, you can:
- install with conda-forge by running `conda install -c conda-forge ffmpeg` (it should be compiled with `libsvtav1`),
- or, install [Homebrew](https://brew.sh) and run `brew install ffmpeg` (it should be compiled with `libsvtav1`),
- or, install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1),
- and, make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`.
- install with conda-forge by running `conda install -c conda-forge ffmpeg` (it should be compiled with `libsvtav1`),
> **NOTE:** This usually installs `ffmpeg 7.X` for your platform (check the version installed with `ffmpeg -encoders | grep libsvtav1`). If it isn't `ffmpeg 7.X` or lacks `libsvtav1` support, you can explicitly install `ffmpeg 7.X` using: `conda install ffmpeg=7.1.1 -c conda-forge`
- or, install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1),
- and, make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`.
- On Linux, if the left and right arrow keys and escape key don't have any effect during data recording, make sure you've set the `$DISPLAY` environment variable. See [pynput limitations](https://pynput.readthedocs.io/en/latest/limitations.html#linux).
At the end of data recording, your dataset will be uploaded on your Hugging Face page (e.g. https://huggingface.co/datasets/cadene/koch_test) that you can obtain by running:

View File

@ -14,12 +14,12 @@
import logging
from lerobot.common.robots.config import RobotMode
from lerobot.common.robots.lekiwi.config_lekiwi import LeKiwiClientConfig
from lerobot.common.robots.lekiwi.config_lekiwi import LeKiwiClientConfig, RobotMode
from lerobot.common.robots.lekiwi.lekiwi_client import LeKiwiClient
from lerobot.common.teleoperators.keyboard import KeyboardTeleop, KeyboardTeleopConfig
from lerobot.common.teleoperators.so100 import SO100Leader, SO100LeaderConfig
# TODO(Steven): Check validity of these features
DUMMY_FEATURES = {
"observation.state": {
"dtype": "float64",
@ -119,6 +119,7 @@ def main():
# robot.set_mode(RobotMode.AUTO)
# policy_action = policy.get_action() # This might be in body frame, key space or smt else
# robot.send_action(policy_action)
action_sent = robot.send_action(action)
observation = robot.get_observation()

View File

@ -35,7 +35,7 @@ from .tables import (
)
PROTOCOL_VERSION = 2.0
BAUDRATE = 1_000_000
DEFAULT_BAUDRATE = 1_000_000
DEFAULT_TIMEOUT_MS = 1000
NORMALIZED_DATA = ["Goal_Position", "Present_Position"]
@ -84,6 +84,23 @@ class TorqueMode(Enum):
DISABLED = 0
def _split_into_byte_chunks(value: int, length: int) -> list[int]:
import dynamixel_sdk as dxl
if length == 1:
data = [value]
elif length == 2:
data = [dxl.DXL_LOBYTE(value), dxl.DXL_HIBYTE(value)]
elif length == 4:
data = [
dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)),
dxl.DXL_HIBYTE(dxl.DXL_LOWORD(value)),
dxl.DXL_LOBYTE(dxl.DXL_HIWORD(value)),
dxl.DXL_HIBYTE(dxl.DXL_HIWORD(value)),
]
return data
class DynamixelMotorsBus(MotorsBus):
"""
The Dynamixel implementation for a MotorsBus. It relies on the python dynamixel sdk to communicate with
@ -92,6 +109,7 @@ class DynamixelMotorsBus(MotorsBus):
"""
available_baudrates = deepcopy(AVAILABLE_BAUDRATES)
default_baudrate = DEFAULT_BAUDRATE
default_timeout = DEFAULT_TIMEOUT_MS
model_baudrate_table = deepcopy(MODEL_BAUDRATE_TABLE)
model_ctrl_table = deepcopy(MODEL_CONTROL_TABLE)
@ -119,19 +137,70 @@ class DynamixelMotorsBus(MotorsBus):
def _assert_protocol_is_compatible(self, instruction_name: str) -> None:
pass
def _handshake(self) -> None:
self._assert_motors_exist()
def _find_single_motor(self, motor: str, initial_baudrate: int | None) -> tuple[int, int]:
model = self.motors[motor].model
search_baudrates = (
[initial_baudrate] if initial_baudrate is not None else self.model_baudrate_table[model]
)
for baudrate in search_baudrates:
self.set_baudrate(baudrate)
id_model = self.broadcast_ping()
if id_model:
found_id, found_model = next(iter(id_model.items()))
expected_model_nb = self.model_number_table[model]
if found_model != expected_model_nb:
raise RuntimeError(
f"Found one motor on {baudrate=} with id={found_id} but it has a "
f"model number '{found_model}' different than the one expected: '{expected_model_nb}' "
f"Make sure you are connected only connected to the '{motor}' motor (model '{model}')."
)
return baudrate, found_id
raise RuntimeError(f"Motor '{motor}' (model '{model}') was not found. Make sure it is connected.")
def configure_motors(self) -> None:
# By default, Dynamixel motors have a 500µs delay response time (corresponding to a value of 250 on
# the 'Return_Delay_Time' address). We ensure this is reduced to the minimum of 2µs (value of 0).
for id_ in self.ids:
self.write("Return_Delay_Time", id_, 0)
for motor in self.motors:
self.write("Return_Delay_Time", motor, 0)
def disable_torque(self, motors: str | list[str] | None = None) -> None:
for name in self._get_names_list(motors):
self.write("Torque_Enable", name, TorqueMode.DISABLED.value)
def read_calibration(self) -> dict[str, MotorCalibration]:
offsets = self.sync_read("Homing_Offset", normalize=False)
mins = self.sync_read("Min_Position_Limit", normalize=False)
maxes = self.sync_read("Max_Position_Limit", normalize=False)
drive_modes = self.sync_read("Drive_Mode", normalize=False)
def enable_torque(self, motors: str | list[str] | None = None) -> None:
for name in self._get_names_list(motors):
self.write("Torque_Enable", name, TorqueMode.ENABLED.value)
calibration = {}
for name, motor in self.motors.items():
calibration[name] = MotorCalibration(
id=motor.id,
drive_mode=drive_modes[name],
homing_offset=offsets[name],
range_min=mins[name],
range_max=maxes[name],
)
return calibration
def write_calibration(self, calibration_dict: dict[str, MotorCalibration]) -> None:
for motor, calibration in calibration_dict.items():
self.write("Homing_Offset", motor, calibration.homing_offset)
self.write("Min_Position_Limit", motor, calibration.range_min)
self.write("Max_Position_Limit", motor, calibration.range_max)
self.calibration = calibration_dict
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
for name in self._get_motors_list(motors):
self.write("Torque_Enable", name, TorqueMode.DISABLED.value, num_retry=num_retry)
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
for name in self._get_motors_list(motors):
self.write("Torque_Enable", name, TorqueMode.ENABLED.value, num_retry=num_retry)
def _encode_sign(self, data_name: str, ids_values: dict[int, int]) -> dict[int, int]:
for id_ in ids_values:
@ -166,22 +235,8 @@ class DynamixelMotorsBus(MotorsBus):
return half_turn_homings
@staticmethod
def _split_into_byte_chunks(value: int, n_bytes: int) -> list[int]:
import dynamixel_sdk as dxl
if n_bytes == 1:
data = [value]
elif n_bytes == 2:
data = [dxl.DXL_LOBYTE(value), dxl.DXL_HIBYTE(value)]
elif n_bytes == 4:
data = [
dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)),
dxl.DXL_HIBYTE(dxl.DXL_LOWORD(value)),
dxl.DXL_LOBYTE(dxl.DXL_HIWORD(value)),
dxl.DXL_HIBYTE(dxl.DXL_HIWORD(value)),
]
return data
def _split_into_byte_chunks(self, value: int, length: int) -> list[int]:
return _split_into_byte_chunks(value, length)
def broadcast_ping(self, num_retry: int = 0, raise_on_error: bool = False) -> dict[int, int] | None:
for n_try in range(1 + num_retry):

View File

@ -1,3 +1,24 @@
# TODO(Steven): Consider doing the following:
# from enum import Enum
# class MyControlTableKey(Enum):
# ID = "ID"
# GOAL_SPEED = "Goal_Speed"
# ...
#
# MY_CONTROL_TABLE ={
# MyControlTableKey.ID.value: (5,1)
# MyControlTableKey.GOAL_SPEED.value: (46, 2)
# ...
# }
# This allows me do to:
# bus.write(MyControlTableKey.GOAL_SPEED, ...)
# Instead of:
# bus.write("Goal_Speed", ...)
# This is important for two reasons:
# 1. The linter will tell me if I'm trying to use an invalid key, instead of me realizing when I get the RunTimeError
# 2. We can change the value of the MyControlTableKey enums without impacting the client code
# {data_name: (address, size_byte)}
# https://emanual.robotis.com/docs/en/dxl/x/{MODEL}/#control-table
X_SERIES_CONTROL_TABLE = {
@ -57,13 +78,13 @@ X_SERIES_CONTROL_TABLE = {
# https://emanual.robotis.com/docs/en/dxl/x/{MODEL}/#baud-rate8
X_SERIES_BAUDRATE_TABLE = {
0: 9_600,
1: 57_600,
2: 115_200,
3: 1_000_000,
4: 2_000_000,
5: 3_000_000,
6: 4_000_000,
9_600: 0,
57_600: 1,
115_200: 2,
1_000_000: 3,
2_000_000: 4,
3_000_000: 5,
4_000_000: 6,
}
# {data_name: size_byte}

View File

@ -21,18 +21,20 @@ from lerobot.common.utils.encoding_utils import decode_sign_magnitude, encode_si
from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value
from .tables import (
FIRMWARE_VERSION,
FIRMWARE_MAJOR_VERSION,
FIRMWARE_MINOR_VERSION,
MODEL_BAUDRATE_TABLE,
MODEL_CONTROL_TABLE,
MODEL_ENCODING_TABLE,
MODEL_NUMBER,
MODEL_NUMBER_TABLE,
MODEL_PROTOCOL,
MODEL_RESOLUTION,
SCAN_BAUDRATES,
)
DEFAULT_PROTOCOL_VERSION = 0
BAUDRATE = 1_000_000
DEFAULT_BAUDRATE = 1_000_000
DEFAULT_TIMEOUT_MS = 1000
NORMALIZED_DATA = ["Goal_Position", "Present_Position"]
@ -64,6 +66,23 @@ class TorqueMode(Enum):
DISABLED = 0
def _split_into_byte_chunks(value: int, length: int) -> list[int]:
import scservo_sdk as scs
if length == 1:
data = [value]
elif length == 2:
data = [scs.SCS_LOBYTE(value), scs.SCS_HIBYTE(value)]
elif length == 4:
data = [
scs.SCS_LOBYTE(scs.SCS_LOWORD(value)),
scs.SCS_HIBYTE(scs.SCS_LOWORD(value)),
scs.SCS_LOBYTE(scs.SCS_HIWORD(value)),
scs.SCS_HIBYTE(scs.SCS_HIWORD(value)),
]
return data
def patch_setPacketTimeout(self, packet_length): # noqa: N802
"""
HACK: This patches the PortHandler behavior to set the correct packet timeouts.
@ -84,6 +103,7 @@ class FeetechMotorsBus(MotorsBus):
"""
available_baudrates = deepcopy(SCAN_BAUDRATES)
default_baudrate = DEFAULT_BAUDRATE
default_timeout = DEFAULT_TIMEOUT_MS
model_baudrate_table = deepcopy(MODEL_BAUDRATE_TABLE)
model_ctrl_table = deepcopy(MODEL_CONTROL_TABLE)
@ -100,9 +120,10 @@ class FeetechMotorsBus(MotorsBus):
protocol_version: int = DEFAULT_PROTOCOL_VERSION,
):
super().__init__(port, motors, calibration)
self.protocol_version = protocol_version
self._assert_same_protocol()
import scservo_sdk as scs
self.protocol_version = protocol_version
self.port_handler = scs.PortHandler(self.port)
# HACK: monkeypatch
self.port_handler.setPacketTimeout = patch_setPacketTimeout.__get__(
@ -114,17 +135,132 @@ class FeetechMotorsBus(MotorsBus):
self._comm_success = scs.COMM_SUCCESS
self._no_error = 0x00
if any(MODEL_PROTOCOL[model] != self.protocol_version for model in self.models):
raise ValueError(f"Some motors are incompatible with protocol_version={self.protocol_version}")
def _assert_same_protocol(self) -> None:
if any(MODEL_PROTOCOL[model] != self.protocol_version for model in self.models):
raise RuntimeError("Some motors use an incompatible protocol.")
def _assert_protocol_is_compatible(self, instruction_name: str) -> None:
if instruction_name == "sync_read" and self.protocol_version == 1:
raise NotImplementedError(
"'Sync Read' is not available with Feetech motors using Protocol 1. Use 'Read' instead."
"'Sync Read' is not available with Feetech motors using Protocol 1. Use 'Read' sequentially instead."
)
if instruction_name == "broadcast_ping" and self.protocol_version == 1:
raise NotImplementedError(
"'Broadcast Ping' is not available with Feetech motors using Protocol 1. Use 'Ping' sequentially instead."
)
def _assert_same_firmware(self) -> None:
firmware_versions = self._read_firmware_version(self.ids)
if len(set(firmware_versions.values())) != 1:
raise RuntimeError(
"Some Motors use different firmware versions. Update their firmware first using Feetech's software. "
"Visit https://www.feetechrc.com/software."
)
def _handshake(self) -> None:
self._assert_motors_exist()
self._assert_same_firmware()
def _find_single_motor(self, motor: str, initial_baudrate: int | None = None) -> tuple[int, int]:
if self.protocol_version == 0:
return self._find_single_motor_p0(motor, initial_baudrate)
else:
return self._find_single_motor_p1(motor, initial_baudrate)
def _find_single_motor_p0(self, motor: str, initial_baudrate: int | None = None) -> tuple[int, int]:
model = self.motors[motor].model
search_baudrates = (
[initial_baudrate] if initial_baudrate is not None else self.model_baudrate_table[model]
)
expected_model_nb = self.model_number_table[model]
for baudrate in search_baudrates:
self.set_baudrate(baudrate)
id_model = self.broadcast_ping()
if id_model:
found_id, found_model = next(iter(id_model.items()))
if found_model != expected_model_nb:
raise RuntimeError(
f"Found one motor on {baudrate=} with id={found_id} but it has a "
f"model number '{found_model}' different than the one expected: '{expected_model_nb}' "
f"Make sure you are connected only connected to the '{motor}' motor (model '{model}')."
)
return baudrate, found_id
raise RuntimeError(f"Motor '{motor}' (model '{model}') was not found. Make sure it is connected.")
def _find_single_motor_p1(self, motor: str, initial_baudrate: int | None = None) -> tuple[int, int]:
import scservo_sdk as scs
model = self.motors[motor].model
search_baudrates = (
[initial_baudrate] if initial_baudrate is not None else self.model_baudrate_table[model]
)
expected_model_nb = self.model_number_table[model]
for baudrate in search_baudrates:
self.set_baudrate(baudrate)
for id_ in range(scs.MAX_ID + 1):
found_model = self.ping(id_)
if found_model is not None and found_model != expected_model_nb:
raise RuntimeError(
f"Found one motor on {baudrate=} with id={id_} but it has a "
f"model number '{found_model}' different than the one expected: '{expected_model_nb}' "
f"Make sure you are connected only connected to the '{motor}' motor (model '{model}')."
)
return baudrate, id_
raise RuntimeError(f"Motor '{motor}' (model '{model}') was not found. Make sure it is connected.")
def configure_motors(self) -> None:
# By default, Feetech motors have a 500µs delay response time (corresponding to a value of 250 on the
# 'Return_Delay' address). We ensure this is reduced to the minimum of 2µs (value of 0).
for id_ in self.ids:
self.write("Return_Delay_Time", id_, 0)
for motor in self.motors:
# By default, Feetech motors have a 500µs delay response time (corresponding to a value of 250 on
# the 'Return_Delay_Time' address). We ensure this is reduced to the minimum of 2µs (value of 0).
self.write("Return_Delay_Time", motor, 0)
# Set 'Maximum_Acceleration' to 254 to speedup acceleration and deceleration of the motors.
# Note: this address is not in the official STS3215 Memory Table
self.write("Maximum_Acceleration", motor, 254)
self.write("Acceleration", motor, 254)
def read_calibration(self) -> dict[str, MotorCalibration]:
if self.protocol_version == 0:
offsets = self.sync_read("Homing_Offset", normalize=False)
mins = self.sync_read("Min_Position_Limit", normalize=False)
maxes = self.sync_read("Max_Position_Limit", normalize=False)
drive_modes = dict.fromkeys(self.motors, 0)
else:
offsets, mins, maxes, drive_modes = {}, {}, {}, {}
for motor in self.motors:
offsets[motor] = 0
mins[motor] = self.read("Min_Position_Limit", motor, normalize=False)
maxes[motor] = self.read("Max_Position_Limit", motor, normalize=False)
drive_modes[motor] = 0
# TODO(aliberts): add set/get_drive_mode?
calibration = {}
for name, motor in self.motors.items():
calibration[name] = MotorCalibration(
id=motor.id,
drive_mode=drive_modes[name],
homing_offset=offsets[name],
range_min=mins[name],
range_max=maxes[name],
)
return calibration
def write_calibration(self, calibration_dict: dict[str, MotorCalibration]) -> None:
for motor, calibration in calibration_dict.items():
if self.protocol_version == 0:
self.write("Homing_Offset", motor, calibration.homing_offset)
self.write("Min_Position_Limit", motor, calibration.range_min)
self.write("Max_Position_Limit", motor, calibration.range_max)
self.calibration = calibration_dict
def _get_half_turn_homings(self, positions: dict[NameOrID, Value]) -> dict[NameOrID, Value]:
"""
@ -139,15 +275,15 @@ class FeetechMotorsBus(MotorsBus):
return half_turn_homings
def disable_torque(self, motors: str | list[str] | None = None) -> None:
for name in self._get_names_list(motors):
self.write("Torque_Enable", name, TorqueMode.DISABLED.value)
self.write("Lock", name, 0)
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
for name in self._get_motors_list(motors):
self.write("Torque_Enable", name, TorqueMode.DISABLED.value, num_retry=num_retry)
self.write("Lock", name, 0, num_retry=num_retry)
def enable_torque(self, motors: str | list[str] | None = None) -> None:
for name in self._get_names_list(motors):
self.write("Torque_Enable", name, TorqueMode.ENABLED.value)
self.write("Lock", name, 1)
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
for name in self._get_motors_list(motors):
self.write("Torque_Enable", name, TorqueMode.ENABLED.value, num_retry=num_retry)
self.write("Lock", name, 1, num_retry=num_retry)
def _encode_sign(self, data_name: str, ids_values: dict[int, int]) -> dict[int, int]:
for id_ in ids_values:
@ -169,40 +305,10 @@ class FeetechMotorsBus(MotorsBus):
return ids_values
@staticmethod
def _split_into_byte_chunks(value: int, n_bytes: int) -> list[int]:
import scservo_sdk as scs
def _split_into_byte_chunks(self, value: int, length: int) -> list[int]:
return _split_into_byte_chunks(value, length)
if n_bytes == 1:
data = [value]
elif n_bytes == 2:
data = [scs.SCS_LOBYTE(value), scs.SCS_HIBYTE(value)]
elif n_bytes == 4:
data = [
scs.SCS_LOBYTE(scs.SCS_LOWORD(value)),
scs.SCS_HIBYTE(scs.SCS_LOWORD(value)),
scs.SCS_LOBYTE(scs.SCS_HIWORD(value)),
scs.SCS_HIBYTE(scs.SCS_HIWORD(value)),
]
return data
def _broadcast_ping_p1(self, known_motors_only: bool = True, num_retry: int = 0) -> dict[int, int]:
if known_motors_only:
ids = self.ids
else:
import scservo_sdk as scs
ids = range(scs.MAX_ID + 1)
ids_models = {}
for id_ in ids:
model_number = self.ping(id_, num_retry)
if model_number is not None:
ids_models[id_] = model_number
return ids_models
def _broadcast_ping_p0(self) -> tuple[dict[int, int], int]:
def _broadcast_ping(self) -> tuple[dict[int, int], int]:
import scservo_sdk as scs
data_list = {}
@ -277,83 +383,52 @@ class FeetechMotorsBus(MotorsBus):
rx_length = rx_length - idx
def broadcast_ping(self, num_retry: int = 0, raise_on_error: bool = False) -> dict[int, int] | None:
if self.protocol_version == 0:
for n_try in range(1 + num_retry):
ids_status, comm = self._broadcast_ping_p0()
if self._is_comm_success(comm):
break
logger.debug(f"Broadcast ping failed on port '{self.port}' ({n_try=})")
logger.debug(self.packet_handler.getTxRxResult(comm))
self._assert_protocol_is_compatible("broadcast_ping")
for n_try in range(1 + num_retry):
ids_status, comm = self._broadcast_ping()
if self._is_comm_success(comm):
break
logger.debug(f"Broadcast ping failed on port '{self.port}' ({n_try=})")
logger.debug(self.packet_handler.getTxRxResult(comm))
if not self._is_comm_success(comm):
if raise_on_error:
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
return
ids_errors = {id_: status for id_, status in ids_status.items() if self._is_error(status)}
if ids_errors:
display_dict = {
id_: self.packet_handler.getRxPacketError(err) for id_, err in ids_errors.items()
}
logger.error(
f"Some motors found returned an error status:\n{pformat(display_dict, indent=4)}"
)
return self._get_model_number(list(ids_status), raise_on_error)
else:
return self._broadcast_ping_p1(num_retry=num_retry)
def _get_firmware_version(self, motor_ids: list[int], raise_on_error: bool = False) -> dict[int, int]:
# comm, major = self._sync_read(*FIRMWARE_MAJOR_VERSION, motor_ids)
# if not self._is_comm_success(comm):
# if raise_on_error:
# raise ConnectionError(self.packet_handler.getTxRxResult(comm))
# return
# comm, minor = self._sync_read(*FIRMWARE_MINOR_VERSION, motor_ids)
# if not self._is_comm_success(comm):
# if raise_on_error:
# raise ConnectionError(self.packet_handler.getTxRxResult(comm))
# return
# return {id_: f"{major[id_]}.{minor[id_]}" for id_ in motor_ids}
comm, firmware_versions = self._sync_read(*FIRMWARE_VERSION, motor_ids)
if not self._is_comm_success(comm):
if raise_on_error:
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
return
return firmware_versions
ids_errors = {id_: status for id_, status in ids_status.items() if self._is_error(status)}
if ids_errors:
display_dict = {id_: self.packet_handler.getRxPacketError(err) for id_, err in ids_errors.items()}
logger.error(f"Some motors found returned an error status:\n{pformat(display_dict, indent=4)}")
def _get_model_number(self, motor_ids: list[int], raise_on_error: bool = False) -> dict[int, int]:
# comm, major = self._sync_read(*MODEL_MAJOR_VERSION, motor_ids)
# if not self._is_comm_success(comm):
# if raise_on_error:
# raise ConnectionError(self.packet_handler.getTxRxResult(comm))
# return
return self._read_model_number(list(ids_status), raise_on_error)
# comm, minor = self._sync_read(*MODEL_MINOR_VERSION, motor_ids)
# if not self._is_comm_success(comm):
# if raise_on_error:
# raise ConnectionError(self.packet_handler.getTxRxResult(comm))
# return
# return {id_: f"{major[id_]}.{minor[id_]}" for id_ in motor_ids}
if self.protocol_version == 1:
model_numbers = {}
for id_ in motor_ids:
model_nb, comm, error = self._read(*MODEL_NUMBER, id_)
if self._is_comm_success(comm) and not self._is_error(error):
model_numbers[id_] = model_nb
elif raise_on_error:
raise Exception # FIX
else:
comm, model_numbers = self._sync_read(*MODEL_NUMBER, motor_ids)
if not self._is_comm_success(comm):
if raise_on_error:
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
def _read_firmware_version(self, motor_ids: list[int], raise_on_error: bool = False) -> dict[int, str]:
firmware_versions = {}
for id_ in motor_ids:
firm_ver_major, comm, error = self._read(
*FIRMWARE_MAJOR_VERSION, id_, raise_on_error=raise_on_error
)
if not self._is_comm_success(comm) or self._is_error(error):
return
firm_ver_minor, comm, error = self._read(
*FIRMWARE_MINOR_VERSION, id_, raise_on_error=raise_on_error
)
if not self._is_comm_success(comm) or self._is_error(error):
return
firmware_versions[id_] = f"{firm_ver_major}.{firm_ver_minor}"
return firmware_versions
def _read_model_number(self, motor_ids: list[int], raise_on_error: bool = False) -> dict[int, int]:
model_numbers = {}
for id_ in motor_ids:
model_nb, comm, error = self._read(*MODEL_NUMBER, id_, raise_on_error=raise_on_error)
if not self._is_comm_success(comm) or self._is_error(error):
return
model_numbers[id_] = model_nb
return model_numbers

View File

@ -1,22 +1,34 @@
FIRMWARE_MAJOR_VERSION = (0, 1)
FIRMWARE_MINOR_VERSION = (1, 1)
MODEL_MAJOR_VERSION = (3, 1)
MODEL_MINOR_VERSION = (4, 1)
FIRMWARE_VERSION = (0, 2)
MODEL_NUMBER = (3, 2)
# See this link for STS3215 Memory Table:
# https://docs.google.com/spreadsheets/d/1GVs7W1VS1PqdhA1nW-abeyAHhTUxKUdR/edit?usp=sharing&ouid=116566590112741600240&rtpof=true&sd=true
# TODO(Steven): Consider doing the following:
# from enum import Enum
# class MyControlTableKey(Enum):
# ID = "ID"
# GOAL_SPEED = "Goal_Speed"
# ...
#
# MY_CONTROL_TABLE ={
# MyControlTableKey.ID.value: (5,1)
# MyControlTableKey.GOAL_SPEED.value: (46, 2)
# ...
# }
# This allows me do to:
# bus.write(MyControlTableKey.GOAL_SPEED, ...)
# Instead of:
# bus.write("Goal_Speed", ...)
# This is important for two reasons:
# 1. The linter will tell me if I'm trying to use an invalid key, instead of me realizing when I get the RunTimeError
# 2. We can change the value of the MyControlTableKey enums without impacting the client code
# data_name: (address, size_byte)
# http://doc.feetech.cn/#/prodinfodownload?srcType=FT-SMS-STS-emanual-229f4476422d4059abfb1cb0
STS_SMS_SERIES_CONTROL_TABLE = {
# EPROM
"Firmware_Version": FIRMWARE_VERSION, # read-only
"Firmware_Major_Version": FIRMWARE_MAJOR_VERSION, # read-only
"Firmware_Minor_Version": FIRMWARE_MINOR_VERSION, # read-only
"Model_Number": MODEL_NUMBER, # read-only
# "Firmware_Major_Version": FIRMWARE_MAJOR_VERSION, # read-only
# "Firmware_Minor_Version": FIRMWARE_MINOR_VERSION, # read-only
# "Model_Major_Version": MODEL_MAJOR_VERSION, # read-only
# "Model_Minor_Version": MODEL_MINOR_VERSION,
"ID": (5, 1),
"Baud_Rate": (6, 1),
"Return_Delay_Time": (7, 1),
@ -43,7 +55,7 @@ STS_SMS_SERIES_CONTROL_TABLE = {
"Protective_Torque": (34, 1),
"Protection_Time": (35, 1),
"Overload_Torque": (36, 1),
"Speed_closed_loop_P_proportional_coefficient": (37, 1),
"Velocity_closed_loop_P_proportional_coefficient": (37, 1),
"Over_Current_Protection_Time": (38, 1),
"Velocity_closed_loop_I_integral_coefficient": (39, 1),
# SRAM
@ -51,32 +63,38 @@ STS_SMS_SERIES_CONTROL_TABLE = {
"Acceleration": (41, 1),
"Goal_Position": (42, 2),
"Goal_Time": (44, 2),
"Goal_Speed": (46, 2),
"Goal_Velocity": (46, 2),
"Torque_Limit": (48, 2),
"Lock": (55, 1),
"Present_Position": (56, 2), # read-only
"Present_Speed": (58, 2), # read-only
"Present_Velocity": (58, 2), # read-only
"Present_Load": (60, 2), # read-only
"Present_Voltage": (62, 1), # read-only
"Present_Temperature": (63, 1), # read-only
"Status": (65, 1), # read-only
"Moving": (66, 1), # read-only
"Present_Current": (69, 2), # read-only
# Not in the Memory Table
"Maximum_Acceleration": (85, 2),
"Goal_Position_2": (71, 2), # read-only
# Factory
"Moving_Velocity": (80, 1),
"Moving_Velocity_Threshold": (80, 1),
"DTs": (81, 1), # (ms)
"Velocity_Unit_factor": (82, 1),
"Hts": (83, 1), # (ns) valid for firmware >= 2.54, other versions keep 0
"Maximum_Velocity_Limit": (84, 1),
"Maximum_Acceleration": (85, 1),
"Acceleration_Multiplier ": (86, 1), # Acceleration multiplier in effect when acceleration is 0
}
# http://doc.feetech.cn/#/prodinfodownload?srcType=FT-SCSCL-emanual-cbcc8ab2e3384282a01d4bf3
SCS_SERIES_CONTROL_TABLE = {
# EPROM
"Firmware_Version": FIRMWARE_VERSION, # read-only
"Firmware_Major_Version": FIRMWARE_MAJOR_VERSION, # read-only
"Firmware_Minor_Version": FIRMWARE_MINOR_VERSION, # read-only
"Model_Number": MODEL_NUMBER, # read-only
# "Firmware_Major_Version": FIRMWARE_MAJOR_VERSION, # read-only
# "Firmware_Minor_Version": FIRMWARE_MINOR_VERSION, # read-only
# "Model_Major_Version": MODEL_MAJOR_VERSION, # read-only
# "Model_Minor_Version": MODEL_MINOR_VERSION,
"ID": (5, 1),
"Baud_Rate": (6, 1),
"Return_Delay": (7, 1),
"Return_Delay_Time": (7, 1),
"Response_Status_Level": (8, 1),
"Min_Position_Limit": (9, 2),
"Max_Position_Limit": (11, 2),
@ -100,38 +118,45 @@ SCS_SERIES_CONTROL_TABLE = {
"Acceleration": (41, 1),
"Goal_Position": (42, 2),
"Running_Time": (44, 2),
"Goal_Speed": (46, 2),
"Goal_Velocity": (46, 2),
"Lock": (48, 1),
"Present_Position": (56, 2), # read-only
"Present_Speed": (58, 2), # read-only
"Present_Velocity": (58, 2), # read-only
"Present_Load": (60, 2), # read-only
"Present_Voltage": (62, 1), # read-only
"Present_Temperature": (63, 1), # read-only
"Sync_Write_Flag": (64, 1), # read-only
"Status": (65, 1), # read-only
"Moving": (66, 1), # read-only
# Factory
"PWM_Maximum_Step": (78, 1),
"Moving_Velocity_Threshold*50": (79, 1),
"DTs": (80, 1), # (ms)
"Minimum_Velocity_Limit*50": (81, 1),
"Maximum_Velocity_Limit*50": (82, 1),
"Acceleration_2": (83, 1), # don't know what that is
}
STS_SMS_SERIES_BAUDRATE_TABLE = {
0: 1_000_000,
1: 500_000,
2: 250_000,
3: 128_000,
4: 115_200,
5: 57_600,
6: 38_400,
7: 19_200,
1_000_000: 0,
500_000: 1,
250_000: 2,
128_000: 3,
115_200: 4,
57_600: 5,
38_400: 6,
19_200: 7,
}
SCS_SERIES_BAUDRATE_TABLE = {
0: 1_000_000,
1: 500_000,
2: 250_000,
3: 128_000,
4: 115_200,
5: 57_600,
6: 38_400,
7: 19_200,
1_000_000: 0,
500_000: 1,
250_000: 2,
128_000: 3,
115_200: 4,
57_600: 5,
38_400: 6,
19_200: 7,
}
MODEL_CONTROL_TABLE = {
@ -150,7 +175,7 @@ MODEL_RESOLUTION = {
"scs_series": 1024,
"sts3215": 4096,
"sts3250": 4096,
"sm8512bl": 4096,
"sm8512bl": 65536,
"scs0009": 1024,
}
@ -167,7 +192,7 @@ MODEL_BAUDRATE_TABLE = {
# Sign-Magnitude encoding bits
STS_SMS_SERIES_ENCODINGS_TABLE = {
"Homing_Offset": 11,
"Goal_Speed": 15,
"Goal_Velocity": 15,
}
MODEL_ENCODING_TABLE = {
@ -194,10 +219,19 @@ SCAN_BAUDRATES = [
1_000_000,
]
# {model: model_number} TODO
MODEL_NUMBER_TABLE = {
"sts3215": 777,
"sts3250": None,
"sm8512bl": None,
"sts3250": 2825,
"sm8512bl": 11272,
"scs0009": 1284,
}
MODEL_PROTOCOL = {
"sts_series": 0,
"sms_series": 0,
"scs_series": 1,
"sts3215": 0,
"sts3250": 0,
"sm8512bl": 0,
"scs0009": 1,
}

View File

@ -21,6 +21,7 @@
import abc
import logging
from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum
from functools import cached_property
@ -254,6 +255,7 @@ class MotorsBus(abc.ABC):
"""
available_baudrates: list[int]
default_baudrate: int
default_timeout: int
model_baudrate_table: dict[str, dict]
model_ctrl_table: dict[str, dict]
@ -283,6 +285,8 @@ class MotorsBus(abc.ABC):
self._id_to_name_dict = {m.id: name for name, m in self.motors.items()}
self._model_nb_to_model_dict = {v: k for k, v in self.model_number_table.items()}
self._validate_motors()
def __len__(self):
return len(self.motors)
@ -341,7 +345,7 @@ class MotorsBus(abc.ABC):
else:
raise TypeError(f"'{motor}' should be int, str.")
def _get_names_list(self, motors: str | list[str] | None) -> list[str]:
def _get_motors_list(self, motors: str | list[str] | None) -> list[str]:
if motors is None:
return self.names
elif isinstance(motors, str):
@ -375,9 +379,13 @@ class MotorsBus(abc.ABC):
def _assert_motors_exist(self) -> None:
# TODO(aliberts): collect all wrong ids/models and display them at once
found_models = self.broadcast_ping()
found_models = {}
for id_ in self.ids:
model_nb = self.ping(id_)
if model_nb is not None:
found_models[id_] = model_nb
expected_models = {m.id: self.model_number_table[m.model] for m in self.motors.values()}
if not found_models or set(found_models) != set(self.ids):
if set(found_models) != set(self.ids):
raise RuntimeError(
f"{self.__class__.__name__} is supposed to have these motors: ({{id: model_nb}})"
f"\n{pformat(expected_models, indent=4, sort_dicts=False)}\n"
@ -401,36 +409,36 @@ class MotorsBus(abc.ABC):
def is_connected(self) -> bool:
return self.port_handler.is_open
def connect(self, assert_motors_exist: bool = True) -> None:
def connect(self, handshake: bool = True) -> None:
if self.is_connected:
raise DeviceAlreadyConnectedError(
f"{self.__class__.__name__}('{self.port}') is already connected. Do not call `{self.__class__.__name__}.connect()` twice."
)
self._connect(handshake)
self.set_timeout()
logger.debug(f"{self.__class__.__name__} connected.")
def _connect(self, handshake: bool = True) -> None:
try:
if not self.port_handler.openPort():
raise OSError(f"Failed to open port '{self.port}'.")
elif assert_motors_exist:
self._assert_motors_exist()
elif handshake:
self._handshake()
except (FileNotFoundError, OSError, serial.SerialException) as e:
raise ConnectionError(
f"\nCould not connect on port '{self.port}'. Make sure you are using the correct port."
"\nTry running `python lerobot/scripts/find_motors_bus_port.py`\n"
) from e
self.set_timeout()
logger.debug(f"{self.__class__.__name__} connected.")
@abc.abstractmethod
def _handshake(self) -> None:
pass
@classmethod
def scan_port(cls, port: str) -> dict[int, list[int]]:
bus = cls(port, {})
try:
bus.port_handler.openPort()
except (FileNotFoundError, OSError, serial.SerialException) as e:
raise ConnectionError(
f"Could not connect to port '{port}'. Make sure you are using the correct port."
"\nTry running `python lerobot/scripts/find_motors_bus_port.py`\n"
) from e
def scan_port(cls, port: str, *args, **kwargs) -> dict[int, list[int]]:
bus = cls(port, {}, *args, **kwargs)
bus._connect(handshake=False)
baudrate_ids = {}
for baudrate in tqdm(bus.available_baudrates, desc="Scanning port"):
bus.set_baudrate(baudrate)
@ -441,18 +449,57 @@ class MotorsBus(abc.ABC):
return baudrate_ids
def setup_motor(
self, motor: str, initial_baudrate: int | None = None, initial_id: int | None = None
) -> None:
if not self.is_connected:
self._connect(handshake=False)
if initial_baudrate is None:
initial_baudrate, initial_id = self._find_single_motor(motor)
if initial_id is None:
_, initial_id = self._find_single_motor(motor, initial_baudrate)
model = self.motors[motor].model
target_id = self.motors[motor].id
self.set_baudrate(initial_baudrate)
# Set ID
addr, length = get_address(self.model_ctrl_table, "ID", model)
self._write(addr, length, initial_id, target_id)
# Set Baudrate
addr, length = get_address(self.model_ctrl_table, "Baud_Rate", model)
baudrate_value = self.model_baudrate_table[model][self.default_baudrate]
self._write(addr, length, target_id, baudrate_value)
self.set_baudrate(self.default_baudrate)
@abc.abstractmethod
def _find_single_motor(self, motor: str, initial_baudrate: int | None) -> tuple[int, int]:
pass
@abc.abstractmethod
def configure_motors(self) -> None:
pass
@abc.abstractmethod
def disable_torque(self, motors: str | list[str] | None = None) -> None:
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
pass
@abc.abstractmethod
def enable_torque(self, motors: str | list[str] | None = None) -> None:
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
pass
@contextmanager
def torque_disabled(self):
self.disable_torque()
try:
yield
finally:
self.enable_torque()
def set_timeout(self, timeout_ms: int | None = None):
timeout_ms = timeout_ms if timeout_ms is not None else self.default_timeout
self.port_handler.setPacketTimeoutMillis(timeout_ms)
@ -473,35 +520,13 @@ class MotorsBus(abc.ABC):
def is_calibrated(self) -> bool:
return self.calibration == self.read_calibration()
@abc.abstractmethod
def read_calibration(self) -> dict[str, MotorCalibration]:
offsets = self.sync_read("Homing_Offset", normalize=False)
mins = self.sync_read("Min_Position_Limit", normalize=False)
maxes = self.sync_read("Max_Position_Limit", normalize=False)
try:
drive_modes = self.sync_read("Drive_Mode", normalize=False)
except KeyError:
drive_modes = dict.fromkeys(self.names, 0)
calibration = {}
for name, motor in self.motors.items():
calibration[name] = MotorCalibration(
id=motor.id,
drive_mode=drive_modes[name],
homing_offset=offsets[name],
range_min=mins[name],
range_max=maxes[name],
)
return calibration
pass
@abc.abstractmethod
def write_calibration(self, calibration_dict: dict[str, MotorCalibration]) -> None:
for motor, calibration in calibration_dict.items():
self.write("Homing_Offset", motor, calibration.homing_offset)
self.write("Min_Position_Limit", motor, calibration.range_min)
self.write("Max_Position_Limit", motor, calibration.range_max)
self.calibration = calibration_dict
pass
def reset_calibration(self, motors: NameOrID | list[NameOrID] | None = None) -> None:
if motors is None:
@ -600,13 +625,15 @@ class MotorsBus(abc.ABC):
def _normalize(self, data_name: str, ids_values: dict[int, int]) -> dict[int, float]:
if not self.calibration:
raise RuntimeError(f"{self} has no calibration registered.")
normalized_values = {}
for id_, val in ids_values.items():
name = self._id_to_name(id_)
min_ = self.calibration[name].range_min
max_ = self.calibration[name].range_max
bounded_val = min(max_, max(min_, val))
# TODO(Steven): normalization can go boom if max_ == min_, we should add a check probably in record_ranges_of_motions (which probably indicates the user forgot to move a motor)
# TODO(Steven): normalization can go boom if max_ == min_, we should add a check probably in record_ranges_of_motions
# (which probably indicates the user forgot to move a motor, most likely a gripper-like one)
if self.motors[name].norm_mode is MotorNormMode.RANGE_M100_100:
normalized_values[id_] = (((bounded_val - min_) / (max_ - min_)) * 200) - 100
elif self.motors[name].norm_mode is MotorNormMode.RANGE_0_100:
@ -618,6 +645,9 @@ class MotorsBus(abc.ABC):
return normalized_values
def _unnormalize(self, data_name: str, ids_values: dict[int, float]) -> dict[int, int]:
if not self.calibration:
raise RuntimeError(f"{self} has no calibration registered.")
unnormalized_values = {}
for id_, val in ids_values.items():
name = self._id_to_name(id_)
@ -643,57 +673,30 @@ class MotorsBus(abc.ABC):
def _decode_sign(self, data_name: str, ids_values: dict[int, int]) -> dict[int, int]:
pass
def _serialize_data(self, value: int, n_bytes: int) -> list[int]:
def _serialize_data(self, value: int, length: int) -> list[int]:
"""
Converts an unsigned integer value into a list of byte-sized integers to be sent via a communication
protocol. Depending on the protocol, split values can be in big-endian or little-endian order.
This function extracts the individual bytes of an integer based on the
specified number of bytes (`n_bytes`). The output is a list of integers,
each representing a byte (0-255).
**Byte order:** The function returns bytes in **little-endian format**,
meaning the least significant byte (LSB) comes first.
Args:
value (int): The unsigned integer to be converted into a byte list. Must be within
the valid range for the specified `n_bytes`.
n_bytes (int): The number of bytes to use for conversion. Supported values for both Feetech and
Dynamixel:
- 1 (for values 0 to 255)
- 2 (for values 0 to 65,535)
- 4 (for values 0 to 4,294,967,295)
Raises:
ValueError: If `value` is negative or exceeds the maximum allowed for `n_bytes`.
NotImplementedError: If `n_bytes` is not 1, 2, or 4.
Returns:
list[int]: A list of integers, each representing a byte in **little-endian order**.
Examples (for a little-endian protocol):
>>> split_int_bytes(0x12, 1)
[18]
>>> split_int_bytes(0x1234, 2)
[52, 18] # 0x1234 → 0x34 0x12 (little-endian)
>>> split_int_bytes(0x12345678, 4)
[120, 86, 52, 18] # 0x12345678 → 0x78 0x56 0x34 0x12
Supported data length for both Feetech and Dynamixel:
- 1 (for values 0 to 255)
- 2 (for values 0 to 65,535)
- 4 (for values 0 to 4,294,967,295)
"""
if value < 0:
raise ValueError(f"Negative values are not allowed: {value}")
max_value = {1: 0xFF, 2: 0xFFFF, 4: 0xFFFFFFFF}.get(n_bytes)
max_value = {1: 0xFF, 2: 0xFFFF, 4: 0xFFFFFFFF}.get(length)
if max_value is None:
raise NotImplementedError(f"Unsupported byte size: {n_bytes}. Expected [1, 2, 4].")
raise NotImplementedError(f"Unsupported byte size: {length}. Expected [1, 2, 4].")
if value > max_value:
raise ValueError(f"Value {value} exceeds the maximum for {n_bytes} bytes ({max_value}).")
raise ValueError(f"Value {value} exceeds the maximum for {length} bytes ({max_value}).")
return self._split_into_byte_chunks(value, n_bytes)
return self._split_into_byte_chunks(value, length)
@staticmethod
@abc.abstractmethod
def _split_into_byte_chunks(value: int, n_bytes: int) -> list[int]:
def _split_into_byte_chunks(self, value: int, length: int) -> list[int]:
"""Convert an integer into a list of byte-sized integers."""
pass
@ -712,7 +715,7 @@ class MotorsBus(abc.ABC):
return
if self._is_error(error):
if raise_on_error:
raise RuntimeError(self.packet_handler.getTxRxResult(comm))
raise RuntimeError(self.packet_handler.getRxPacketError(error))
else:
return
@ -737,19 +740,10 @@ class MotorsBus(abc.ABC):
id_ = self.motors[motor].id
model = self.motors[motor].model
addr, n_bytes = get_address(self.model_ctrl_table, model, data_name)
addr, length = get_address(self.model_ctrl_table, model, data_name)
value, comm, error = self._read(addr, n_bytes, id_, num_retry=num_retry)
if not self._is_comm_success(comm):
raise ConnectionError(
f"Failed to read '{data_name}' on {id_=} after {num_retry + 1} tries."
f"{self.packet_handler.getTxRxResult(comm)}"
)
elif self._is_error(error):
raise RuntimeError(
f"Failed to read '{data_name}' on {id_=} after {num_retry + 1} tries."
f"\n{self.packet_handler.getRxPacketError(error)}"
)
err_msg = f"Failed to read '{data_name}' on {id_=} after {num_retry + 1} tries."
value, _, _ = self._read(addr, length, id_, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
id_value = self._decode_sign(data_name, {id_: value})
@ -758,25 +752,39 @@ class MotorsBus(abc.ABC):
return id_value[id_]
def _read(self, addr: int, n_bytes: int, motor_id: int, num_retry: int = 0) -> tuple[int, int]:
if n_bytes == 1:
def _read(
self,
address: int,
length: int,
motor_id: int,
*,
num_retry: int = 0,
raise_on_error: bool = True,
err_msg: str = "",
) -> tuple[int, int]:
if length == 1:
read_fn = self.packet_handler.read1ByteTxRx
elif n_bytes == 2:
elif length == 2:
read_fn = self.packet_handler.read2ByteTxRx
elif n_bytes == 4:
elif length == 4:
read_fn = self.packet_handler.read4ByteTxRx
else:
raise ValueError(n_bytes)
raise ValueError(length)
for n_try in range(1 + num_retry):
value, comm, error = read_fn(self.port_handler, motor_id, addr)
value, comm, error = read_fn(self.port_handler, motor_id, address)
if self._is_comm_success(comm):
break
logger.debug(
f"Failed to read @{addr=} ({n_bytes=}) on {motor_id=} ({n_try=}): "
f"Failed to read @{address=} ({length=}) on {motor_id=} ({n_try=}): "
+ self.packet_handler.getTxRxResult(comm)
)
if not self._is_comm_success(comm) and raise_on_error:
raise ConnectionError(f"{err_msg} {self.packet_handler.getTxRxResult(comm)}")
elif self._is_error(error) and raise_on_error:
raise RuntimeError(f"{err_msg} {self.packet_handler.getRxPacketError(error)}")
return value, comm, error
def write(
@ -789,38 +797,42 @@ class MotorsBus(abc.ABC):
id_ = self.motors[motor].id
model = self.motors[motor].model
addr, n_bytes = get_address(self.model_ctrl_table, model, data_name)
addr, length = get_address(self.model_ctrl_table, model, data_name)
if normalize and data_name in self.normalized_data:
value = self._unnormalize(data_name, {id_: value})[id_]
value = self._encode_sign(data_name, {id_: value})[id_]
comm, error = self._write(addr, n_bytes, id_, value, num_retry=num_retry)
if not self._is_comm_success(comm):
raise ConnectionError(
f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries."
f"\n{self.packet_handler.getTxRxResult(comm)}"
)
elif self._is_error(error):
raise RuntimeError(
f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries."
f"\n{self.packet_handler.getRxPacketError(error)}"
)
err_msg = f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries."
self._write(addr, length, id_, value, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
def _write(
self, addr: int, n_bytes: int, motor_id: int, value: int, num_retry: int = 0
self,
addr: int,
length: int,
motor_id: int,
value: int,
*,
num_retry: int = 0,
raise_on_error: bool = True,
err_msg: str = "",
) -> tuple[int, int]:
data = self._serialize_data(value, n_bytes)
data = self._serialize_data(value, length)
for n_try in range(1 + num_retry):
comm, error = self.packet_handler.writeTxRx(self.port_handler, motor_id, addr, n_bytes, data)
comm, error = self.packet_handler.writeTxRx(self.port_handler, motor_id, addr, length, data)
if self._is_comm_success(comm):
break
logger.debug(
f"Failed to sync write @{addr=} ({n_bytes=}) on id={motor_id} with {value=} ({n_try=}): "
f"Failed to sync write @{addr=} ({length=}) on id={motor_id} with {value=} ({n_try=}): "
+ self.packet_handler.getTxRxResult(comm)
)
if not self._is_comm_success(comm) and raise_on_error:
raise ConnectionError(f"{err_msg} {self.packet_handler.getTxRxResult(comm)}")
elif self._is_error(error) and raise_on_error:
raise RuntimeError(f"{err_msg} {self.packet_handler.getRxPacketError(error)}")
return comm, error
def sync_read(
@ -838,7 +850,7 @@ class MotorsBus(abc.ABC):
self._assert_protocol_is_compatible("sync_read")
names = self._get_names_list(motors)
names = self._get_motors_list(motors)
ids = [self.motors[name].id for name in names]
models = [self.motors[name].model for name in names]
@ -846,14 +858,12 @@ class MotorsBus(abc.ABC):
assert_same_address(self.model_ctrl_table, models, data_name)
model = next(iter(models))
addr, n_bytes = get_address(self.model_ctrl_table, model, data_name)
addr, length = get_address(self.model_ctrl_table, model, data_name)
comm, ids_values = self._sync_read(addr, n_bytes, ids, num_retry=num_retry)
if not self._is_comm_success(comm):
raise ConnectionError(
f"Failed to sync read '{data_name}' on {ids=} after {num_retry + 1} tries."
f"{self.packet_handler.getTxRxResult(comm)}"
)
err_msg = f"Failed to sync read '{data_name}' on {ids=} after {num_retry + 1} tries."
ids_values, _ = self._sync_read(
addr, length, ids, num_retry=num_retry, raise_on_error=True, err_msg=err_msg
)
ids_values = self._decode_sign(data_name, ids_values)
@ -863,25 +873,35 @@ class MotorsBus(abc.ABC):
return {self._id_to_name(id_): value for id_, value in ids_values.items()}
def _sync_read(
self, addr: int, n_bytes: int, motor_ids: list[int], num_retry: int = 0
) -> tuple[int, dict[int, int]]:
self._setup_sync_reader(motor_ids, addr, n_bytes)
self,
addr: int,
length: int,
motor_ids: list[int],
*,
num_retry: int = 0,
raise_on_error: bool = True,
err_msg: str = "",
) -> tuple[dict[int, int], int]:
self._setup_sync_reader(motor_ids, addr, length)
for n_try in range(1 + num_retry):
comm = self.sync_reader.txRxPacket()
if self._is_comm_success(comm):
break
logger.debug(
f"Failed to sync read @{addr=} ({n_bytes=}) on {motor_ids=} ({n_try=}): "
f"Failed to sync read @{addr=} ({length=}) on {motor_ids=} ({n_try=}): "
+ self.packet_handler.getTxRxResult(comm)
)
values = {id_: self.sync_reader.getData(id_, addr, n_bytes) for id_ in motor_ids}
return comm, values
if not self._is_comm_success(comm) and raise_on_error:
raise ConnectionError(f"{err_msg} {self.packet_handler.getTxRxResult(comm)}")
def _setup_sync_reader(self, motor_ids: list[int], addr: int, n_bytes: int) -> None:
values = {id_: self.sync_reader.getData(id_, addr, length) for id_ in motor_ids}
return values, comm
def _setup_sync_reader(self, motor_ids: list[int], addr: int, length: int) -> None:
self.sync_reader.clearParam()
self.sync_reader.start_address = addr
self.sync_reader.data_length = n_bytes
self.sync_reader.data_length = length
for id_ in motor_ids:
self.sync_reader.addParam(id_)
@ -889,15 +909,15 @@ class MotorsBus(abc.ABC):
# Would have to handle the logic of checking if a packet has been sent previously though but doable.
# This could be at the cost of increase latency between the moment the data is produced by the motors and
# the moment it is used by a policy.
# def _async_read(self, motor_ids: list[int], address: int, n_bytes: int):
# if self.sync_reader.start_address != address or self.sync_reader.data_length != n_bytes or ...:
# self._setup_sync_reader(motor_ids, address, n_bytes)
# def _async_read(self, motor_ids: list[int], address: int, length: int):
# if self.sync_reader.start_address != address or self.sync_reader.data_length != length or ...:
# self._setup_sync_reader(motor_ids, address, length)
# else:
# self.sync_reader.rxPacket()
# self.sync_reader.txPacket()
# for id_ in motor_ids:
# value = self.sync_reader.getData(id_, address, n_bytes)
# value = self.sync_reader.getData(id_, address, length)
def sync_write(
self,
@ -918,39 +938,46 @@ class MotorsBus(abc.ABC):
assert_same_address(self.model_ctrl_table, models, data_name)
model = next(iter(models))
addr, n_bytes = get_address(self.model_ctrl_table, model, data_name)
addr, length = get_address(self.model_ctrl_table, model, data_name)
if normalize and data_name in self.normalized_data:
ids_values = self._unnormalize(data_name, ids_values)
ids_values = self._encode_sign(data_name, ids_values)
comm = self._sync_write(addr, n_bytes, ids_values, num_retry=num_retry)
if not self._is_comm_success(comm):
raise ConnectionError(
f"Failed to sync write '{data_name}' with {ids_values=} after {num_retry + 1} tries."
f"\n{self.packet_handler.getTxRxResult(comm)}"
)
err_msg = f"Failed to sync write '{data_name}' with {ids_values=} after {num_retry + 1} tries."
self._sync_write(addr, length, ids_values, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
def _sync_write(self, addr: int, n_bytes: int, ids_values: dict[int, int], num_retry: int = 0) -> int:
self._setup_sync_writer(ids_values, addr, n_bytes)
def _sync_write(
self,
addr: int,
length: int,
ids_values: dict[int, int],
num_retry: int = 0,
raise_on_error: bool = True,
err_msg: str = "",
) -> int:
self._setup_sync_writer(ids_values, addr, length)
for n_try in range(1 + num_retry):
comm = self.sync_writer.txPacket()
if self._is_comm_success(comm):
break
logger.debug(
f"Failed to sync write @{addr=} ({n_bytes=}) with {ids_values=} ({n_try=}): "
f"Failed to sync write @{addr=} ({length=}) with {ids_values=} ({n_try=}): "
+ self.packet_handler.getTxRxResult(comm)
)
if not self._is_comm_success(comm) and raise_on_error:
raise ConnectionError(f"{err_msg} {self.packet_handler.getTxRxResult(comm)}")
return comm
def _setup_sync_writer(self, ids_values: dict[int, int], addr: int, n_bytes: int) -> None:
def _setup_sync_writer(self, ids_values: dict[int, int], addr: int, length: int) -> None:
self.sync_writer.clearParam()
self.sync_writer.start_address = addr
self.sync_writer.data_length = n_bytes
self.sync_writer.data_length = length
for id_, value in ids_values.items():
data = self._serialize_data(value, n_bytes)
data = self._serialize_data(value, length)
self.sync_writer.addParam(id_, data)
def disconnect(self, disable_torque: bool = True) -> None:
@ -962,7 +989,7 @@ class MotorsBus(abc.ABC):
if disable_torque:
self.port_handler.clearPort()
self.port_handler.is_using = False
self.disable_torque()
self.disable_torque(num_retry=5)
self.port_handler.closePort()
logger.debug(f"{self.__class__.__name__} disconnected.")

View File

@ -24,7 +24,7 @@ Designed by Physical Intelligence. Ported from Jax by Hugging Face.
Install pi0 extra dependencies:
```bash
pip install --no-binary=av -e ".[pi0]"
pip install -e ".[pi0]"
```
Example of finetuning the pi0 pretrained model (`pi0_base` in `openpi`):

View File

@ -1,23 +1,16 @@
import abc
import enum
from dataclasses import dataclass
from pathlib import Path
import draccus
class RobotMode(enum.Enum):
TELEOP = 0
AUTO = 1
@dataclass(kw_only=True)
class RobotConfig(draccus.ChoiceRegistry, abc.ABC):
# Allows to distinguish between different robots of the same type
id: str | None = None
# Directory to store calibration file
calibration_dir: Path | None = None
robot_mode: RobotMode | None = None
@property
def type(self) -> str:

View File

@ -122,7 +122,7 @@ class KochFollower(Robot):
full_turn_motors = ["shoulder_pan", "wrist_roll"]
unknown_range_motors = [name for name in self.arm.names if name not in full_turn_motors]
logger.info(
print(
f"Move all joints except {full_turn_motors} sequentially through their entire "
"ranges of motion.\nRecording positions. Press ENTER to stop..."
)
@ -146,29 +146,28 @@ class KochFollower(Robot):
logger.info(f"Calibration saved to {self.calibration_fpath}")
def configure(self) -> None:
self.arm.disable_torque()
self.arm.configure_motors()
# Use 'extended position mode' for all motors except gripper, because in joint mode the servos
# can't rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while
# assembling the arm, you could end up with a servo with a position 0 or 4095 at a crucial
# point
for name in self.arm.names:
if name != "gripper":
self.arm.write("Operating_Mode", name, OperatingMode.EXTENDED_POSITION.value)
with self.arm.torque_disabled():
self.arm.configure_motors()
# Use 'extended position mode' for all motors except gripper, because in joint mode the servos
# can't rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while assembling
# the arm, you could end up with a servo with a position 0 or 4095 at a crucial point
for name in self.arm.names:
if name != "gripper":
self.arm.write("Operating_Mode", name, OperatingMode.EXTENDED_POSITION.value)
# Use 'position control current based' for gripper to be limited by the limit of the current.
# For the follower gripper, it means it can grasp an object without forcing too much even tho,
# its goal position is a complete grasp (both gripper fingers are ordered to join and reach a touch).
# For the leader gripper, it means we can use it as a physical trigger, since we can force with our finger
# to make it move, and it will move back to its original target position when we release the force.
self.arm.write("Operating_Mode", "gripper", OperatingMode.CURRENT_POSITION.value)
# Use 'position control current based' for gripper to be limited by the limit of the current. For
# the follower gripper, it means it can grasp an object without forcing too much even tho, its
# goal position is a complete grasp (both gripper fingers are ordered to join and reach a touch).
# For the leader gripper, it means we can use it as a physical trigger, since we can force with
# our finger to make it move, and it will move back to its original target position when we
# release the force.
self.arm.write("Operating_Mode", "gripper", OperatingMode.CURRENT_POSITION.value)
# Set better PID values to close the gap between recorded states and actions
# TODO(rcadene): Implement an automatic procedure to set optimal PID values for each motor
self.arm.write("Position_P_Gain", "elbow_flex", 1500)
self.arm.write("Position_I_Gain", "elbow_flex", 0)
self.arm.write("Position_D_Gain", "elbow_flex", 600)
self.arm.enable_torque()
# Set better PID values to close the gap between recorded states and actions
# TODO(rcadene): Implement an automatic procedure to set optimal PID values for each motor
self.arm.write("Position_P_Gain", "elbow_flex", 1500)
self.arm.write("Position_I_Gain", "elbow_flex", 0)
self.arm.write("Position_D_Gain", "elbow_flex", 600)
def get_observation(self) -> dict[str, Any]:
if not self.is_connected:

View File

@ -69,9 +69,15 @@ conda activate lerobot
git clone https://github.com/huggingface/lerobot.git ~/lerobot
```
#### 5. Install LeRobot with dependencies for the feetech motors:
#### 5. Install ffmpeg in your environment:
When using `miniconda`, install `ffmpeg` in your environment:
```bash
cd ~/lerobot && pip install --no-binary=av -e ".[feetech]"
conda install ffmpeg -c conda-forge
```
#### 6. Install LeRobot with dependencies for the feetech motors:
```bash
cd ~/lerobot && pip install -e ".[feetech]"
```
## C. Install LeRobot on laptop
@ -110,9 +116,15 @@ conda activate lerobot
git clone https://github.com/huggingface/lerobot.git ~/lerobot
```
#### 5. Install LeRobot with dependencies for the feetech motors:
#### 5. Install ffmpeg in your environment:
When using `miniconda`, install `ffmpeg` in your environment:
```bash
cd ~/lerobot && pip install --no-binary=av -e ".[feetech]"
conda install ffmpeg -c conda-forge
```
#### 6. Install LeRobot with dependencies for the feetech motors:
```bash
cd ~/lerobot && pip install -e ".[feetech]"
```
Great :hugs:! You are now done installing LeRobot and we can begin assembling the SO100 arms and Mobile base :robot:.
@ -414,6 +426,8 @@ python lerobot/scripts/control_robot.py \
--control.fps=30
```
> **NOTE:** To visualize the data, enable `--control.display_data=true`. This streams the data using `rerun`. For the `--control.type=remote_robot` you will also need to set `--control.viewer_ip` and `--control.viewer_port`
You should see on your laptop something like this: ```[INFO] Connected to remote robot at tcp://172.17.133.91:5555 and video stream at tcp://172.17.133.91:5556.``` Now you can move the leader arm and use the keyboard (w,a,s,d) to drive forward, left, backwards, right. And use (z,x) to turn left or turn right. You can use (r,f) to increase and decrease the speed of the mobile robot. There are three speed modes, see the table below:
| Speed Mode | Linear Speed (m/s) | Rotation Speed (deg/s) |
| ---------- | ------------------ | ---------------------- |

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
from dataclasses import dataclass, field
from lerobot.common.cameras.configs import CameraConfig
@ -20,6 +21,14 @@ from lerobot.common.cameras.opencv.configuration_opencv import OpenCVCameraConfi
from ..config import RobotConfig
class RobotMode(enum.Enum):
TELEOP = 0
AUTO = 1
# TODO(Steven): Consider sending config at initial step over a socket
# However, this isn't practical because anyways we have to configure the
# socket ports to begin with
@RobotConfig.register_subclass("lekiwi")
@dataclass
class LeKiwiConfig(RobotConfig):
@ -43,6 +52,10 @@ class LeKiwiConfig(RobotConfig):
}
)
# Network Configuration
port_zmq_cmd: int = 5555
port_zmq_observations: int = 5556
@RobotConfig.register_subclass("lekiwi_client")
@dataclass
@ -68,3 +81,5 @@ class LeKiwiClientConfig(RobotConfig):
"quit": "q",
}
)
robot_mode: RobotMode | None = None

View File

@ -116,9 +116,6 @@ class LeKiwi(Robot):
def is_calibrated(self) -> bool:
return self.bus.is_calibrated
# TODO(Steven): I think we should extend this to give the user the option of re-calibrate
# calibrate(recalibrate: bool = False) -> None:
# If true, then we overwrite the previous calibration file with new values
def calibrate(self) -> None:
logger.info(f"\nRunning calibration of {self}")
@ -131,9 +128,15 @@ class LeKiwi(Robot):
input("Move robot to the middle of its range of motion and press ENTER....")
homing_offsets = self.bus.set_half_turn_homings(motors)
# TODO(Steven): Previously homig_offsets was called only on self.arm_motors
# After a discussion, we said it was better to keep it like this and then
# just populate with the rest of motors. However, I don't know which value
# should we use for this
# homing_offsets.update({k,None???} for k in self.base_motors)
# TODO(Steven): Might be worth to do this also in other robots but it should be added in the docs
full_turn_motor = [
motor for motor in motors if any(keyword in motor for keyword in ["wheel", "gripper"])
motor for motor in motors if any(keyword in motor for keyword in ["wheel", "wrist"])
]
unknown_range_motors = [motor for motor in motors if motor not in full_turn_motor]
@ -180,7 +183,7 @@ class LeKiwi(Robot):
for name in self.base_motors:
self.bus.write("Operating_Mode", name, OperatingMode.VELOCITY.value)
self.bus.enable_torque() # TODO(Steven): Operation has failed with: ConnectionError: Failed to write 'Lock' on id_=6 with '1' after 1 tries. [TxRxResult] Incorrect status packet!
self.bus.enable_torque()
def get_observation(self) -> dict[str, Any]:
if not self.is_connected:
@ -191,7 +194,7 @@ class LeKiwi(Robot):
# Read actuators position for arm and vel for base
start = time.perf_counter()
arm_pos = self.bus.sync_read("Present_Position", self.arm_motors)
base_vel = self.bus.sync_read("Present_Speed", self.base_motors)
base_vel = self.bus.sync_read("Present_Velocity", self.base_motors)
obs_dict[OBS_STATE] = {**arm_pos, **base_vel}
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read state: {dt_ms:.1f}ms")
@ -238,12 +241,12 @@ class LeKiwi(Robot):
# Send goal position to the actuators
self.bus.sync_write("Goal_Position", arm_goal_pos)
self.bus.sync_write("Goal_Speed", base_goal_vel)
self.bus.sync_write("Goal_Velocity", base_goal_vel)
return {**arm_goal_pos, **base_goal_vel}
def stop_base(self):
self.bus.sync_write("Goal_Speed", dict.fromkeys(self.base_motors, 0), num_retry=5)
self.bus.sync_write("Goal_Velocity", dict.fromkeys(self.base_motors, 0), num_retry=5)
logger.info("Base motors stopped")
def disconnect(self):

View File

@ -24,10 +24,9 @@ import zmq
from lerobot.common.constants import OBS_IMAGES, OBS_STATE
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError, InvalidActionError
from lerobot.common.robots.config import RobotMode
from ..robot import Robot
from .config_lekiwi import LeKiwiClientConfig
from .config_lekiwi import LeKiwiClientConfig, RobotMode
# TODO(Steven): This doesn't need to inherit from Robot
@ -81,9 +80,8 @@ class LeKiwiClient(Robot):
@property
def state_feature(self) -> dict:
# TODO(Steven): Get this from the data fetched?
# TODO(Steven): Motor names are unknown for the Daemon
# Or assume its size/metadata?
# TODO(Steven): Get this from the data fetched? Motor names are unknown for the Daemon
# For now we assume its size/metadata is known
return {
"dtype": "float64",
"shape": (9,),
@ -108,9 +106,8 @@ class LeKiwiClient(Robot):
@property
def camera_features(self) -> dict[str, dict]:
# TODO(Steven): Get this from the data fetched?
# TODO(Steven): camera names are unknown for the Daemon
# Or assume its size/metadata?
# TODO(Steven): Get this from the data fetched? Motor names are unknown for the Daemon
# For now we assume its size/metadata is known
# TODO(Steven): Check consistency of image sizes
cam_ft = {
"front": {
@ -128,6 +125,8 @@ class LeKiwiClient(Robot):
@property
def is_connected(self) -> bool:
# TODO(Steven): Ideally we could check instead the status of the sockets
# I didn't find any API that allows us to do that easily
return self._is_connected
@property
@ -137,6 +136,7 @@ class LeKiwiClient(Robot):
def connect(self) -> None:
"""Establishes ZMQ sockets with the remote mobile robot"""
# TODO(Steven): Consider instead returning a bool + warn
if self._is_connected:
raise DeviceAlreadyConnectedError(
"LeKiwi Daemon is already connected. Do not run `robot.connect()` twice."
@ -173,7 +173,7 @@ class LeKiwiClient(Robot):
speed_int = -0x8000 # -32768 -> minimum negative value
return speed_int
# Copied from robot_lekiwi MobileManipulator class
# Copied from robot_lekiwi MobileManipulator class* (before the refactor)
@staticmethod
def _raw_to_degps(raw_speed: int) -> float:
steps_per_deg = 4096.0 / 360.0
@ -181,7 +181,7 @@ class LeKiwiClient(Robot):
degps = magnitude / steps_per_deg
return degps
# Copied from robot_lekiwi MobileManipulator class
# Copied from robot_lekiwi MobileManipulator class* (before the refactor)
def _body_to_wheel_raw(
self,
x_cmd: float,
@ -284,6 +284,7 @@ class LeKiwiClient(Robot):
# TODO(Steven): This is flaky, for example, if we received a state but failed decoding the image, we will not update any value
# TODO(Steven): All this function needs to be refactored
# Copied from robot_lekiwi MobileManipulator class* (before the refactor)
def _get_data(self):
# Copied from robot_lekiwi.py
"""Polls the video socket for up to 15 ms. If data arrives, decode only
@ -368,7 +369,7 @@ class LeKiwiClient(Robot):
if not self._is_connected:
raise DeviceNotConnectedError("LeKiwiClient is not connected. You need to run `robot.connect()`.")
# TODO(Steven): remove hard-coded cam name
# TODO(Steven): remove hard-coded cam names & dims
# This is needed at init for when there's no comms
obs_dict = {
OBS_IMAGES: {"wrist": np.zeros(shape=(480, 640, 3)), "front": np.zeros(shape=(640, 480, 3))}
@ -376,7 +377,7 @@ class LeKiwiClient(Robot):
frames, present_speed, remote_arm_state_tensor = self._get_data()
body_state = self._wheel_raw_to_body(present_speed)
# TODO(Steven): output isdict[str,Any] and we multiply by 1000.0. This should be more explicit and specify the expected type instead of Any
# TODO(Steven): output is dict[str,Any] and we multiply by 1000.0. This should be more explicit and specify the expected type instead of Any
body_state_mm = {k: v * 1000.0 for k, v in body_state.items()} # Convert x,y to mm/s
obs_dict[OBS_STATE] = {**remote_arm_state_tensor, **body_state_mm}
@ -422,7 +423,7 @@ class LeKiwiClient(Robot):
def configure(self):
pass
# TODO(Steven): This assumes this call is always called from a keyboard teleop command
# TODO(Steven): This assumes this call is always called with a keyboard as a teleop device. It breaks if we teleop with other device
# TODO(Steven): Doing this mapping in here adds latecy between send_action and movement from the user perspective.
# t0: get teleop_cmd
# t1: send_action(teleop_cmd)
@ -453,7 +454,7 @@ class LeKiwiClient(Robot):
if self.robot_mode is RobotMode.AUTO:
# TODO(Steven): Not yet implemented. The policy outputs might need a different conversion
raise Exception
raise InvalidActionError("Policy output as action input is not yet well defined")
goal_pos = {}
# TODO(Steven): This assumes teleop mode is always used with keyboard. Tomorrow we could teleop with another device ... ?

View File

@ -27,21 +27,17 @@ from lerobot.common.constants import OBS_IMAGES
from .config_lekiwi import LeKiwiConfig
from .lekiwi import LeKiwi
# Network Configuration
PORT_ZMQ_CMD: int = 5555
PORT_ZMQ_OBSERVATIONS: int = 5556
class HostAgent:
def __init__(self):
def __init__(self, port_zmq_cmd, port_zmq_observations):
self.zmq_context = zmq.Context()
self.zmq_cmd_socket = self.zmq_context.socket(zmq.PULL)
self.zmq_cmd_socket.setsockopt(zmq.CONFLATE, 1)
self.zmq_cmd_socket.bind(f"tcp://*:{PORT_ZMQ_CMD}")
self.zmq_cmd_socket.bind(f"tcp://*:{port_zmq_cmd}")
self.zmq_observation_socket = self.zmq_context.socket(zmq.PUSH)
self.zmq_observation_socket.setsockopt(zmq.CONFLATE, 1)
self.zmq_observation_socket.bind(f"tcp://*:{PORT_ZMQ_OBSERVATIONS}")
self.zmq_observation_socket.bind(f"tcp://*:{port_zmq_observations}")
def disconnect(self):
self.zmq_observation_socket.close()
@ -58,7 +54,7 @@ def main():
robot.connect()
logging.info("Starting HostAgent")
remote_agent = HostAgent()
remote_agent = HostAgent(robot_config.port_zmq_cmd, robot_config.port_zmq_observations)
last_cmd_time = time.time()
logging.info("Waiting for commands...")

View File

@ -31,9 +31,15 @@ conda create -y -n lerobot python=3.10 && conda activate lerobot
git clone https://github.com/huggingface/lerobot.git ~/lerobot
```
5. Install LeRobot with dependencies for the feetech motors:
5. Install ffmpeg in your environment:
When using `miniconda`, install `ffmpeg` in your environment:
```bash
cd ~/lerobot && pip install --no-binary=av -e ".[feetech]"
conda install ffmpeg -c conda-forge
```
6. Install LeRobot with dependencies for the feetech motors:
```bash
cd ~/lerobot && pip install -e ".[feetech]"
```
## Configure the motors
@ -212,6 +218,9 @@ python lerobot/scripts/control_robot.py \
**Teleop with displaying cameras**
Follow [this guide to setup your cameras](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#c-add-your-cameras-with-opencvcamera). Then you will be able to display the cameras on your computer while you are teleoperating by running the following code. This is useful to prepare your setup before recording your first dataset.
> **NOTE:** To visualize the data, enable `--control.display_data=true`. This streams the data using `rerun`.
```bash
python lerobot/scripts/control_robot.py \
--robot.type=moss \

View File

@ -22,11 +22,10 @@ class Robot(abc.ABC):
def __init__(self, config: RobotConfig):
self.robot_type = self.name
self.id = config.id
self.robot_mode = config.robot_mode
self.calibration_dir = (
Path(config.calibration_dir)
if config.calibration_dir
else HF_LEROBOT_CALIBRATION / ROBOTS / self.name
else Path(HF_LEROBOT_CALIBRATION / ROBOTS / self.name)
)
self.calibration_dir.mkdir(parents=True, exist_ok=True)
self.calibration_fpath = self.calibration_dir / f"{self.id}.json"

View File

@ -57,9 +57,15 @@ conda activate lerobot
git clone https://github.com/huggingface/lerobot.git ~/lerobot
```
#### 5. Install LeRobot with dependencies for the feetech motors:
#### 5. Install ffmpeg in your environment:
When using `miniconda`, install `ffmpeg` in your environment:
```bash
cd ~/lerobot && pip install --no-binary=av -e ".[feetech]"
conda install ffmpeg -c conda-forge
```
#### 6. Install LeRobot with dependencies for the feetech motors:
```bash
cd ~/lerobot && pip install -e ".[feetech]"
```
Great :hugs:! You are now done installing LeRobot and we can begin assembling the SO100 arms :robot:.
@ -491,6 +497,9 @@ python lerobot/scripts/control_robot.py \
#### a. Teleop with displaying cameras
Follow [this guide to setup your cameras](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#c-add-your-cameras-with-opencvcamera). Then you will be able to display the cameras on your computer while you are teleoperating by running the following code. This is useful to prepare your setup before recording your first dataset.
> **NOTE:** To visualize the data, enable `--control.display_data=true`. This streams the data using `rerun`.
```bash
python lerobot/scripts/control_robot.py \
--robot.type=so100 \

View File

@ -55,6 +55,7 @@ class SO100Follower(Robot):
"wrist_roll": Motor(5, "sts3215", MotorNormMode.RANGE_M100_100),
"gripper": Motor(6, "sts3215", MotorNormMode.RANGE_0_100),
},
calibration=self.calibration,
)
self.cameras = make_cameras_from_configs(config.cameras)
@ -120,7 +121,7 @@ class SO100Follower(Robot):
full_turn_motor = "wrist_roll"
unknown_range_motors = [name for name in self.arm.names if name != full_turn_motor]
logger.info(
print(
f"Move all joints except '{full_turn_motor}' sequentially through their "
"entire ranges of motion.\nRecording positions. Press ENTER to stop..."
)
@ -143,21 +144,15 @@ class SO100Follower(Robot):
print("Calibration saved to", self.calibration_fpath)
def configure(self) -> None:
self.arm.disable_torque()
self.arm.configure_motors()
for name in self.arm.names:
self.arm.write("Operating_Mode", name, OperatingMode.POSITION.value)
# Set P_Coefficient to lower value to avoid shakiness (Default is 32)
self.arm.write("P_Coefficient", name, 16)
# Set I_Coefficient and D_Coefficient to default value 0 and 32
self.arm.write("I_Coefficient", name, 0)
self.arm.write("D_Coefficient", name, 32)
# Set Maximum_Acceleration to 254 to speedup acceleration and deceleration of
# the motors. Note: this address is not in the official STS3215 Memory Table
self.arm.write("Maximum_Acceleration", name, 254)
self.arm.write("Acceleration", name, 254)
self.arm.enable_torque()
with self.arm.torque_disabled():
self.arm.configure_motors()
for name in self.arm.names:
self.arm.write("Operating_Mode", name, OperatingMode.POSITION.value)
# Set P_Coefficient to lower value to avoid shakiness (Default is 32)
self.arm.write("P_Coefficient", name, 16)
# Set I_Coefficient and D_Coefficient to default value 0 and 32
self.arm.write("I_Coefficient", name, 0)
self.arm.write("D_Coefficient", name, 32)
def get_observation(self) -> dict[str, Any]:
if not self.is_connected:

View File

@ -43,14 +43,19 @@ conda create -y -n lerobot python=3.10 && conda activate lerobot
git clone https://github.com/huggingface/lerobot.git ~/lerobot
```
6. Install LeRobot with stretch dependencies:
6. When using `miniconda`, install `ffmpeg` in your environment:
```bash
cd ~/lerobot && pip install --no-binary=av -e ".[stretch]"
conda install ffmpeg -c conda-forge
```
7. Install LeRobot with stretch dependencies:
```bash
cd ~/lerobot && pip install -e ".[stretch]"
```
> **Note:** If you get this message, you can ignore it: `ERROR: pip's dependency resolver does not currently take into account all the packages that are installed.`
7. Run a [system check](https://docs.hello-robot.com/0.3/getting_started/stretch_hardware_overview/#system-check) to make sure your robot is ready:
8. Run a [system check](https://docs.hello-robot.com/0.3/getting_started/stretch_hardware_overview/#system-check) to make sure your robot is ready:
```bash
stretch_system_check.py
```
@ -97,6 +102,8 @@ This is equivalent to running `stretch_robot_home.py`
Before trying teleoperation, you need activate the gamepad controller by pressing the middle button. For more info, see Stretch's [doc](https://docs.hello-robot.com/0.3/getting_started/hello_robot/#gamepad-teleoperation).
Now try out teleoperation (see above documentation to learn about the gamepad controls):
> **NOTE:** To visualize the data, enable `--control.display_data=true`. This streams the data using `rerun`.
```bash
python lerobot/scripts/control_robot.py \
--robot.type=stretch \

View File

@ -65,9 +65,10 @@ def make_robot_from_config(config: RobotConfig):
return ManipulatorRobot(config)
elif isinstance(config, LeKiwiConfig):
from lerobot.common.robots.mobile_manipulator import MobileManipulator
return MobileManipulator(config)
# TODO(Steven): Change when we decide what to do with these scripts
# from lerobot.common.robots.mobile_manipulator import MobileManipulator
# return MobileManipulator(config)
...
else:
from lerobot.common.robots.stretch3.robot_stretch3 import Stretch3Robot

View File

@ -30,9 +30,14 @@ conda create -y -n lerobot python=3.10 && conda activate lerobot
git clone https://github.com/huggingface/lerobot.git ~/lerobot
```
5. Install LeRobot with dependencies for the Aloha motors (dynamixel) and cameras (intelrealsense):
5. When using `miniconda`, install `ffmpeg` in your environment:
```bash
cd ~/lerobot && pip install --no-binary=av -e ".[dynamixel, intelrealsense]"
conda install ffmpeg -c conda-forge
```
6. Install LeRobot with dependencies for the Aloha motors (dynamixel) and cameras (intelrealsense):
```bash
cd ~/lerobot && pip install -e ".[dynamixel, intelrealsense]"
```
## Teleoperate
@ -43,6 +48,9 @@ Teleoperation consists in manually operating the leader arms to move the followe
2. Our code assumes that your robot has been assembled following Trossen Robotics instructions. This allows us to skip calibration, as we use the pre-defined calibration files in `.cache/calibration/aloha_default`. If you replace a motor, make sure you follow the exact instructions from Trossen Robotics.
By running the following code, you can start your first **SAFE** teleoperation:
> **NOTE:** To visualize the data, enable `--control.display_data=true`. This streams the data using `rerun`.
```bash
python lerobot/scripts/control_robot.py \
--robot.type=aloha \

View File

@ -117,7 +117,7 @@ class ViperX(Robot):
full_turn_motors = ["shoulder_pan", "wrist_roll"]
unknown_range_motors = [name for name in self.arm.names if name not in full_turn_motors]
logger.info(
print(
f"Move all joints except {full_turn_motors} sequentially through their entire "
"ranges of motion.\nRecording positions. Press ENTER to stop..."
)
@ -141,32 +141,31 @@ class ViperX(Robot):
logger.info(f"Calibration saved to {self.calibration_fpath}")
def configure(self) -> None:
self.arm.disable_torque()
self.arm.configure_motors()
with self.arm.torque_disabled():
self.arm.configure_motors()
# Set secondary/shadow ID for shoulder and elbow. These joints have two motors.
# As a result, if only one of them is required to move to a certain position,
# the other will follow. This is to avoid breaking the motors.
self.arm.write("Secondary_ID", "shoulder_shadow", 2)
self.arm.write("Secondary_ID", "elbow_shadow", 4)
# Set secondary/shadow ID for shoulder and elbow. These joints have two motors.
# As a result, if only one of them is required to move to a certain position,
# the other will follow. This is to avoid breaking the motors.
self.arm.write("Secondary_ID", "shoulder_shadow", 2)
self.arm.write("Secondary_ID", "elbow_shadow", 4)
# Set a velocity limit of 131 as advised by Trossen Robotics
# TODO(aliberts): remove as it's actually useless in position control
self.arm.write("Velocity_Limit", 131)
# Set a velocity limit of 131 as advised by Trossen Robotics
# TODO(aliberts): remove as it's actually useless in position control
self.arm.write("Velocity_Limit", 131)
# Use 'extended position mode' for all motors except gripper, because in joint mode the servos can't
# rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while assembling the arm,
# you could end up with a servo with a position 0 or 4095 at a crucial point. See:
# https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11
for name in self.arm.names:
if name != "gripper":
self.arm.write("Operating_Mode", name, OperatingMode.EXTENDED_POSITION.value)
# Use 'extended position mode' for all motors except gripper, because in joint mode the servos
# can't rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while assembling
# the arm, you could end up with a servo with a position 0 or 4095 at a crucial point.
# See: https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11
for name in self.arm.names:
if name != "gripper":
self.arm.write("Operating_Mode", name, OperatingMode.EXTENDED_POSITION.value)
# Use 'position control current based' for follower gripper to be limited by the limit of the current.
# It can grasp an object without forcing too much even tho, it's goal position is a complete grasp
# (both gripper fingers are ordered to join and reach a touch).
self.arm.write("Operating_Mode", "gripper", OperatingMode.CURRENT_POSITION.value)
self.arm.enable_torque()
# Use 'position control current based' for follower gripper to be limited by the limit of the
# current. It can grasp an object without forcing too much even tho, it's goal position is a
# complete grasp (both gripper fingers are ordered to join and reach a touch).
self.arm.write("Operating_Mode", "gripper", OperatingMode.CURRENT_POSITION.value)
def get_observation(self) -> dict[str, Any]:
"""The returned observations do not have a batch dimension."""

View File

@ -22,5 +22,5 @@ from ..config import TeleoperatorConfig
@TeleoperatorConfig.register_subclass("keyboard")
@dataclass
class KeyboardTeleopConfig(TeleoperatorConfig):
# TODO(Steven): Maybe set in here the keys that we want to capture/listen
# TODO(Steven): Consider setting in here the keys that we want to capture/listen
mock: bool = False

View File

@ -63,12 +63,8 @@ class KeyboardTeleop(Teleoperator):
@property
def action_feature(self) -> dict:
# TODO(Steven): Verify this is correct
return {
"dtype": "float32",
"shape": (len(self.arm),),
"names": {"motors": list(self.arm.motors)},
}
# TODO(Steven): Change this when we agree what should this return
...
@property
def feedback_feature(self) -> dict:
@ -83,15 +79,15 @@ class KeyboardTeleop(Teleoperator):
pass
def connect(self) -> None:
# TODO(Steven): Consider instead of raising a warning and then returning the status
# TODO(Steven): Consider early return instead of raising a warning
# if self._is_connected:
# logging.warning(
# "ManipulatorRobot is already connected. Do not run `robot.connect()` twice."
# "Keyboard is already connected. Do not run `robot.connect()` twice."
# )
# return self._is_connected
if self._is_connected:
raise DeviceAlreadyConnectedError(
"ManipulatorRobot is already connected. Do not run `robot.connect()` twice."
"Keyboard is already connected. Do not run `robot.connect()` twice."
)
if PYNPUT_AVAILABLE:

View File

@ -102,7 +102,7 @@ class KochLeader(Teleoperator):
full_turn_motors = ["shoulder_pan", "wrist_roll"]
unknown_range_motors = [name for name in self.arm.names if name not in full_turn_motors]
logger.info(
print(
f"Move all joints except {full_turn_motors} sequentially through their "
"entire ranges of motion.\nRecording positions. Press ENTER to stop..."
)

View File

@ -51,6 +51,7 @@ class SO100Leader(Teleoperator):
"wrist_roll": Motor(5, "sts3215", MotorNormMode.RANGE_M100_100),
"gripper": Motor(6, "sts3215", MotorNormMode.RANGE_0_100),
},
calibration=self.calibration,
)
@property

View File

@ -45,9 +45,6 @@ class Teleoperator(abc.ABC):
def is_connected(self) -> bool:
pass
# TODO(Steven): I think connect() should return a bool, such that the client/application code can check if the device connected successfully
# if not device.connect():
# raise DeviceNotConnectedError(f"{device} failed to connect")
@abc.abstractmethod
def connect(self) -> None:
"""Connects to the teleoperator."""

View File

@ -99,7 +99,7 @@ class WidowX(Teleoperator):
full_turn_motors = ["shoulder_pan", "wrist_roll"]
unknown_range_motors = [name for name in self.arm.names if name not in full_turn_motors]
logger.info(
print(
f"Move all joints except {full_turn_motors} sequentially through their "
"entire ranges of motion.\nRecording positions. Press ENTER to stop..."
)

View File

@ -24,7 +24,7 @@ from contextlib import nullcontext
from copy import copy
from functools import cache
import cv2
import rerun as rr
import torch
from deepdiff import DeepDiff
from termcolor import colored
@ -174,13 +174,13 @@ def warmup_record(
events,
enable_teleoperation,
warmup_time_s,
display_cameras,
display_data,
fps,
):
control_loop(
robot=robot,
control_time_s=warmup_time_s,
display_cameras=display_cameras,
display_data=display_data,
events=events,
fps=fps,
teleoperate=enable_teleoperation,
@ -192,7 +192,7 @@ def record_episode(
dataset,
events,
episode_time_s,
display_cameras,
display_data,
policy,
fps,
single_task,
@ -200,7 +200,7 @@ def record_episode(
control_loop(
robot=robot,
control_time_s=episode_time_s,
display_cameras=display_cameras,
display_data=display_data,
dataset=dataset,
events=events,
policy=policy,
@ -215,7 +215,7 @@ def control_loop(
robot,
control_time_s=None,
teleoperate=False,
display_cameras=False,
display_data=False,
dataset: LeRobotDataset | None = None,
events=None,
policy: PreTrainedPolicy = None,
@ -264,11 +264,15 @@ def control_loop(
frame = {**observation, **action, "task": single_task}
dataset.add_frame(frame)
if display_cameras and not is_headless():
# TODO(Steven): This should be more general (for RemoteRobot instead of checking the name, but anyways it will change soon)
if (display_data and not is_headless()) or (display_data and robot.robot_type.startswith("lekiwi")):
for k, v in action.items():
for i, vv in enumerate(v):
rr.log(f"sent_{k}_{i}", rr.Scalar(vv.numpy()))
image_keys = [key for key in observation if "image" in key]
for key in image_keys:
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
cv2.waitKey(1)
rr.log(key, rr.Image(observation[key].numpy()), static=True)
if fps is not None:
dt_s = time.perf_counter() - start_loop_t
@ -297,15 +301,11 @@ def reset_environment(robot, events, reset_time_s, fps):
)
def stop_recording(robot, listener, display_cameras):
def stop_recording(robot, listener, display_data):
robot.disconnect()
if not is_headless():
if listener is not None:
listener.stop()
if display_cameras:
cv2.destroyAllWindows()
if not is_headless() and listener is not None:
listener.stop()
def sanity_check_dataset_name(repo_id, policy_cfg):

View File

@ -41,7 +41,7 @@ class TeleoperateControlConfig(ControlConfig):
fps: int | None = None
teleop_time_s: float | None = None
# Display all cameras on screen
display_cameras: bool = True
display_data: bool = False
@ControlConfig.register_subclass("record")
@ -82,7 +82,7 @@ class RecordControlConfig(ControlConfig):
# Not enough threads might cause low camera fps.
num_image_writer_threads_per_camera: int = 4
# Display all cameras on screen
display_cameras: bool = True
display_data: bool = False
# Use vocal synthesis to read events.
play_sounds: bool = True
# Resume recording on an existing dataset.
@ -116,6 +116,11 @@ class ReplayControlConfig(ControlConfig):
@dataclass
class RemoteRobotConfig(ControlConfig):
log_interval: int = 100
# Display all cameras on screen
display_data: bool = False
# Rerun configuration for remote robot (https://ref.rerun.io/docs/python/0.22.1/common/initialization_functions/#rerun.connect_tcp)
viewer_ip: str | None = None
viewer_port: str | None = None
@dataclass

View File

@ -135,10 +135,13 @@ python lerobot/scripts/control_robot.py \
"""
import logging
import os
import time
from dataclasses import asdict
from pprint import pformat
import rerun as rr
# from safetensors.torch import load_file, save_file
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.policies.factory import make_policy
@ -146,6 +149,7 @@ from lerobot.common.robots.utils import Robot, make_robot_from_config
from lerobot.common.utils.control_utils import (
control_loop,
init_keyboard_listener,
is_headless,
log_control_info,
record_episode,
reset_environment,
@ -159,6 +163,7 @@ from lerobot.common.utils.utils import has_method, init_logging, log_say
from lerobot.configs import parser
from lerobot.configs.control import (
CalibrateControlConfig,
ControlConfig,
ControlPipelineConfig,
RecordControlConfig,
RemoteRobotConfig,
@ -232,7 +237,7 @@ def teleoperate(robot: Robot, cfg: TeleoperateControlConfig):
control_time_s=cfg.teleop_time_s,
fps=cfg.fps,
teleoperate=True,
display_cameras=cfg.display_cameras,
display_data=cfg.display_data,
)
@ -280,7 +285,7 @@ def record(
# 3. place the cameras windows on screen
enable_teleoperation = policy is None
log_say("Warmup record", cfg.play_sounds)
warmup_record(robot, events, enable_teleoperation, cfg.warmup_time_s, cfg.display_cameras, cfg.fps)
warmup_record(robot, events, enable_teleoperation, cfg.warmup_time_s, cfg.display_data, cfg.fps)
if has_method(robot, "teleop_safety_stop"):
robot.teleop_safety_stop()
@ -296,7 +301,7 @@ def record(
dataset=dataset,
events=events,
episode_time_s=cfg.episode_time_s,
display_cameras=cfg.display_cameras,
display_data=cfg.display_data,
policy=policy,
fps=cfg.fps,
single_task=cfg.single_task,
@ -326,7 +331,7 @@ def record(
break
log_say("Stop recording", cfg.play_sounds, blocking=True)
stop_recording(robot, listener, cfg.display_cameras)
stop_recording(robot, listener, cfg.display_data)
if cfg.push_to_hub:
dataset.push_to_hub(tags=cfg.tags, private=cfg.private)
@ -363,6 +368,40 @@ def replay(
log_control_info(robot, dt_s, fps=cfg.fps)
def _init_rerun(control_config: ControlConfig, session_name: str = "lerobot_control_loop") -> None:
"""Initializes the Rerun SDK for visualizing the control loop.
Args:
control_config: Configuration determining data display and robot type.
session_name: Rerun session name. Defaults to "lerobot_control_loop".
Raises:
ValueError: If viewer IP is missing for non-remote configurations with display enabled.
"""
if (control_config.display_data and not is_headless()) or (
control_config.display_data and isinstance(control_config, RemoteRobotConfig)
):
# Configure Rerun flush batch size default to 8KB if not set
batch_size = os.getenv("RERUN_FLUSH_NUM_BYTES", "8000")
os.environ["RERUN_FLUSH_NUM_BYTES"] = batch_size
# Initialize Rerun based on configuration
rr.init(session_name)
if isinstance(control_config, RemoteRobotConfig):
viewer_ip = control_config.viewer_ip
viewer_port = control_config.viewer_port
if not viewer_ip or not viewer_port:
raise ValueError(
"Viewer IP & Port are required for remote config. Set via config file/CLI or disable control_config.display_data."
)
logging.info(f"Connecting to viewer at {viewer_ip}:{viewer_port}")
rr.connect_tcp(f"{viewer_ip}:{viewer_port}")
else:
# Get memory limit for rerun viewer parameters
memory_limit = os.getenv("LEROBOT_RERUN_MEMORY_LIMIT", "10%")
rr.spawn(memory_limit=memory_limit)
@parser.wrap()
def control_robot(cfg: ControlPipelineConfig):
init_logging()
@ -370,18 +409,24 @@ def control_robot(cfg: ControlPipelineConfig):
robot = make_robot_from_config(cfg.robot)
# TODO(Steven): Blueprint for fixed window size
if isinstance(cfg.control, CalibrateControlConfig):
calibrate(robot, cfg.control)
elif isinstance(cfg.control, TeleoperateControlConfig):
_init_rerun(control_config=cfg.control, session_name="lerobot_control_loop_teleop")
teleoperate(robot, cfg.control)
elif isinstance(cfg.control, RecordControlConfig):
_init_rerun(control_config=cfg.control, session_name="lerobot_control_loop_record")
record(robot, cfg.control)
elif isinstance(cfg.control, ReplayControlConfig):
replay(robot, cfg.control)
elif isinstance(cfg.control, RemoteRobotConfig):
from lerobot.common.robots.lekiwi.old_lekiwi_remote import run_lekiwi
run_lekiwi(cfg.robot)
...
# TODO(Steven): Change this when we decide what to do with the control_robot script
# from lerobot.common.robots.lekiwi.old_lekiwi_remote import run_lekiwi
# _init_rerun(control_config=cfg.control, session_name="lerobot_control_loop_remote")
# run_lekiwi(cfg.robot)
if robot.is_connected:
# Disconnect manually to avoid a "Core dump" during process

View File

@ -60,9 +60,9 @@ dependencies = [
"jsonlines>=4.0.0",
"numba>=0.59.0",
"omegaconf>=2.3.0",
"opencv-python>=4.9.0",
"opencv-python-headless>=4.9.0",
"packaging>=24.2",
"av>=12.0.5,<13.0.0",
"av>=12.0.5",
"pymunk>=6.6.0",
"pynput>=1.7.7",
"pyzmq>=26.2.1",

View File

@ -5,7 +5,7 @@ import dynamixel_sdk as dxl
import serial
from mock_serial.mock_serial import MockSerial
from lerobot.common.motors.dynamixel import X_SERIES_CONTROL_TABLE, DynamixelMotorsBus
from lerobot.common.motors.dynamixel.dynamixel import _split_into_byte_chunks
from .mock_serial_patch import WaitableStub
@ -45,41 +45,6 @@ DXL_CRC_TABLE = [
0x8213, 0x0216, 0x021C, 0x8219, 0x0208, 0x820D, 0x8207, 0x0202
] # fmt: skip
# https://emanual.robotis.com/docs/en/dxl/protocol2/#instruction
INSTRUCTION_TYPES = {
"Ping": dxl.INST_PING, # Checks whether the Packet has arrived at a device with the same ID as the specified packet ID
"Read": dxl.INST_READ, # Read data from the Device
"Write": dxl.INST_WRITE, # Write data to the Device
"Reg_Write": dxl.INST_REG_WRITE, # Register the Instruction Packet in standby status; Packet can later be executed using the Action command
"Action": dxl.INST_ACTION, # Executes a Packet that was registered beforehand using Reg Write
"Factory_Reset": dxl.INST_FACTORY_RESET, # Resets the Control Table to its initial factory default settings
"Reboot": dxl.INST_REBOOT, # Reboot the Device
"Clear": dxl.INST_CLEAR, # Reset certain information stored in memory
"Control_Table_Backup": 0x20, # Store current Control Table status data to a Backup or to restore backup EEPROM data.
"Status": dxl.INST_STATUS, # Return packet sent following the execution of an Instruction Packet
"Sync_Read": dxl.INST_SYNC_READ, # Read data from multiple devices with the same Address with the same length at once
"Sync_Write": dxl.INST_SYNC_WRITE, # Write data to multiple devices with the same Address with the same length at once
"Fast_Sync_Read": 0x8A, # Read data from multiple devices with the same Address with the same length at once
"Bulk_Read": dxl.INST_BULK_READ, # Read data from multiple devices with different Addresses with different lengths at once
"Bulk_Write": dxl.INST_BULK_WRITE, # Write data to multiple devices with different Addresses with different lengths at once
"Fast_Bulk_Read": 0x9A, # Read data from multiple devices with different Addresses with different lengths at once
} # fmt: skip
# https://emanual.robotis.com/docs/en/dxl/protocol2/#error
ERROR_TYPE = {
"Success": 0x00, # No error
"Result_Fail": dxl.ERRNUM_RESULT_FAIL, # Failed to process the sent Instruction Packet
"Instruction_Error": dxl.ERRNUM_INSTRUCTION, # An undefined Instruction has been usedAction has been used without Reg Write
"CRC_Error": dxl.ERRNUM_CRC, # The CRC of the sent Packet does not match the expected value
"Data_Range_Error": dxl.ERRNUM_DATA_RANGE, # Data to be written to the specified Address is outside the range of the minimum/maximum value
"Data_Length_Error": dxl.ERRNUM_DATA_LENGTH, # Attempted to write Data that is shorter than the required data length of the specified Address
# (ex: when you attempt to only use 2 bytes of a register that has been defined as 4 bytes)
"Data_Limit_Error": dxl.ERRNUM_DATA_LIMIT, # Data to be written to the specified Address is outside of the configured Limit value
"Access_Error": dxl.ERRNUM_ACCESS, # Attempted to write a value to an Address that is Read Only or has not been defined
# Attempted to read a value from an Address that is Write Only or has not been defined
# Attempted to write a value to an EEPROM register while Torque was Enabled.
} # fmt: skip
class MockDynamixelPacketv2(abc.ABC):
@classmethod
@ -186,14 +151,14 @@ class MockInstructionPacket(MockDynamixelPacketv2):
"""
@classmethod
def _build(cls, dxl_id: int, params: list[int], length: int, instruct_type: str) -> list[int]:
instruct_value = INSTRUCTION_TYPES[instruct_type]
def _build(cls, dxl_id: int, params: list[int], length: int, instruction: int) -> list[int]:
length = len(params) + 3
return [
0xFF, 0xFF, 0xFD, 0x00, # header
dxl_id, # servo id
dxl.DXL_LOBYTE(length), # length_l
dxl.DXL_HIBYTE(length), # length_h
instruct_value, # instruction type
instruction, # instruction type
*params, # data bytes
0x00, 0x00 # placeholder for CRC
] # fmt: skip
@ -209,8 +174,39 @@ class MockInstructionPacket(MockDynamixelPacketv2):
No parameters required.
"""
params, length = [], 3
return cls.build(dxl_id=dxl_id, params=params, length=length, instruct_type="Ping")
return cls.build(dxl_id=dxl_id, params=[], length=3, instruction=dxl.INST_PING)
@classmethod
def read(
cls,
dxl_id: int,
start_address: int,
data_length: int,
) -> bytes:
"""
Builds a "Read" instruction.
https://emanual.robotis.com/docs/en/dxl/protocol2/#read-0x02
The parameters for Read (Protocol 2.0) are:
param[0] = start_address L
param[1] = start_address H
param[2] = data_length L
param[3] = data_length H
And 'length' = data_length + 5, where:
+1 is for instruction byte,
+2 is for the length bytes,
+2 is for the CRC at the end.
"""
params = [
dxl.DXL_LOBYTE(start_address),
dxl.DXL_HIBYTE(start_address),
dxl.DXL_LOBYTE(data_length),
dxl.DXL_HIBYTE(data_length),
]
length = len(params) + 3
# length = data_length + 5
return cls.build(dxl_id=dxl_id, params=params, length=length, instruction=dxl.INST_READ)
@classmethod
def write(
@ -237,14 +233,14 @@ class MockInstructionPacket(MockDynamixelPacketv2):
+2 is for the length bytes,
+2 is for the CRC at the end.
"""
data = DynamixelMotorsBus._split_int_to_bytes(value, data_length)
data = _split_into_byte_chunks(value, data_length)
params = [
dxl.DXL_LOBYTE(start_address),
dxl.DXL_HIBYTE(start_address),
*data,
]
length = data_length + 5
return cls.build(dxl_id=dxl_id, params=params, length=length, instruct_type="Write")
return cls.build(dxl_id=dxl_id, params=params, length=length, instruction=dxl.INST_WRITE)
@classmethod
def sync_read(
@ -278,7 +274,9 @@ class MockInstructionPacket(MockDynamixelPacketv2):
*dxl_ids,
]
length = len(dxl_ids) + 7
return cls.build(dxl_id=dxl.BROADCAST_ID, params=params, length=length, instruct_type="Sync_Read")
return cls.build(
dxl_id=dxl.BROADCAST_ID, params=params, length=length, instruction=dxl.INST_SYNC_READ
)
@classmethod
def sync_write(
@ -315,7 +313,7 @@ class MockInstructionPacket(MockDynamixelPacketv2):
"""
data = []
for id_, value in ids_values.items():
split_value = DynamixelMotorsBus._split_int_to_bytes(value, data_length)
split_value = _split_into_byte_chunks(value, data_length)
data += [id_, *split_value]
params = [
dxl.DXL_LOBYTE(start_address),
@ -325,7 +323,9 @@ class MockInstructionPacket(MockDynamixelPacketv2):
*data,
]
length = len(ids_values) * (1 + data_length) + 7
return cls.build(dxl_id=dxl.BROADCAST_ID, params=params, length=length, instruct_type="Sync_Write")
return cls.build(
dxl_id=dxl.BROADCAST_ID, params=params, length=length, instruction=dxl.INST_SYNC_WRITE
)
class MockStatusPacket(MockDynamixelPacketv2):
@ -341,21 +341,20 @@ class MockStatusPacket(MockDynamixelPacketv2):
"""
@classmethod
def _build(cls, dxl_id: int, params: list[int], length: int, error: str = "Success") -> list[int]:
err_byte = ERROR_TYPE[error]
def _build(cls, dxl_id: int, params: list[int], length: int, error: int = 0) -> list[int]:
return [
0xFF, 0xFF, 0xFD, 0x00, # header
dxl_id, # servo id
dxl.DXL_LOBYTE(length), # length_l
dxl.DXL_HIBYTE(length), # length_h
0x55, # instruction = 'status'
err_byte, # error
error, # error
*params, # data bytes
0x00, 0x00 # placeholder for CRC
] # fmt: skip
@classmethod
def ping(cls, dxl_id: int, model_nb: int = 1190, firm_ver: int = 50) -> bytes:
def ping(cls, dxl_id: int, model_nb: int = 1190, firm_ver: int = 50, error: int = 0) -> bytes:
"""
Builds a 'Ping' status packet.
https://emanual.robotis.com/docs/en/dxl/protocol2/#ping-0x01
@ -372,10 +371,10 @@ class MockStatusPacket(MockDynamixelPacketv2):
"""
params = [dxl.DXL_LOBYTE(model_nb), dxl.DXL_HIBYTE(model_nb), firm_ver]
length = 7
return cls.build(dxl_id, params=params, length=length)
return cls.build(dxl_id, params=params, length=length, error=error)
@classmethod
def read(cls, dxl_id: int, value: int, param_length: int) -> bytes:
def read(cls, dxl_id: int, value: int, param_length: int, error: int = 0) -> bytes:
"""
Builds a 'Read' status packet (also works for 'Sync Read')
https://emanual.robotis.com/docs/en/dxl/protocol2/#read-0x02
@ -389,9 +388,9 @@ class MockStatusPacket(MockDynamixelPacketv2):
Returns:
bytes: The raw 'Present_Position' status packet ready to be sent through serial.
"""
params = DynamixelMotorsBus._split_int_to_bytes(value, param_length)
params = _split_into_byte_chunks(value, param_length)
length = param_length + 4
return cls.build(dxl_id, params=params, length=length)
return cls.build(dxl_id, params=params, length=length, error=error)
class MockPortHandler(dxl.PortHandler):
@ -425,8 +424,6 @@ class MockMotors(MockSerial):
instruction packets. It is meant to test MotorsBus classes.
"""
ctrl_table = X_SERIES_CONTROL_TABLE
def __init__(self):
super().__init__()
@ -455,10 +452,10 @@ class MockMotors(MockSerial):
return stub_name
def build_ping_stub(
self, dxl_id: int, model_nb: int, firm_ver: int = 50, num_invalid_try: int = 0
self, dxl_id: int, model_nb: int, firm_ver: int = 50, num_invalid_try: int = 0, error: int = 0
) -> str:
ping_request = MockInstructionPacket.ping(dxl_id)
return_packet = MockStatusPacket.ping(dxl_id, model_nb, firm_ver)
return_packet = MockStatusPacket.ping(dxl_id, model_nb, firm_ver, error)
ping_response = self._build_send_fn(return_packet, num_invalid_try)
stub_name = f"Ping_{dxl_id}"
self.stub(
@ -468,14 +465,63 @@ class MockMotors(MockSerial):
)
return stub_name
def build_sync_read_stub(
self, data_name: str, ids_values: dict[int, int] | None = None, num_invalid_try: int = 0
def build_read_stub(
self,
address: int,
length: int,
dxl_id: int,
value: int,
reply: bool = True,
error: int = 0,
num_invalid_try: int = 0,
) -> str:
read_request = MockInstructionPacket.read(dxl_id, address, length)
return_packet = MockStatusPacket.read(dxl_id, value, length, error) if reply else b""
read_response = self._build_send_fn(return_packet, num_invalid_try)
stub_name = f"Read_{address}_{length}_{dxl_id}_{value}_{error}"
self.stub(
name=stub_name,
receive_bytes=read_request,
send_fn=read_response,
)
return stub_name
def build_write_stub(
self,
address: int,
length: int,
dxl_id: int,
value: int,
reply: bool = True,
error: int = 0,
num_invalid_try: int = 0,
) -> str:
sync_read_request = MockInstructionPacket.write(dxl_id, value, address, length)
return_packet = MockStatusPacket.build(dxl_id, params=[], length=4, error=error) if reply else b""
stub_name = f"Write_{address}_{length}_{dxl_id}"
self.stub(
name=stub_name,
receive_bytes=sync_read_request,
send_fn=self._build_send_fn(return_packet, num_invalid_try),
)
return stub_name
def build_sync_read_stub(
self,
address: int,
length: int,
ids_values: dict[int, int],
reply: bool = True,
num_invalid_try: int = 0,
) -> str:
address, length = self.ctrl_table[data_name]
sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length)
return_packets = b"".join(MockStatusPacket.read(id_, pos, length) for id_, pos in ids_values.items())
return_packets = (
b"".join(MockStatusPacket.read(id_, pos, length) for id_, pos in ids_values.items())
if reply
else b""
)
sync_read_response = self._build_send_fn(return_packets, num_invalid_try)
stub_name = f"Sync_Read_{data_name}_" + "_".join([str(id_) for id_ in ids_values])
stub_name = f"Sync_Read_{address}_{length}_" + "_".join([str(id_) for id_ in ids_values])
self.stub(
name=stub_name,
receive_bytes=sync_read_request,
@ -484,11 +530,10 @@ class MockMotors(MockSerial):
return stub_name
def build_sequential_sync_read_stub(
self, data_name: str, ids_values: dict[int, list[int]] | None = None
self, address: int, length: int, ids_values: dict[int, list[int]] | None = None
) -> str:
sequence_length = len(next(iter(ids_values.values())))
assert all(len(positions) == sequence_length for positions in ids_values.values())
address, length = self.ctrl_table[data_name]
sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length)
sequential_packets = []
for count in range(sequence_length):
@ -498,7 +543,7 @@ class MockMotors(MockSerial):
sequential_packets.append(return_packets)
sync_read_response = self._build_sequential_send_fn(sequential_packets)
stub_name = f"Seq_Sync_Read_{data_name}_" + "_".join([str(id_) for id_ in ids_values])
stub_name = f"Seq_Sync_Read_{address}_{length}_" + "_".join([str(id_) for id_ in ids_values])
self.stub(
name=stub_name,
receive_bytes=sync_read_request,
@ -507,11 +552,10 @@ class MockMotors(MockSerial):
return stub_name
def build_sync_write_stub(
self, data_name: str, ids_values: dict[int, int] | None = None, num_invalid_try: int = 0
self, address: int, length: int, ids_values: dict[int, int], num_invalid_try: int = 0
) -> str:
address, length = self.ctrl_table[data_name]
sync_read_request = MockInstructionPacket.sync_write(ids_values, address, length)
stub_name = f"Sync_Write_{data_name}_" + "_".join([str(id_) for id_ in ids_values])
stub_name = f"Sync_Write_{address}_{length}_" + "_".join([str(id_) for id_ in ids_values])
self.stub(
name=stub_name,
receive_bytes=sync_read_request,
@ -519,20 +563,6 @@ class MockMotors(MockSerial):
)
return stub_name
def build_write_stub(
self, data_name: str, dxl_id: int, value: int, error: str = "Success", num_invalid_try: int = 0
) -> str:
address, length = self.ctrl_table[data_name]
sync_read_request = MockInstructionPacket.write(dxl_id, value, address, length)
return_packet = MockStatusPacket.build(dxl_id, params=[], length=4, error=error)
stub_name = f"Write_{data_name}_{dxl_id}"
self.stub(
name=stub_name,
receive_bytes=sync_read_request,
send_fn=self._build_send_fn(return_packet, num_invalid_try),
)
return stub_name
@staticmethod
def _build_send_fn(packet: bytes, num_invalid_try: int = 0) -> Callable[[int], bytes]:
def send_fn(_call_count: int) -> bytes:

View File

@ -5,32 +5,10 @@ import scservo_sdk as scs
import serial
from mock_serial import MockSerial
from lerobot.common.motors.feetech import STS_SMS_SERIES_CONTROL_TABLE, FeetechMotorsBus
from lerobot.common.motors.feetech.feetech import patch_setPacketTimeout
from lerobot.common.motors.feetech.feetech import _split_into_byte_chunks, patch_setPacketTimeout
from .mock_serial_patch import WaitableStub
# https://files.waveshare.com/upload/2/27/Communication_Protocol_User_Manual-EN%28191218-0923%29.pdf
INSTRUCTION_TYPES = {
"Read": scs.INST_PING, # Read data from the Device
"Ping": scs.INST_READ, # Checks whether the Packet has arrived at a device with the same ID as the specified packet ID
"Write": scs.INST_WRITE, # Write data to the Device
"Reg_Write": scs.INST_REG_WRITE, # Register the Instruction Packet in standby status; Packet can later be executed using the Action command
"Action": scs.INST_ACTION, # Executes a Packet that was registered beforehand using Reg Write
"Factory_Reset": 0x06, # Resets the Control Table to its initial factory default settings
"Sync_Write": scs.INST_SYNC_WRITE, # Write data to multiple devices with the same Address with the same length at once
"Sync_Read": scs.INST_SYNC_READ, # Read data from multiple devices with the same Address with the same length at once
} # fmt: skip
ERROR_TYPE = {
"Success": 0x00,
"Voltage": scs.ERRBIT_VOLTAGE,
"Angle": scs.ERRBIT_ANGLE,
"Overheat": scs.ERRBIT_OVERHEAT,
"Overele": scs.ERRBIT_OVERELE,
"Overload": scs.ERRBIT_OVERLOAD,
}
class MockFeetechPacket(abc.ABC):
@classmethod
@ -49,7 +27,7 @@ class MockFeetechPacket(abc.ABC):
for id_ in range(2, len(packet) - 1): # except header & checksum
checksum += packet[id_]
packet[-1] = scs.SCS_LOBYTE(~checksum)
packet[-1] = ~checksum & 0xFF
return packet
@ -68,15 +46,14 @@ class MockInstructionPacket(MockFeetechPacket):
"""
@classmethod
def _build(cls, scs_id: int, params: list[int], length: int, instruct_type: str) -> list[int]:
instruct_value = INSTRUCTION_TYPES[instruct_type]
def _build(cls, scs_id: int, params: list[int], length: int, instruction: int) -> list[int]:
return [
0xFF, 0xFF, # header
scs_id, # servo id
length, # length
instruct_value, # instruction type
*params, # data bytes
0x00, # placeholder for checksum
0xFF, 0xFF, # header
scs_id, # servo id
length, # length
instruction, # instruction type
*params, # data bytes
0x00, # placeholder for checksum
] # fmt: skip
@classmethod
@ -89,7 +66,7 @@ class MockInstructionPacket(MockFeetechPacket):
No parameters required.
"""
return cls.build(scs_id=scs_id, params=[], length=2, instruct_type="Ping")
return cls.build(scs_id=scs_id, params=[], length=2, instruction=scs.INST_PING)
@classmethod
def read(
@ -113,7 +90,7 @@ class MockInstructionPacket(MockFeetechPacket):
"""
params = [start_address, data_length]
length = 4
return cls.build(scs_id=scs_id, params=params, length=length, instruct_type="Read")
return cls.build(scs_id=scs_id, params=params, length=length, instruction=scs.INST_READ)
@classmethod
def write(
@ -139,10 +116,10 @@ class MockInstructionPacket(MockFeetechPacket):
+1 is for the length bytes,
+1 is for the checksum at the end.
"""
data = FeetechMotorsBus._split_int_to_bytes(value, data_length)
data = _split_into_byte_chunks(value, data_length)
params = [start_address, *data]
length = data_length + 3
return cls.build(scs_id=scs_id, params=params, length=length, instruct_type="Write")
return cls.build(scs_id=scs_id, params=params, length=length, instruction=scs.INST_WRITE)
@classmethod
def sync_read(
@ -167,7 +144,9 @@ class MockInstructionPacket(MockFeetechPacket):
"""
params = [start_address, data_length, *scs_ids]
length = len(scs_ids) + 4
return cls.build(scs_id=scs.BROADCAST_ID, params=params, length=length, instruct_type="Sync_Read")
return cls.build(
scs_id=scs.BROADCAST_ID, params=params, length=length, instruction=scs.INST_SYNC_READ
)
@classmethod
def sync_write(
@ -201,11 +180,13 @@ class MockInstructionPacket(MockFeetechPacket):
"""
data = []
for id_, value in ids_values.items():
split_value = FeetechMotorsBus._split_int_to_bytes(value, data_length)
split_value = _split_into_byte_chunks(value, data_length)
data += [id_, *split_value]
params = [start_address, data_length, *data]
length = len(ids_values) * (1 + data_length) + 4
return cls.build(scs_id=scs.BROADCAST_ID, params=params, length=length, instruct_type="Sync_Write")
return cls.build(
scs_id=scs.BROADCAST_ID, params=params, length=length, instruction=scs.INST_SYNC_WRITE
)
class MockStatusPacket(MockFeetechPacket):
@ -222,19 +203,18 @@ class MockStatusPacket(MockFeetechPacket):
"""
@classmethod
def _build(cls, scs_id: int, params: list[int], length: int, error: str = "Success") -> list[int]:
err_byte = ERROR_TYPE[error]
def _build(cls, scs_id: int, params: list[int], length: int, error: int = 0) -> list[int]:
return [
0xFF, 0xFF, # header
scs_id, # servo id
length, # length
err_byte, # status
error, # status
*params, # data bytes
0x00, # placeholder for checksum
] # fmt: skip
@classmethod
def ping(cls, scs_id: int, error: str = "Success") -> bytes:
def ping(cls, scs_id: int, error: int = 0) -> bytes:
"""Builds a 'Ping' status packet.
Args:
@ -247,7 +227,7 @@ class MockStatusPacket(MockFeetechPacket):
return cls.build(scs_id, params=[], length=2, error=error)
@classmethod
def read(cls, scs_id: int, value: int, param_length: int) -> bytes:
def read(cls, scs_id: int, value: int, param_length: int, error: int = 0) -> bytes:
"""Builds a 'Read' status packet.
Args:
@ -258,9 +238,9 @@ class MockStatusPacket(MockFeetechPacket):
Returns:
bytes: The raw 'Sync Read' status packet ready to be sent through serial.
"""
params = FeetechMotorsBus._split_int_to_bytes(value, param_length)
params = _split_into_byte_chunks(value, param_length)
length = param_length + 2
return cls.build(scs_id, params=params, length=length)
return cls.build(scs_id, params=params, length=length, error=error)
class MockPortHandler(scs.PortHandler):
@ -297,8 +277,6 @@ class MockMotors(MockSerial):
instruction packets. It is meant to test MotorsBus classes.
"""
ctrl_table = STS_SMS_SERIES_CONTROL_TABLE
def __init__(self):
super().__init__()
@ -323,11 +301,11 @@ class MockMotors(MockSerial):
)
return stub_name
def build_ping_stub(self, scs_id: int, num_invalid_try: int = 0) -> str:
def build_ping_stub(self, scs_id: int, num_invalid_try: int = 0, error: int = 0) -> str:
ping_request = MockInstructionPacket.ping(scs_id)
return_packet = MockStatusPacket.ping(scs_id)
return_packet = MockStatusPacket.ping(scs_id, error)
ping_response = self._build_send_fn(return_packet, num_invalid_try)
stub_name = f"Ping_{scs_id}"
stub_name = f"Ping_{scs_id}_{error}"
self.stub(
name=stub_name,
receive_bytes=ping_request,
@ -336,13 +314,19 @@ class MockMotors(MockSerial):
return stub_name
def build_read_stub(
self, data_name: str, scs_id: int, value: int | None = None, num_invalid_try: int = 0
self,
address: int,
length: int,
scs_id: int,
value: int,
reply: bool = True,
error: int = 0,
num_invalid_try: int = 0,
) -> str:
address, length = self.ctrl_table[data_name]
read_request = MockInstructionPacket.read(scs_id, address, length)
return_packet = MockStatusPacket.read(scs_id, value, length)
return_packet = MockStatusPacket.read(scs_id, value, length, error) if reply else b""
read_response = self._build_send_fn(return_packet, num_invalid_try)
stub_name = f"Read_{data_name}_{scs_id}"
stub_name = f"Read_{address}_{length}_{scs_id}_{value}_{error}"
self.stub(
name=stub_name,
receive_bytes=read_request,
@ -350,15 +334,42 @@ class MockMotors(MockSerial):
)
return stub_name
def build_sync_read_stub(
self, data_name: str, ids_values: dict[int, int] | None = None, num_invalid_try: int = 0
def build_write_stub(
self,
address: int,
length: int,
scs_id: int,
value: int,
reply: bool = True,
error: int = 0,
num_invalid_try: int = 0,
) -> str:
address, length = self.ctrl_table[data_name]
sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length)
return_packets = b"".join(MockStatusPacket.read(id_, pos, length) for id_, pos in ids_values.items())
sync_read_request = MockInstructionPacket.write(scs_id, value, address, length)
return_packet = MockStatusPacket.build(scs_id, params=[], length=2, error=error) if reply else b""
stub_name = f"Write_{address}_{length}_{scs_id}"
self.stub(
name=stub_name,
receive_bytes=sync_read_request,
send_fn=self._build_send_fn(return_packet, num_invalid_try),
)
return stub_name
def build_sync_read_stub(
self,
address: int,
length: int,
ids_values: dict[int, int],
reply: bool = True,
num_invalid_try: int = 0,
) -> str:
sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length)
return_packets = (
b"".join(MockStatusPacket.read(id_, pos, length) for id_, pos in ids_values.items())
if reply
else b""
)
sync_read_response = self._build_send_fn(return_packets, num_invalid_try)
stub_name = f"Sync_Read_{data_name}_" + "_".join([str(id_) for id_ in ids_values])
stub_name = f"Sync_Read_{address}_{length}_" + "_".join([str(id_) for id_ in ids_values])
self.stub(
name=stub_name,
receive_bytes=sync_read_request,
@ -367,11 +378,10 @@ class MockMotors(MockSerial):
return stub_name
def build_sequential_sync_read_stub(
self, data_name: str, ids_values: dict[int, list[int]] | None = None
self, address: int, length: int, ids_values: dict[int, list[int]] | None = None
) -> str:
sequence_length = len(next(iter(ids_values.values())))
assert all(len(positions) == sequence_length for positions in ids_values.values())
address, length = self.ctrl_table[data_name]
sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length)
sequential_packets = []
for count in range(sequence_length):
@ -381,7 +391,7 @@ class MockMotors(MockSerial):
sequential_packets.append(return_packets)
sync_read_response = self._build_sequential_send_fn(sequential_packets)
stub_name = f"Seq_Sync_Read_{data_name}_" + "_".join([str(id_) for id_ in ids_values])
stub_name = f"Seq_Sync_Read_{address}_{length}_" + "_".join([str(id_) for id_ in ids_values])
self.stub(
name=stub_name,
receive_bytes=sync_read_request,
@ -390,11 +400,10 @@ class MockMotors(MockSerial):
return stub_name
def build_sync_write_stub(
self, data_name: str, ids_values: dict[int, int] | None = None, num_invalid_try: int = 0
self, address: int, length: int, ids_values: dict[int, int], num_invalid_try: int = 0
) -> str:
address, length = self.ctrl_table[data_name]
sync_read_request = MockInstructionPacket.sync_write(ids_values, address, length)
stub_name = f"Sync_Write_{data_name}_" + "_".join([str(id_) for id_ in ids_values])
stub_name = f"Sync_Write_{address}_{length}_" + "_".join([str(id_) for id_ in ids_values])
self.stub(
name=stub_name,
receive_bytes=sync_read_request,
@ -402,20 +411,6 @@ class MockMotors(MockSerial):
)
return stub_name
def build_write_stub(
self, data_name: str, scs_id: int, value: int, error: str = "Success", num_invalid_try: int = 0
) -> str:
address, length = self.ctrl_table[data_name]
sync_read_request = MockInstructionPacket.write(scs_id, value, address, length)
return_packet = MockStatusPacket.build(scs_id, params=[], length=2, error=error)
stub_name = f"Write_{data_name}_{scs_id}"
self.stub(
name=stub_name,
receive_bytes=sync_read_request,
send_fn=self._build_send_fn(return_packet, num_invalid_try),
)
return stub_name
@staticmethod
def _build_send_fn(packet: bytes, num_invalid_try: int = 0) -> Callable[[int], bytes]:
def send_fn(_call_count: int) -> bytes:

View File

@ -1,3 +1,4 @@
import re
import sys
from typing import Generator
from unittest.mock import MagicMock, patch
@ -7,6 +8,7 @@ import pytest
from lerobot.common.motors import Motor, MotorCalibration, MotorNormMode
from lerobot.common.motors.dynamixel import MODEL_NUMBER_TABLE, DynamixelMotorsBus
from lerobot.common.motors.dynamixel.tables import X_SERIES_CONTROL_TABLE
from lerobot.common.utils.encoding_utils import encode_twos_complement
from tests.mocks.mock_dynamixel import MockMotors, MockPortHandler
@ -62,48 +64,21 @@ def test_autouse_patch():
@pytest.mark.parametrize(
"value, n_bytes, expected",
"value, length, expected",
[
(0x12, 1, [0x12]),
(0x1234, 2, [0x34, 0x12]),
(0x12345678, 4, [0x78, 0x56, 0x34, 0x12]),
(0, 1, [0x00]),
(0, 2, [0x00, 0x00]),
(0, 4, [0x00, 0x00, 0x00, 0x00]),
(255, 1, [0xFF]),
(65535, 2, [0xFF, 0xFF]),
(4294967295, 4, [0xFF, 0xFF, 0xFF, 0xFF]),
],
ids=[
"1 byte",
"2 bytes",
"4 bytes",
"0 with 1 byte",
"0 with 2 bytes",
"0 with 4 bytes",
"max single byte",
"max two bytes",
"max four bytes",
],
) # fmt: skip
def test_split_int_to_bytes(value, n_bytes, expected):
assert DynamixelMotorsBus._split_int_to_bytes(value, n_bytes) == expected
def test_split_int_to_bytes_invalid_n_bytes():
with pytest.raises(NotImplementedError):
DynamixelMotorsBus._split_int_to_bytes(100, 3)
def test_split_int_to_bytes_negative_numbers():
with pytest.raises(ValueError):
neg = DynamixelMotorsBus._split_int_to_bytes(-1, 1)
print(neg)
def test_split_int_to_bytes_large_number():
with pytest.raises(ValueError):
DynamixelMotorsBus._split_int_to_bytes(2**32, 4) # 4-byte max is 0xFFFFFFFF
def test__split_into_byte_chunks(value, length, expected):
bus = DynamixelMotorsBus("", {})
assert bus._split_into_byte_chunks(value, length) == expected
def test_abc_implementation(dummy_motors):
@ -114,204 +89,195 @@ def test_abc_implementation(dummy_motors):
@pytest.mark.parametrize("id_", [1, 2, 3])
def test_ping(id_, mock_motors, dummy_motors):
expected_model_nb = MODEL_NUMBER_TABLE[dummy_motors[f"dummy_{id_}"].model]
stub_name = mock_motors.build_ping_stub(id_, expected_model_nb)
motors_bus = DynamixelMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
motors_bus.connect(assert_motors_exist=False)
stub = mock_motors.build_ping_stub(id_, expected_model_nb)
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
bus.connect(handshake=False)
ping_model_nb = motors_bus.ping(id_)
ping_model_nb = bus.ping(id_)
assert ping_model_nb == expected_model_nb
assert mock_motors.stubs[stub_name].called
assert mock_motors.stubs[stub].called
def test_broadcast_ping(mock_motors, dummy_motors):
models = {m.id: m.model for m in dummy_motors.values()}
expected_model_nbs = {id_: MODEL_NUMBER_TABLE[model] for id_, model in models.items()}
stub_name = mock_motors.build_broadcast_ping_stub(expected_model_nbs)
motors_bus = DynamixelMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
motors_bus.connect(assert_motors_exist=False)
stub = mock_motors.build_broadcast_ping_stub(expected_model_nbs)
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
bus.connect(handshake=False)
ping_model_nbs = motors_bus.broadcast_ping()
ping_model_nbs = bus.broadcast_ping()
assert ping_model_nbs == expected_model_nbs
assert mock_motors.stubs[stub_name].called
def test_sync_read_none(mock_motors, dummy_motors):
expected_positions = {
"dummy_1": 1337,
"dummy_2": 42,
"dummy_3": 4016,
}
ids_values = dict(zip([1, 2, 3], expected_positions.values(), strict=True))
stub_name = mock_motors.build_sync_read_stub("Present_Position", ids_values)
motors_bus = DynamixelMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
motors_bus.connect(assert_motors_exist=False)
read_positions = motors_bus.sync_read("Present_Position", normalize=False)
assert mock_motors.stubs[stub_name].called
assert read_positions == expected_positions
assert mock_motors.stubs[stub].called
@pytest.mark.parametrize(
"id_, position",
"addr, length, id_, value",
[
(1, 1337),
(2, 42),
(3, 4016),
(0, 1, 1, 2),
(10, 2, 2, 999),
(42, 4, 3, 1337),
],
)
def test_sync_read_single_value(id_, position, mock_motors, dummy_motors):
expected_position = {f"dummy_{id_}": position}
stub_name = mock_motors.build_sync_read_stub("Present_Position", {id_: position})
motors_bus = DynamixelMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
motors_bus.connect(assert_motors_exist=False)
def test__read(addr, length, id_, value, mock_motors, dummy_motors):
stub = mock_motors.build_read_stub(addr, length, id_, value)
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
bus.connect(handshake=False)
read_position = motors_bus.sync_read("Present_Position", f"dummy_{id_}", normalize=False)
read_value, _, _ = bus._read(addr, length, id_)
assert mock_motors.stubs[stub_name].called
assert read_position == expected_position
assert mock_motors.stubs[stub].called
assert read_value == value
@pytest.mark.parametrize(
"ids, positions",
[
([1], [1337]),
([1, 2], [1337, 42]),
([1, 2, 3], [1337, 42, 4016]),
],
ids=["1 motor", "2 motors", "3 motors"],
) # fmt: skip
def test_sync_read(ids, positions, mock_motors, dummy_motors):
assert len(ids) == len(positions)
names = [f"dummy_{dxl_id}" for dxl_id in ids]
expected_positions = dict(zip(names, positions, strict=True))
ids_values = dict(zip(ids, positions, strict=True))
stub_name = mock_motors.build_sync_read_stub("Present_Position", ids_values)
motors_bus = DynamixelMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
motors_bus.connect(assert_motors_exist=False)
@pytest.mark.parametrize("raise_on_error", (True, False))
def test__read_error(raise_on_error, mock_motors, dummy_motors):
addr, length, id_, value, error = (10, 4, 1, 1337, dxl.ERRNUM_DATA_LIMIT)
stub = mock_motors.build_read_stub(addr, length, id_, value, error=error)
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
bus.connect(handshake=False)
read_positions = motors_bus.sync_read("Present_Position", names, normalize=False)
assert mock_motors.stubs[stub_name].called
assert read_positions == expected_positions
@pytest.mark.parametrize(
"num_retry, num_invalid_try, pos",
[
(0, 2, 1337),
(2, 3, 42),
(3, 2, 4016),
(2, 1, 999),
],
)
def test_sync_read_num_retry(num_retry, num_invalid_try, pos, mock_motors, dummy_motors):
expected_position = {"dummy_1": pos}
stub_name = mock_motors.build_sync_read_stub(
"Present_Position", {1: pos}, num_invalid_try=num_invalid_try
)
motors_bus = DynamixelMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
motors_bus.connect(assert_motors_exist=False)
if num_retry >= num_invalid_try:
pos_dict = motors_bus.sync_read("Present_Position", "dummy_1", normalize=False, num_retry=num_retry)
assert pos_dict == expected_position
if raise_on_error:
with pytest.raises(
RuntimeError, match=re.escape("[RxPacketError] The data value exceeds the limit value!")
):
bus._read(addr, length, id_, raise_on_error=raise_on_error)
else:
with pytest.raises(ConnectionError):
_ = motors_bus.sync_read("Present_Position", "dummy_1", normalize=False, num_retry=num_retry)
_, _, read_error = bus._read(addr, length, id_, raise_on_error=raise_on_error)
assert read_error == error
expected_calls = min(1 + num_retry, 1 + num_invalid_try)
assert mock_motors.stubs[stub_name].calls == expected_calls
assert mock_motors.stubs[stub].called
@pytest.mark.parametrize("raise_on_error", (True, False))
def test__read_comm(raise_on_error, mock_motors, dummy_motors):
addr, length, id_, value = (10, 4, 1, 1337)
stub = mock_motors.build_read_stub(addr, length, id_, value, reply=False)
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
bus.connect(handshake=False)
if raise_on_error:
with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")):
bus._read(addr, length, id_, raise_on_error=raise_on_error)
else:
_, read_comm, _ = bus._read(addr, length, id_, raise_on_error=raise_on_error)
assert read_comm == dxl.COMM_RX_TIMEOUT
assert mock_motors.stubs[stub].called
@pytest.mark.parametrize(
"data_name, value",
"addr, length, id_, value",
[
("Torque_Enable", 0),
("Torque_Enable", 1),
("Goal_Position", 1337),
("Goal_Position", 42),
(0, 1, 1, 2),
(10, 2, 2, 999),
(42, 4, 3, 1337),
],
)
def test_sync_write_single_value(data_name, value, mock_motors, dummy_motors):
ids_values = {m.id: value for m in dummy_motors.values()}
stub_name = mock_motors.build_sync_write_stub(data_name, ids_values)
motors_bus = DynamixelMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
motors_bus.connect(assert_motors_exist=False)
def test__write(addr, length, id_, value, mock_motors, dummy_motors):
stub = mock_motors.build_write_stub(addr, length, id_, value)
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
bus.connect(handshake=False)
motors_bus.sync_write(data_name, value, normalize=False)
comm, error = bus._write(addr, length, id_, value)
assert mock_motors.stubs[stub_name].wait_called()
assert mock_motors.stubs[stub].called
assert comm == dxl.COMM_SUCCESS
assert error == 0
@pytest.mark.parametrize("raise_on_error", (True, False))
def test__write_error(raise_on_error, mock_motors, dummy_motors):
addr, length, id_, value, error = (10, 4, 1, 1337, dxl.ERRNUM_DATA_LIMIT)
stub = mock_motors.build_write_stub(addr, length, id_, value, error=error)
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
bus.connect(handshake=False)
if raise_on_error:
with pytest.raises(
RuntimeError, match=re.escape("[RxPacketError] The data value exceeds the limit value!")
):
bus._write(addr, length, id_, value, raise_on_error=raise_on_error)
else:
_, write_error = bus._write(addr, length, id_, value, raise_on_error=raise_on_error)
assert write_error == error
assert mock_motors.stubs[stub].called
@pytest.mark.parametrize("raise_on_error", (True, False))
def test__write_comm(raise_on_error, mock_motors, dummy_motors):
addr, length, id_, value = (10, 4, 1, 1337)
stub = mock_motors.build_write_stub(addr, length, id_, value, reply=False)
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
bus.connect(handshake=False)
if raise_on_error:
with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")):
bus._write(addr, length, id_, value, raise_on_error=raise_on_error)
else:
write_comm, _ = bus._write(addr, length, id_, value, raise_on_error=raise_on_error)
assert write_comm == dxl.COMM_RX_TIMEOUT
assert mock_motors.stubs[stub].called
@pytest.mark.parametrize(
"ids, positions",
"addr, length, ids_values",
[
([1], [1337]),
([1, 2], [1337, 42]),
([1, 2, 3], [1337, 42, 4016]),
(0, 1, {1: 4}),
(10, 2, {1: 1337, 2: 42}),
(42, 4, {1: 1337, 2: 42, 3: 4016}),
],
ids=["1 motor", "2 motors", "3 motors"],
) # fmt: skip
def test_sync_write(ids, positions, mock_motors, dummy_motors):
assert len(ids) == len(positions)
ids_values = dict(zip(ids, positions, strict=True))
stub_name = mock_motors.build_sync_write_stub("Goal_Position", ids_values)
motors_bus = DynamixelMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
motors_bus.connect(assert_motors_exist=False)
)
def test__sync_read(addr, length, ids_values, mock_motors, dummy_motors):
stub = mock_motors.build_sync_read_stub(addr, length, ids_values)
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
bus.connect(handshake=False)
write_values = {f"dummy_{id_}": pos for id_, pos in ids_values.items()}
motors_bus.sync_write("Goal_Position", write_values, normalize=False)
read_values, _ = bus._sync_read(addr, length, list(ids_values))
assert mock_motors.stubs[stub_name].wait_called()
assert mock_motors.stubs[stub].called
assert read_values == ids_values
@pytest.mark.parametrize("raise_on_error", (True, False))
def test__sync_read_comm(raise_on_error, mock_motors, dummy_motors):
addr, length, ids_values = (10, 4, {1: 1337})
stub = mock_motors.build_sync_read_stub(addr, length, ids_values, reply=False)
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
bus.connect(handshake=False)
if raise_on_error:
with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")):
bus._sync_read(addr, length, list(ids_values), raise_on_error=raise_on_error)
else:
_, read_comm = bus._sync_read(addr, length, list(ids_values), raise_on_error=raise_on_error)
assert read_comm == dxl.COMM_RX_TIMEOUT
assert mock_motors.stubs[stub].called
@pytest.mark.parametrize(
"data_name, dxl_id, value",
"addr, length, ids_values",
[
("Torque_Enable", 1, 0),
("Torque_Enable", 1, 1),
("Goal_Position", 2, 1337),
("Goal_Position", 3, 42),
(0, 1, {1: 4}),
(10, 2, {1: 1337, 2: 42}),
(42, 4, {1: 1337, 2: 42, 3: 4016}),
],
ids=["1 motor", "2 motors", "3 motors"],
)
def test_write(data_name, dxl_id, value, mock_motors, dummy_motors):
stub_name = mock_motors.build_write_stub(data_name, dxl_id, value)
motors_bus = DynamixelMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
motors_bus.connect(assert_motors_exist=False)
def test__sync_write(addr, length, ids_values, mock_motors, dummy_motors):
stub = mock_motors.build_sync_write_stub(addr, length, ids_values)
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
bus.connect(handshake=False)
motors_bus.write(data_name, f"dummy_{dxl_id}", value, normalize=False)
comm = bus._sync_write(addr, length, ids_values)
assert mock_motors.stubs[stub_name].called
assert mock_motors.stubs[stub].wait_called()
assert comm == dxl.COMM_SUCCESS
def test_is_calibrated(mock_motors, dummy_motors, dummy_calibration):
@ -319,18 +285,18 @@ def test_is_calibrated(mock_motors, dummy_motors, dummy_calibration):
encoded_homings = {m.id: encode_twos_complement(m.homing_offset, 4) for m in dummy_calibration.values()}
mins = {m.id: m.range_min for m in dummy_calibration.values()}
maxes = {m.id: m.range_max for m in dummy_calibration.values()}
drive_modes_stub = mock_motors.build_sync_read_stub("Drive_Mode", drive_modes)
offsets_stub = mock_motors.build_sync_read_stub("Homing_Offset", encoded_homings)
mins_stub = mock_motors.build_sync_read_stub("Min_Position_Limit", mins)
maxes_stub = mock_motors.build_sync_read_stub("Max_Position_Limit", maxes)
motors_bus = DynamixelMotorsBus(
drive_modes_stub = mock_motors.build_sync_read_stub(*X_SERIES_CONTROL_TABLE["Drive_Mode"], drive_modes)
offsets_stub = mock_motors.build_sync_read_stub(*X_SERIES_CONTROL_TABLE["Homing_Offset"], encoded_homings)
mins_stub = mock_motors.build_sync_read_stub(*X_SERIES_CONTROL_TABLE["Min_Position_Limit"], mins)
maxes_stub = mock_motors.build_sync_read_stub(*X_SERIES_CONTROL_TABLE["Max_Position_Limit"], maxes)
bus = DynamixelMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
calibration=dummy_calibration,
)
motors_bus.connect(assert_motors_exist=False)
bus.connect(handshake=False)
is_calibrated = motors_bus.is_calibrated
is_calibrated = bus.is_calibrated
assert is_calibrated
assert mock_motors.stubs[drive_modes_stub].called
@ -344,17 +310,20 @@ def test_reset_calibration(mock_motors, dummy_motors):
write_mins_stubs = []
write_maxes_stubs = []
for motor in dummy_motors.values():
write_homing_stubs.append(mock_motors.build_write_stub("Homing_Offset", motor.id, 0))
write_mins_stubs.append(mock_motors.build_write_stub("Min_Position_Limit", motor.id, 0))
write_maxes_stubs.append(mock_motors.build_write_stub("Max_Position_Limit", motor.id, 4095))
write_homing_stubs.append(
mock_motors.build_write_stub(*X_SERIES_CONTROL_TABLE["Homing_Offset"], motor.id, 0)
)
write_mins_stubs.append(
mock_motors.build_write_stub(*X_SERIES_CONTROL_TABLE["Min_Position_Limit"], motor.id, 0)
)
write_maxes_stubs.append(
mock_motors.build_write_stub(*X_SERIES_CONTROL_TABLE["Max_Position_Limit"], motor.id, 4095)
)
motors_bus = DynamixelMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
motors_bus.connect(assert_motors_exist=False)
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
bus.connect(handshake=False)
motors_bus.reset_calibration()
bus.reset_calibration()
assert all(mock_motors.stubs[stub].called for stub in write_homing_stubs)
assert all(mock_motors.stubs[stub].called for stub in write_mins_stubs)
@ -376,23 +345,22 @@ def test_set_half_turn_homings(mock_motors, dummy_motors):
2: 2005, # 2047 - 42
3: -1625, # 2047 - 3672
}
read_pos_stub = mock_motors.build_sync_read_stub("Present_Position", current_positions)
read_pos_stub = mock_motors.build_sync_read_stub(
*X_SERIES_CONTROL_TABLE["Present_Position"], current_positions
)
write_homing_stubs = []
for id_, homing in expected_homings.items():
encoded_homing = encode_twos_complement(homing, 4)
stub = mock_motors.build_write_stub("Homing_Offset", id_, encoded_homing)
stub = mock_motors.build_write_stub(*X_SERIES_CONTROL_TABLE["Homing_Offset"], id_, encoded_homing)
write_homing_stubs.append(stub)
motors_bus = DynamixelMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
motors_bus.connect(assert_motors_exist=False)
motors_bus.reset_calibration = MagicMock()
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
bus.connect(handshake=False)
bus.reset_calibration = MagicMock()
motors_bus.set_half_turn_homings()
bus.set_half_turn_homings()
motors_bus.reset_calibration.assert_called_once()
bus.reset_calibration.assert_called_once()
assert mock_motors.stubs[read_pos_stub].called
assert all(mock_motors.stubs[stub].called for stub in write_homing_stubs)
@ -413,15 +381,14 @@ def test_record_ranges_of_motion(mock_motors, dummy_motors):
"dummy_2": 3600,
"dummy_3": 4002,
}
read_pos_stub = mock_motors.build_sequential_sync_read_stub("Present_Position", positions)
read_pos_stub = mock_motors.build_sequential_sync_read_stub(
*X_SERIES_CONTROL_TABLE["Present_Position"], positions
)
with patch("lerobot.common.motors.motors_bus.enter_pressed", side_effect=[False, True]):
motors_bus = DynamixelMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
motors_bus.connect(assert_motors_exist=False)
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
bus.connect(handshake=False)
mins, maxes = motors_bus.record_ranges_of_motion(display_values=False)
mins, maxes = bus.record_ranges_of_motion(display_values=False)
assert mock_motors.stubs[read_pos_stub].calls == 3
assert mins == expected_mins

View File

@ -1,3 +1,4 @@
import re
import sys
from typing import Generator
from unittest.mock import MagicMock, patch
@ -6,7 +7,8 @@ import pytest
import scservo_sdk as scs
from lerobot.common.motors import Motor, MotorCalibration, MotorNormMode
from lerobot.common.motors.feetech import MODEL_NUMBER_TABLE, FeetechMotorsBus
from lerobot.common.motors.feetech import MODEL_NUMBER, MODEL_NUMBER_TABLE, FeetechMotorsBus
from lerobot.common.motors.feetech.tables import STS_SMS_SERIES_CONTROL_TABLE
from lerobot.common.utils.encoding_utils import encode_sign_magnitude
from tests.mocks.mock_feetech import MockMotors, MockPortHandler
@ -61,48 +63,27 @@ def test_autouse_patch():
@pytest.mark.parametrize(
"value, n_bytes, expected",
"protocol, value, length, expected",
[
(0x12, 1, [0x12]),
(0x1234, 2, [0x34, 0x12]),
(0x12345678, 4, [0x78, 0x56, 0x34, 0x12]),
(0, 1, [0x00]),
(0, 2, [0x00, 0x00]),
(0, 4, [0x00, 0x00, 0x00, 0x00]),
(255, 1, [0xFF]),
(65535, 2, [0xFF, 0xFF]),
(4294967295, 4, [0xFF, 0xFF, 0xFF, 0xFF]),
(0, 0x12, 1, [0x12]),
(1, 0x12, 1, [0x12]),
(0, 0x1234, 2, [0x34, 0x12]),
(1, 0x1234, 2, [0x12, 0x34]),
(0, 0x12345678, 4, [0x78, 0x56, 0x34, 0x12]),
(1, 0x12345678, 4, [0x56, 0x78, 0x12, 0x34]),
],
ids=[
"1 byte",
"2 bytes",
"4 bytes",
"0 with 1 byte",
"0 with 2 bytes",
"0 with 4 bytes",
"max single byte",
"max two bytes",
"max four bytes",
"P0: 1 byte",
"P1: 1 byte",
"P0: 2 bytes",
"P1: 2 bytes",
"P0: 4 bytes",
"P1: 4 bytes",
],
) # fmt: skip
def test_split_int_to_bytes(value, n_bytes, expected):
assert FeetechMotorsBus._split_int_to_bytes(value, n_bytes) == expected
def test_split_int_to_bytes_invalid_n_bytes():
with pytest.raises(NotImplementedError):
FeetechMotorsBus._split_int_to_bytes(100, 3)
def test_split_int_to_bytes_negative_numbers():
with pytest.raises(ValueError):
neg = FeetechMotorsBus._split_int_to_bytes(-1, 1)
print(neg)
def test_split_int_to_bytes_large_number():
with pytest.raises(ValueError):
FeetechMotorsBus._split_int_to_bytes(2**32, 4) # 4-byte max is 0xFFFFFFFF
def test__split_into_byte_chunks(protocol, value, length, expected):
bus = FeetechMotorsBus("", {}, protocol_version=protocol)
assert bus._split_into_byte_chunks(value, length) == expected
def test_abc_implementation(dummy_motors):
@ -110,35 +91,19 @@ def test_abc_implementation(dummy_motors):
FeetechMotorsBus(port="/dev/dummy-port", motors=dummy_motors)
@pytest.mark.skip("TODO")
def test_scan_port(mock_motors):
expected = {
9_600: {1: 777},
57_600: {2: 777},
500_000: {237: 777},
}
expected_model_nbs = {id_: model for d in expected.values() for id_, model in d.items()}
ping_stub = mock_motors.build_broadcast_ping_stub(list(expected_model_nbs))
mobel_nb_stub = mock_motors.build_sync_read_stub("Model_Number", expected_model_nbs)
found = FeetechMotorsBus.scan_port(mock_motors.port)
assert found == expected
assert mock_motors.stubs[ping_stub].called
assert mock_motors.stubs[mobel_nb_stub].called
@pytest.mark.parametrize("id_", [1, 2, 3])
def test_ping(id_, mock_motors, dummy_motors):
expected_model_nb = MODEL_NUMBER_TABLE[dummy_motors[f"dummy_{id_}"].model]
addr, length = MODEL_NUMBER
ping_stub = mock_motors.build_ping_stub(id_)
mobel_nb_stub = mock_motors.build_read_stub("Model_Number", id_, expected_model_nb)
motors_bus = FeetechMotorsBus(
mobel_nb_stub = mock_motors.build_read_stub(addr, length, id_, expected_model_nb)
bus = FeetechMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
motors_bus.connect(assert_motors_exist=False)
bus.connect(handshake=False)
ping_model_nb = motors_bus.ping(id_)
ping_model_nb = bus.ping(id_)
assert ping_model_nb == expected_model_nb
assert mock_motors.stubs[ping_stub].called
@ -147,208 +112,221 @@ def test_ping(id_, mock_motors, dummy_motors):
def test_broadcast_ping(mock_motors, dummy_motors):
models = {m.id: m.model for m in dummy_motors.values()}
expected_model_nbs = {id_: MODEL_NUMBER_TABLE[model] for id_, model in models.items()}
addr, length = MODEL_NUMBER
ping_stub = mock_motors.build_broadcast_ping_stub(list(models))
mobel_nb_stub = mock_motors.build_sync_read_stub("Model_Number", expected_model_nbs)
motors_bus = FeetechMotorsBus(
mobel_nb_stubs = []
expected_model_nbs = {}
for id_, model in models.items():
model_nb = MODEL_NUMBER_TABLE[model]
stub = mock_motors.build_read_stub(addr, length, id_, model_nb)
expected_model_nbs[id_] = model_nb
mobel_nb_stubs.append(stub)
bus = FeetechMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
motors_bus.connect(assert_motors_exist=False)
bus.connect(handshake=False)
ping_model_nbs = motors_bus.broadcast_ping()
ping_model_nbs = bus.broadcast_ping()
assert ping_model_nbs == expected_model_nbs
assert mock_motors.stubs[ping_stub].called
assert mock_motors.stubs[mobel_nb_stub].called
def test_sync_read_none(mock_motors, dummy_motors):
expected_positions = {
"dummy_1": 1337,
"dummy_2": 42,
"dummy_3": 4016,
}
ids_values = dict(zip([1, 2, 3], expected_positions.values(), strict=True))
stub_name = mock_motors.build_sync_read_stub("Present_Position", ids_values)
motors_bus = FeetechMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
motors_bus.connect(assert_motors_exist=False)
read_positions = motors_bus.sync_read("Present_Position", normalize=False)
assert mock_motors.stubs[stub_name].called
assert read_positions == expected_positions
assert all(mock_motors.stubs[stub].called for stub in mobel_nb_stubs)
@pytest.mark.parametrize(
"id_, position",
"addr, length, id_, value",
[
(1, 1337),
(2, 42),
(3, 4016),
(0, 1, 1, 2),
(10, 2, 2, 999),
(42, 4, 3, 1337),
],
)
def test_sync_read_single_value(id_, position, mock_motors, dummy_motors):
expected_position = {f"dummy_{id_}": position}
stub_name = mock_motors.build_sync_read_stub("Present_Position", {id_: position})
motors_bus = FeetechMotorsBus(
def test__read(addr, length, id_, value, mock_motors, dummy_motors):
stub = mock_motors.build_read_stub(addr, length, id_, value)
bus = FeetechMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
motors_bus.connect(assert_motors_exist=False)
bus.connect(handshake=False)
read_position = motors_bus.sync_read("Present_Position", f"dummy_{id_}", normalize=False)
read_value, _, _ = bus._read(addr, length, id_)
assert mock_motors.stubs[stub_name].called
assert read_position == expected_position
assert mock_motors.stubs[stub].called
assert read_value == value
@pytest.mark.parametrize(
"ids, positions",
[
([1], [1337]),
([1, 2], [1337, 42]),
([1, 2, 3], [1337, 42, 4016]),
],
ids=["1 motor", "2 motors", "3 motors"],
) # fmt: skip
def test_sync_read(ids, positions, mock_motors, dummy_motors):
assert len(ids) == len(positions)
names = [f"dummy_{dxl_id}" for dxl_id in ids]
expected_positions = dict(zip(names, positions, strict=True))
ids_values = dict(zip(ids, positions, strict=True))
stub_name = mock_motors.build_sync_read_stub("Present_Position", ids_values)
motors_bus = FeetechMotorsBus(
@pytest.mark.parametrize("raise_on_error", (True, False))
def test__read_error(raise_on_error, mock_motors, dummy_motors):
addr, length, id_, value, error = (10, 4, 1, 1337, scs.ERRBIT_VOLTAGE)
stub = mock_motors.build_read_stub(addr, length, id_, value, error=error)
bus = FeetechMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
motors_bus.connect(assert_motors_exist=False)
bus.connect(handshake=False)
read_positions = motors_bus.sync_read("Present_Position", names, normalize=False)
assert mock_motors.stubs[stub_name].called
assert read_positions == expected_positions
@pytest.mark.parametrize(
"num_retry, num_invalid_try, pos",
[
(0, 2, 1337),
(2, 3, 42),
(3, 2, 4016),
(2, 1, 999),
],
)
def test_sync_read_num_retry(num_retry, num_invalid_try, pos, mock_motors, dummy_motors):
expected_position = {"dummy_1": pos}
stub_name = mock_motors.build_sync_read_stub(
"Present_Position", {1: pos}, num_invalid_try=num_invalid_try
)
motors_bus = FeetechMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
motors_bus.connect(assert_motors_exist=False)
if num_retry >= num_invalid_try:
pos_dict = motors_bus.sync_read("Present_Position", "dummy_1", normalize=False, num_retry=num_retry)
assert pos_dict == expected_position
if raise_on_error:
with pytest.raises(RuntimeError, match=re.escape("[RxPacketError] Input voltage error!")):
bus._read(addr, length, id_, raise_on_error=raise_on_error)
else:
with pytest.raises(ConnectionError):
_ = motors_bus.sync_read("Present_Position", "dummy_1", normalize=False, num_retry=num_retry)
_, _, read_error = bus._read(addr, length, id_, raise_on_error=raise_on_error)
assert read_error == error
expected_calls = min(1 + num_retry, 1 + num_invalid_try)
assert mock_motors.stubs[stub_name].calls == expected_calls
assert mock_motors.stubs[stub].called
@pytest.mark.parametrize(
"data_name, value",
[
("Torque_Enable", 0),
("Torque_Enable", 1),
("Goal_Position", 1337),
("Goal_Position", 42),
],
)
def test_sync_write_single_value(data_name, value, mock_motors, dummy_motors):
ids_values = {m.id: value for m in dummy_motors.values()}
stub_name = mock_motors.build_sync_write_stub(data_name, ids_values)
motors_bus = FeetechMotorsBus(
@pytest.mark.parametrize("raise_on_error", (True, False))
def test__read_comm(raise_on_error, mock_motors, dummy_motors):
addr, length, id_, value = (10, 4, 1, 1337)
stub = mock_motors.build_read_stub(addr, length, id_, value, reply=False)
bus = FeetechMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
motors_bus.connect(assert_motors_exist=False)
bus.connect(handshake=False)
motors_bus.sync_write(data_name, value, normalize=False)
if raise_on_error:
with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")):
bus._read(addr, length, id_, raise_on_error=raise_on_error)
else:
_, read_comm, _ = bus._read(addr, length, id_, raise_on_error=raise_on_error)
assert read_comm == scs.COMM_RX_TIMEOUT
assert mock_motors.stubs[stub_name].wait_called()
assert mock_motors.stubs[stub].called
@pytest.mark.parametrize(
"ids, positions",
"addr, length, id_, value",
[
([1], [1337]),
([1, 2], [1337, 42]),
([1, 2, 3], [1337, 42, 4016]),
(0, 1, 1, 2),
(10, 2, 2, 999),
(42, 4, 3, 1337),
],
)
def test__write(addr, length, id_, value, mock_motors, dummy_motors):
stub = mock_motors.build_write_stub(addr, length, id_, value)
bus = FeetechMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
bus.connect(handshake=False)
comm, error = bus._write(addr, length, id_, value)
assert mock_motors.stubs[stub].called
assert comm == scs.COMM_SUCCESS
assert error == 0
@pytest.mark.parametrize("raise_on_error", (True, False))
def test__write_error(raise_on_error, mock_motors, dummy_motors):
addr, length, id_, value, error = (10, 4, 1, 1337, scs.ERRBIT_VOLTAGE)
stub = mock_motors.build_write_stub(addr, length, id_, value, error=error)
bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors)
bus.connect(handshake=False)
if raise_on_error:
with pytest.raises(RuntimeError, match=re.escape("[RxPacketError] Input voltage error!")):
bus._write(addr, length, id_, value, raise_on_error=raise_on_error)
else:
_, write_error = bus._write(addr, length, id_, value, raise_on_error=raise_on_error)
assert write_error == error
assert mock_motors.stubs[stub].called
@pytest.mark.parametrize("raise_on_error", (True, False))
def test__write_comm(raise_on_error, mock_motors, dummy_motors):
addr, length, id_, value = (10, 4, 1, 1337)
stub = mock_motors.build_write_stub(addr, length, id_, value, reply=False)
bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors)
bus.connect(handshake=False)
if raise_on_error:
with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")):
bus._write(addr, length, id_, value, raise_on_error=raise_on_error)
else:
write_comm, _ = bus._write(addr, length, id_, value, raise_on_error=raise_on_error)
assert write_comm == scs.COMM_RX_TIMEOUT
assert mock_motors.stubs[stub].called
@pytest.mark.parametrize(
"addr, length, ids_values",
[
(0, 1, {1: 4}),
(10, 2, {1: 1337, 2: 42}),
(42, 4, {1: 1337, 2: 42, 3: 4016}),
],
ids=["1 motor", "2 motors", "3 motors"],
) # fmt: skip
def test_sync_write(ids, positions, mock_motors, dummy_motors):
assert len(ids) == len(positions)
ids_values = dict(zip(ids, positions, strict=True))
stub_name = mock_motors.build_sync_write_stub("Goal_Position", ids_values)
motors_bus = FeetechMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
motors_bus.connect(assert_motors_exist=False)
)
def test__sync_read(addr, length, ids_values, mock_motors, dummy_motors):
stub = mock_motors.build_sync_read_stub(addr, length, ids_values)
bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors)
bus.connect(handshake=False)
write_values = {f"dummy_{id_}": pos for id_, pos in ids_values.items()}
motors_bus.sync_write("Goal_Position", write_values, normalize=False)
read_values, _ = bus._sync_read(addr, length, list(ids_values))
assert mock_motors.stubs[stub_name].wait_called()
assert mock_motors.stubs[stub].called
assert read_values == ids_values
@pytest.mark.parametrize("raise_on_error", (True, False))
def test__sync_read_comm(raise_on_error, mock_motors, dummy_motors):
addr, length, ids_values = (10, 4, {1: 1337})
stub = mock_motors.build_sync_read_stub(addr, length, ids_values, reply=False)
bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors)
bus.connect(handshake=False)
if raise_on_error:
with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")):
bus._sync_read(addr, length, list(ids_values), raise_on_error=raise_on_error)
else:
_, read_comm = bus._sync_read(addr, length, list(ids_values), raise_on_error=raise_on_error)
assert read_comm == scs.COMM_RX_TIMEOUT
assert mock_motors.stubs[stub].called
@pytest.mark.parametrize(
"data_name, dxl_id, value",
"addr, length, ids_values",
[
("Torque_Enable", 1, 0),
("Torque_Enable", 1, 1),
("Goal_Position", 2, 1337),
("Goal_Position", 3, 42),
(0, 1, {1: 4}),
(10, 2, {1: 1337, 2: 42}),
(42, 4, {1: 1337, 2: 42, 3: 4016}),
],
ids=["1 motor", "2 motors", "3 motors"],
)
def test_write(data_name, dxl_id, value, mock_motors, dummy_motors):
stub_name = mock_motors.build_write_stub(data_name, dxl_id, value)
motors_bus = FeetechMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
motors_bus.connect(assert_motors_exist=False)
def test__sync_write(addr, length, ids_values, mock_motors, dummy_motors):
stub = mock_motors.build_sync_write_stub(addr, length, ids_values)
bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors)
bus.connect(handshake=False)
motors_bus.write(data_name, f"dummy_{dxl_id}", value, normalize=False)
comm = bus._sync_write(addr, length, ids_values)
assert mock_motors.stubs[stub_name].called
assert mock_motors.stubs[stub].wait_called()
assert comm == scs.COMM_SUCCESS
def test_is_calibrated(mock_motors, dummy_motors, dummy_calibration):
encoded_homings = {m.id: encode_sign_magnitude(m.homing_offset, 11) for m in dummy_calibration.values()}
mins = {m.id: m.range_min for m in dummy_calibration.values()}
maxes = {m.id: m.range_max for m in dummy_calibration.values()}
offsets_stub = mock_motors.build_sync_read_stub("Homing_Offset", encoded_homings)
mins_stub = mock_motors.build_sync_read_stub("Min_Position_Limit", mins)
maxes_stub = mock_motors.build_sync_read_stub("Max_Position_Limit", maxes)
motors_bus = FeetechMotorsBus(
offsets_stub = mock_motors.build_sync_read_stub(
*STS_SMS_SERIES_CONTROL_TABLE["Homing_Offset"], encoded_homings
)
mins_stub = mock_motors.build_sync_read_stub(*STS_SMS_SERIES_CONTROL_TABLE["Min_Position_Limit"], mins)
maxes_stub = mock_motors.build_sync_read_stub(*STS_SMS_SERIES_CONTROL_TABLE["Max_Position_Limit"], maxes)
bus = FeetechMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
calibration=dummy_calibration,
)
motors_bus.connect(assert_motors_exist=False)
bus.connect(handshake=False)
is_calibrated = motors_bus.is_calibrated
is_calibrated = bus.is_calibrated
assert is_calibrated
assert mock_motors.stubs[offsets_stub].called
@ -361,17 +339,20 @@ def test_reset_calibration(mock_motors, dummy_motors):
write_mins_stubs = []
write_maxes_stubs = []
for motor in dummy_motors.values():
write_homing_stubs.append(mock_motors.build_write_stub("Homing_Offset", motor.id, 0))
write_mins_stubs.append(mock_motors.build_write_stub("Min_Position_Limit", motor.id, 0))
write_maxes_stubs.append(mock_motors.build_write_stub("Max_Position_Limit", motor.id, 4095))
write_homing_stubs.append(
mock_motors.build_write_stub(*STS_SMS_SERIES_CONTROL_TABLE["Homing_Offset"], motor.id, 0)
)
write_mins_stubs.append(
mock_motors.build_write_stub(*STS_SMS_SERIES_CONTROL_TABLE["Min_Position_Limit"], motor.id, 0)
)
write_maxes_stubs.append(
mock_motors.build_write_stub(*STS_SMS_SERIES_CONTROL_TABLE["Max_Position_Limit"], motor.id, 4095)
)
motors_bus = FeetechMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
motors_bus.connect(assert_motors_exist=False)
bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors)
bus.connect(handshake=False)
motors_bus.reset_calibration()
bus.reset_calibration()
assert all(mock_motors.stubs[stub].called for stub in write_homing_stubs)
assert all(mock_motors.stubs[stub].called for stub in write_mins_stubs)
@ -393,23 +374,24 @@ def test_set_half_turn_homings(mock_motors, dummy_motors):
2: -2005, # 42 - 2047
3: 1625, # 3672 - 2047
}
read_pos_stub = mock_motors.build_sync_read_stub("Present_Position", current_positions)
read_pos_stub = mock_motors.build_sync_read_stub(
*STS_SMS_SERIES_CONTROL_TABLE["Present_Position"], current_positions
)
write_homing_stubs = []
for id_, homing in expected_homings.items():
encoded_homing = encode_sign_magnitude(homing, 11)
stub = mock_motors.build_write_stub("Homing_Offset", id_, encoded_homing)
stub = mock_motors.build_write_stub(
*STS_SMS_SERIES_CONTROL_TABLE["Homing_Offset"], id_, encoded_homing
)
write_homing_stubs.append(stub)
motors_bus = FeetechMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
motors_bus.connect(assert_motors_exist=False)
motors_bus.reset_calibration = MagicMock()
bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors)
bus.connect(handshake=False)
bus.reset_calibration = MagicMock()
motors_bus.set_half_turn_homings()
bus.set_half_turn_homings()
motors_bus.reset_calibration.assert_called_once()
bus.reset_calibration.assert_called_once()
assert mock_motors.stubs[read_pos_stub].called
assert all(mock_motors.stubs[stub].called for stub in write_homing_stubs)
@ -430,16 +412,15 @@ def test_record_ranges_of_motion(mock_motors, dummy_motors):
"dummy_2": 3600,
"dummy_3": 4002,
}
read_pos_stub = mock_motors.build_sequential_sync_read_stub("Present_Position", positions)
stub = mock_motors.build_sequential_sync_read_stub(
*STS_SMS_SERIES_CONTROL_TABLE["Present_Position"], positions
)
with patch("lerobot.common.motors.motors_bus.enter_pressed", side_effect=[False, True]):
motors_bus = FeetechMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
motors_bus.connect(assert_motors_exist=False)
bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors)
bus.connect(handshake=False)
mins, maxes = motors_bus.record_ranges_of_motion(display_values=False)
mins, maxes = bus.record_ranges_of_motion(display_values=False)
assert mock_motors.stubs[read_pos_stub].calls == 3
assert mock_motors.stubs[stub].calls == 3
assert mins == expected_mins
assert maxes == expected_maxes

View File

@ -1,87 +1,469 @@
# ruff: noqa: N802
import re
from unittest.mock import patch
import pytest
from lerobot.common.motors.motors_bus import assert_same_address, get_address, get_ctrl_table
from lerobot.common.motors.motors_bus import (
Motor,
MotorNormMode,
MotorsBus,
assert_same_address,
get_address,
get_ctrl_table,
)
# TODO(aliberts)
# class DummyMotorsBus(MotorsBus):
# def __init__(self, port: str, motors: dict[str, Motor]):
# super().__init__(port, motors)
DUMMY_CTRL_TABLE_1 = {
"Firmware_Version": (0, 1),
"Model_Number": (1, 2),
"Present_Position": (3, 4),
"Goal_Position": (11, 2),
}
DUMMY_CTRL_TABLE_2 = {
"Model_Number": (0, 2),
"Firmware_Version": (2, 1),
"Present_Position": (3, 4),
"Present_Velocity": (7, 4),
"Goal_Position": (11, 4),
"Goal_Velocity": (15, 4),
"Lock": (19, 1),
}
DUMMY_MODEL_CTRL_TABLE = {
"model_1": DUMMY_CTRL_TABLE_1,
"model_2": DUMMY_CTRL_TABLE_2,
"model_3": DUMMY_CTRL_TABLE_2,
}
DUMMY_BAUDRATE_TABLE = {
0: 1_000_000,
1: 500_000,
2: 250_000,
}
DUMMY_MODEL_BAUDRATE_TABLE = {
"model_1": DUMMY_BAUDRATE_TABLE,
"model_2": DUMMY_BAUDRATE_TABLE,
"model_3": DUMMY_BAUDRATE_TABLE,
}
DUMMY_ENCODING_TABLE = {
"Present_Position": 8,
"Goal_Position": 10,
}
DUMMY_MODEL_ENCODING_TABLE = {
"model_1": DUMMY_ENCODING_TABLE,
"model_2": DUMMY_ENCODING_TABLE,
"model_3": DUMMY_ENCODING_TABLE,
}
DUMMY_MODEL_NUMBER_TABLE = {
"model_1": 1234,
"model_2": 5678,
"model_3": 5799,
}
DUMMY_MODEL_RESOLUTION_TABLE = {
"model_1": 4096,
"model_2": 1024,
"model_3": 4096,
}
class MockPortHandler:
def __init__(self, port_name):
self.is_open: bool = False
self.baudrate: int
self.packet_start_time: float
self.packet_timeout: float
self.tx_time_per_byte: float
self.is_using: bool = False
self.port_name: str = port_name
self.ser = None
def openPort(self):
self.is_open = True
return self.is_open
def closePort(self):
self.is_open = False
def clearPort(self): ...
def setPortName(self, port_name):
self.port_name = port_name
def getPortName(self):
return self.port_name
def setBaudRate(self, baudrate):
self.baudrate: baudrate
def getBaudRate(self):
return self.baudrate
def getBytesAvailable(self): ...
def readPort(self, length): ...
def writePort(self, packet): ...
def setPacketTimeout(self, packet_length): ...
def setPacketTimeoutMillis(self, msec): ...
def isPacketTimeout(self): ...
def getCurrentTime(self): ...
def getTimeSinceStart(self): ...
def setupPort(self, cflag_baud): ...
def getCFlagBaud(self, baudrate): ...
class MockMotorsBus(MotorsBus):
available_baudrates = [500_000, 1_000_000]
default_timeout = 1000
model_baudrate_table = DUMMY_MODEL_BAUDRATE_TABLE
model_ctrl_table = DUMMY_MODEL_CTRL_TABLE
model_encoding_table = DUMMY_MODEL_ENCODING_TABLE
model_number_table = DUMMY_MODEL_NUMBER_TABLE
model_resolution_table = DUMMY_MODEL_RESOLUTION_TABLE
normalized_data = ["Present_Position", "Goal_Position"]
def __init__(self, port: str, motors: dict[str, Motor]):
super().__init__(port, motors)
self.port_handler = MockPortHandler(port)
def _assert_protocol_is_compatible(self, instruction_name): ...
def _handshake(self): ...
def _find_single_motor(self, motor, initial_baudrate): ...
def configure_motors(self): ...
def read_calibration(self): ...
def write_calibration(self, calibration_dict): ...
def disable_torque(self, motors): ...
def enable_torque(self, motors): ...
def _get_half_turn_homings(self, positions): ...
def _encode_sign(self, data_name, ids_values): ...
def _decode_sign(self, data_name, ids_values): ...
def _split_into_byte_chunks(self, value, length): ...
def broadcast_ping(self, num_retry, raise_on_error): ...
@pytest.fixture
def ctrl_table_1() -> dict:
def dummy_motors() -> dict[str, Motor]:
return {
"Firmware_Version": (0, 1),
"Model_Number": (1, 2),
"Present_Position": (3, 4),
"Goal_Position": (7, 2),
"dummy_1": Motor(1, "model_2", MotorNormMode.RANGE_M100_100),
"dummy_2": Motor(2, "model_3", MotorNormMode.RANGE_M100_100),
"dummy_3": Motor(3, "model_2", MotorNormMode.RANGE_0_100),
}
@pytest.fixture
def ctrl_table_2() -> dict:
return {
"Model_Number": (0, 2),
"Firmware_Version": (2, 1),
"Present_Position": (3, 4),
"Goal_Position": (7, 4),
"Lock": (7, 4),
}
@pytest.fixture
def model_ctrl_table(ctrl_table_1, ctrl_table_2) -> dict:
return {
"model_1": ctrl_table_1,
"model_2": ctrl_table_2,
}
def test_get_ctrl_table(model_ctrl_table, ctrl_table_1):
def test_get_ctrl_table():
model = "model_1"
ctrl_table = get_ctrl_table(model_ctrl_table, model)
assert ctrl_table == ctrl_table_1
ctrl_table = get_ctrl_table(DUMMY_MODEL_CTRL_TABLE, model)
assert ctrl_table == DUMMY_CTRL_TABLE_1
def test_get_ctrl_table_error(model_ctrl_table):
def test_get_ctrl_table_error():
model = "model_99"
with pytest.raises(KeyError, match=f"Control table for {model=} not found."):
get_ctrl_table(model_ctrl_table, model)
get_ctrl_table(DUMMY_MODEL_CTRL_TABLE, model)
def test_get_address(model_ctrl_table):
addr, n_bytes = get_address(model_ctrl_table, "model_1", "Firmware_Version")
def test_get_address():
addr, n_bytes = get_address(DUMMY_MODEL_CTRL_TABLE, "model_1", "Firmware_Version")
assert addr == 0
assert n_bytes == 1
def test_get_address_error(model_ctrl_table):
def test_get_address_error():
model = "model_1"
data_name = "Lock"
with pytest.raises(KeyError, match=f"Address for '{data_name}' not found in {model} control table."):
get_address(model_ctrl_table, "model_1", data_name)
get_address(DUMMY_MODEL_CTRL_TABLE, "model_1", data_name)
def test_assert_same_address(model_ctrl_table):
def test_assert_same_address():
models = ["model_1", "model_2"]
assert_same_address(model_ctrl_table, models, "Present_Position")
assert_same_address(DUMMY_MODEL_CTRL_TABLE, models, "Present_Position")
def test_assert_same_address_different_addresses(model_ctrl_table):
def test_assert_same_length_different_addresses():
models = ["model_1", "model_2"]
with pytest.raises(
NotImplementedError,
match=re.escape("At least two motor models use a different address"),
):
assert_same_address(model_ctrl_table, models, "Model_Number")
assert_same_address(DUMMY_MODEL_CTRL_TABLE, models, "Model_Number")
def test_assert_same_address_different_bytes(model_ctrl_table):
def test_assert_same_address_different_length():
models = ["model_1", "model_2"]
with pytest.raises(
NotImplementedError,
match=re.escape("At least two motor models use a different bytes representation"),
):
assert_same_address(model_ctrl_table, models, "Goal_Position")
assert_same_address(DUMMY_MODEL_CTRL_TABLE, models, "Goal_Position")
def test__serialize_data_invalid_length():
bus = MockMotorsBus("", {})
with pytest.raises(NotImplementedError):
bus._serialize_data(100, 3)
def test__serialize_data_negative_numbers():
bus = MockMotorsBus("", {})
with pytest.raises(ValueError):
bus._serialize_data(-1, 1)
def test__serialize_data_large_number():
bus = MockMotorsBus("", {})
with pytest.raises(ValueError):
bus._serialize_data(2**32, 4) # 4-byte max is 0xFFFFFFFF
@pytest.mark.parametrize(
"data_name, id_, value",
[
("Firmware_Version", 1, 14),
("Model_Number", 1, 5678),
("Present_Position", 2, 1337),
("Present_Velocity", 3, 42),
],
)
def test_read(data_name, id_, value, dummy_motors):
bus = MockMotorsBus("/dev/dummy-port", dummy_motors)
bus.connect(handshake=False)
addr, length = DUMMY_CTRL_TABLE_2[data_name]
with (
patch.object(MockMotorsBus, "_read", return_value=(value, 0, 0)) as mock__read,
patch.object(MockMotorsBus, "_decode_sign", return_value={id_: value}) as mock__decode_sign,
patch.object(MockMotorsBus, "_normalize", return_value={id_: value}) as mock__normalize,
):
returned_value = bus.read(data_name, f"dummy_{id_}")
assert returned_value == value
mock__read.assert_called_once_with(
addr,
length,
id_,
num_retry=0,
raise_on_error=True,
err_msg=f"Failed to read '{data_name}' on {id_=} after 1 tries.",
)
mock__decode_sign.assert_called_once_with(data_name, {id_: value})
if data_name in bus.normalized_data:
mock__normalize.assert_called_once_with(data_name, {id_: value})
@pytest.mark.parametrize(
"data_name, id_, value",
[
("Goal_Position", 1, 1337),
("Goal_Velocity", 2, 3682),
("Lock", 3, 1),
],
)
def test_write(data_name, id_, value, dummy_motors):
bus = MockMotorsBus("/dev/dummy-port", dummy_motors)
bus.connect(handshake=False)
addr, length = DUMMY_CTRL_TABLE_2[data_name]
with (
patch.object(MockMotorsBus, "_write", return_value=(0, 0)) as mock__write,
patch.object(MockMotorsBus, "_encode_sign", return_value={id_: value}) as mock__encode_sign,
patch.object(MockMotorsBus, "_unnormalize", return_value={id_: value}) as mock__unnormalize,
):
bus.write(data_name, f"dummy_{id_}", value)
mock__write.assert_called_once_with(
addr,
length,
id_,
value,
num_retry=0,
raise_on_error=True,
err_msg=f"Failed to write '{data_name}' on {id_=} with '{value}' after 1 tries.",
)
mock__encode_sign.assert_called_once_with(data_name, {id_: value})
if data_name in bus.normalized_data:
mock__unnormalize.assert_called_once_with(data_name, {id_: value})
@pytest.mark.parametrize(
"data_name, id_, value",
[
("Firmware_Version", 1, 14),
("Model_Number", 1, 5678),
("Present_Position", 2, 1337),
("Present_Velocity", 3, 42),
],
)
def test_sync_read_by_str(data_name, id_, value, dummy_motors):
bus = MockMotorsBus("/dev/dummy-port", dummy_motors)
bus.connect(handshake=False)
addr, length = DUMMY_CTRL_TABLE_2[data_name]
ids = [id_]
expected_value = {f"dummy_{id_}": value}
with (
patch.object(MockMotorsBus, "_sync_read", return_value=({id_: value}, 0)) as mock__sync_read,
patch.object(MockMotorsBus, "_decode_sign", return_value={id_: value}) as mock__decode_sign,
patch.object(MockMotorsBus, "_normalize", return_value={id_: value}) as mock__normalize,
):
returned_dict = bus.sync_read(data_name, f"dummy_{id_}")
assert returned_dict == expected_value
mock__sync_read.assert_called_once_with(
addr,
length,
ids,
num_retry=0,
raise_on_error=True,
err_msg=f"Failed to sync read '{data_name}' on {ids=} after 1 tries.",
)
mock__decode_sign.assert_called_once_with(data_name, {id_: value})
if data_name in bus.normalized_data:
mock__normalize.assert_called_once_with(data_name, {id_: value})
@pytest.mark.parametrize(
"data_name, ids_values",
[
("Model_Number", {1: 5678}),
("Present_Position", {1: 1337, 2: 42}),
("Present_Velocity", {1: 1337, 2: 42, 3: 4016}),
],
ids=["1 motor", "2 motors", "3 motors"],
)
def test_sync_read_by_list(data_name, ids_values, dummy_motors):
bus = MockMotorsBus("/dev/dummy-port", dummy_motors)
bus.connect(handshake=False)
addr, length = DUMMY_CTRL_TABLE_2[data_name]
ids = list(ids_values)
expected_values = {f"dummy_{id_}": val for id_, val in ids_values.items()}
with (
patch.object(MockMotorsBus, "_sync_read", return_value=(ids_values, 0)) as mock__sync_read,
patch.object(MockMotorsBus, "_decode_sign", return_value=ids_values) as mock__decode_sign,
patch.object(MockMotorsBus, "_normalize", return_value=ids_values) as mock__normalize,
):
returned_dict = bus.sync_read(data_name, [f"dummy_{id_}" for id_ in ids])
assert returned_dict == expected_values
mock__sync_read.assert_called_once_with(
addr,
length,
ids,
num_retry=0,
raise_on_error=True,
err_msg=f"Failed to sync read '{data_name}' on {ids=} after 1 tries.",
)
mock__decode_sign.assert_called_once_with(data_name, ids_values)
if data_name in bus.normalized_data:
mock__normalize.assert_called_once_with(data_name, ids_values)
@pytest.mark.parametrize(
"data_name, ids_values",
[
("Model_Number", {1: 5678, 2: 5799, 3: 5678}),
("Present_Position", {1: 1337, 2: 42, 3: 4016}),
("Goal_Position", {1: 4008, 2: 199, 3: 3446}),
],
ids=["Model_Number", "Present_Position", "Goal_Position"],
)
def test_sync_read_by_none(data_name, ids_values, dummy_motors):
bus = MockMotorsBus("/dev/dummy-port", dummy_motors)
bus.connect(handshake=False)
addr, length = DUMMY_CTRL_TABLE_2[data_name]
ids = list(ids_values)
expected_values = {f"dummy_{id_}": val for id_, val in ids_values.items()}
with (
patch.object(MockMotorsBus, "_sync_read", return_value=(ids_values, 0)) as mock__sync_read,
patch.object(MockMotorsBus, "_decode_sign", return_value=ids_values) as mock__decode_sign,
patch.object(MockMotorsBus, "_normalize", return_value=ids_values) as mock__normalize,
):
returned_dict = bus.sync_read(data_name)
assert returned_dict == expected_values
mock__sync_read.assert_called_once_with(
addr,
length,
ids,
num_retry=0,
raise_on_error=True,
err_msg=f"Failed to sync read '{data_name}' on {ids=} after 1 tries.",
)
mock__decode_sign.assert_called_once_with(data_name, ids_values)
if data_name in bus.normalized_data:
mock__normalize.assert_called_once_with(data_name, ids_values)
@pytest.mark.parametrize(
"data_name, value",
[
("Goal_Position", 500),
("Goal_Velocity", 4010),
("Lock", 0),
],
)
def test_sync_write_by_single_value(data_name, value, dummy_motors):
bus = MockMotorsBus("/dev/dummy-port", dummy_motors)
bus.connect(handshake=False)
addr, length = DUMMY_CTRL_TABLE_2[data_name]
ids_values = {m.id: value for m in dummy_motors.values()}
with (
patch.object(MockMotorsBus, "_sync_write", return_value=(ids_values, 0)) as mock__sync_write,
patch.object(MockMotorsBus, "_encode_sign", return_value=ids_values) as mock__encode_sign,
patch.object(MockMotorsBus, "_unnormalize", return_value=ids_values) as mock__unnormalize,
):
bus.sync_write(data_name, value)
mock__sync_write.assert_called_once_with(
addr,
length,
ids_values,
num_retry=0,
raise_on_error=True,
err_msg=f"Failed to sync write '{data_name}' with {ids_values=} after 1 tries.",
)
mock__encode_sign.assert_called_once_with(data_name, ids_values)
if data_name in bus.normalized_data:
mock__unnormalize.assert_called_once_with(data_name, ids_values)
@pytest.mark.parametrize(
"data_name, ids_values",
[
("Goal_Position", {1: 1337, 2: 42, 3: 4016}),
("Goal_Velocity", {1: 50, 2: 83, 3: 2777}),
("Lock", {1: 0, 2: 0, 3: 1}),
],
ids=["Goal_Position", "Goal_Velocity", "Lock"],
)
def test_sync_write_by_value_dict(data_name, ids_values, dummy_motors):
bus = MockMotorsBus("/dev/dummy-port", dummy_motors)
bus.connect(handshake=False)
addr, length = DUMMY_CTRL_TABLE_2[data_name]
values = {f"dummy_{id_}": val for id_, val in ids_values.items()}
with (
patch.object(MockMotorsBus, "_sync_write", return_value=(ids_values, 0)) as mock__sync_write,
patch.object(MockMotorsBus, "_encode_sign", return_value=ids_values) as mock__encode_sign,
patch.object(MockMotorsBus, "_unnormalize", return_value=ids_values) as mock__unnormalize,
):
bus.sync_write(data_name, values)
mock__sync_write.assert_called_once_with(
addr,
length,
ids_values,
num_retry=0,
raise_on_error=True,
err_msg=f"Failed to sync write '{data_name}' with {ids_values=} after 1 tries.",
)
mock__encode_sign.assert_called_once_with(data_name, ids_values)
if data_name in bus.normalized_data:
mock__unnormalize.assert_called_once_with(data_name, ids_values)

View File

@ -172,8 +172,7 @@ def test_record_and_replay_and_policy(tmp_path, request, robot_type, mock):
push_to_hub=False,
# TODO(rcadene, aliberts): test video=True
video=False,
# TODO(rcadene): display cameras through cv2 sometimes crashes on mac
display_cameras=False,
display_data=False,
play_sounds=False,
)
dataset = record(robot, rec_cfg)
@ -226,7 +225,7 @@ def test_record_and_replay_and_policy(tmp_path, request, robot_type, mock):
num_episodes=2,
push_to_hub=False,
video=False,
display_cameras=False,
display_data=False,
play_sounds=False,
num_image_writer_processes=num_image_writer_processes,
)
@ -273,7 +272,7 @@ def test_resume_record(tmp_path, request, robot_type, mock):
episode_time_s=1,
push_to_hub=False,
video=False,
display_cameras=False,
display_data=False,
play_sounds=False,
num_episodes=1,
)
@ -330,7 +329,7 @@ def test_record_with_event_rerecord_episode(tmp_path, request, robot_type, mock)
num_episodes=1,
push_to_hub=False,
video=False,
display_cameras=False,
display_data=False,
play_sounds=False,
)
dataset = record(robot, rec_cfg)
@ -380,7 +379,7 @@ def test_record_with_event_exit_early(tmp_path, request, robot_type, mock):
num_episodes=1,
push_to_hub=False,
video=False,
display_cameras=False,
display_data=False,
play_sounds=False,
)
@ -433,7 +432,7 @@ def test_record_with_event_stop_recording(tmp_path, request, robot_type, mock, n
num_episodes=2,
push_to_hub=False,
video=False,
display_cameras=False,
display_data=False,
play_sounds=False,
num_image_writer_processes=num_image_writer_processes,
)