Compare commits
42 Commits
c6548caf5d
...
a0657ee274
Author | SHA1 | Date |
---|---|---|
|
a0657ee274 | |
|
d7b9866a7c | |
|
6cd06196c3 | |
|
4f5d840cac | |
|
7dedbeb457 | |
|
6b4931b4f0 | |
|
a38e989cab | |
|
833ab383dd | |
|
48b7e2a137 | |
|
2b100122f5 | |
|
66325b5a42 | |
|
dc3360c06b | |
|
66017f16a0 | |
|
87b0a5995c | |
|
cf35a5e986 | |
|
caa69be553 | |
|
a247e4b2be | |
|
5c925c779b | |
|
73956e31b2 | |
|
d43f1a8136 | |
|
bf1c737858 | |
|
d07c7347f8 | |
|
57e5e4cc07 | |
|
2743c29a96 | |
|
2bb73ac431 | |
|
9afc4b771c | |
|
f71e224023 | |
|
889de7c415 | |
|
3539251b18 | |
|
1f210bc8a3 | |
|
d70bc4bde9 | |
|
bdbca09cb2 | |
|
e0b292ab51 | |
|
f960f4d8d4 | |
|
9e57ec7837 | |
|
0a7f51f0da | |
|
4ca92a28e9 | |
|
0464dc91b3 | |
|
d32daebf75 | |
|
5322417c03 | |
|
4041f57943 | |
|
2c86fea78a |
|
@ -11,7 +11,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
.dev
|
||||||
# Logging
|
# Logging
|
||||||
logs
|
logs
|
||||||
tmp
|
tmp
|
||||||
|
|
|
@ -36,8 +36,8 @@ repos:
|
||||||
- id: end-of-file-fixer
|
- id: end-of-file-fixer
|
||||||
- id: trailing-whitespace
|
- id: trailing-whitespace
|
||||||
|
|
||||||
- repo: https://github.com/crate-ci/typos
|
- repo: https://github.com/adhtruong/mirrors-typos
|
||||||
rev: v1
|
rev: v1.31.1
|
||||||
hooks:
|
hooks:
|
||||||
- id: typos
|
- id: typos
|
||||||
args: [--force-exclude]
|
args: [--force-exclude]
|
||||||
|
|
|
@ -98,14 +98,14 @@ conda create -y -n lerobot python=3.10
|
||||||
conda activate lerobot
|
conda activate lerobot
|
||||||
```
|
```
|
||||||
|
|
||||||
When using `miniconda`, if you don't have `ffmpeg` in your environment:
|
When using `miniconda`, install `ffmpeg` in your environment:
|
||||||
```bash
|
```bash
|
||||||
conda install ffmpeg
|
conda install ffmpeg -c conda-forge
|
||||||
```
|
```
|
||||||
|
|
||||||
Install 🤗 LeRobot:
|
Install 🤗 LeRobot:
|
||||||
```bash
|
```bash
|
||||||
pip install --no-binary=av -e .
|
pip install -e .
|
||||||
```
|
```
|
||||||
|
|
||||||
> **NOTE:** If you encounter build errors, you may need to install additional dependencies (`cmake`, `build-essential`, and `ffmpeg libs`). On Linux, run:
|
> **NOTE:** If you encounter build errors, you may need to install additional dependencies (`cmake`, `build-essential`, and `ffmpeg libs`). On Linux, run:
|
||||||
|
@ -118,7 +118,7 @@ For simulations, 🤗 LeRobot comes with gymnasium environments that can be inst
|
||||||
|
|
||||||
For instance, to install 🤗 LeRobot with aloha and pusht, use:
|
For instance, to install 🤗 LeRobot with aloha and pusht, use:
|
||||||
```bash
|
```bash
|
||||||
pip install --no-binary=av -e ".[aloha, pusht]"
|
pip install -e ".[aloha, pusht]"
|
||||||
```
|
```
|
||||||
|
|
||||||
To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiment tracking, log in with
|
To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiment tracking, log in with
|
||||||
|
|
|
@ -17,12 +17,21 @@
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import datetime as dt
|
import datetime as dt
|
||||||
|
import os
|
||||||
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
|
import rerun as rr
|
||||||
|
|
||||||
|
# see https://rerun.io/docs/howto/visualization/limit-ram
|
||||||
|
RERUN_MEMORY_LIMIT = os.getenv("LEROBOT_RERUN_MEMORY_LIMIT", "5%")
|
||||||
|
|
||||||
|
|
||||||
def display_and_save_video_stream(output_dir: Path, fps: int, width: int, height: int):
|
def display_and_save_video_stream(output_dir: Path, fps: int, width: int, height: int, duration: int):
|
||||||
|
rr.init("lerobot_capture_camera_feed")
|
||||||
|
rr.spawn(memory_limit=RERUN_MEMORY_LIMIT)
|
||||||
|
|
||||||
now = dt.datetime.now()
|
now = dt.datetime.now()
|
||||||
capture_dir = output_dir / f"{now:%Y-%m-%d}" / f"{now:%H-%M-%S}"
|
capture_dir = output_dir / f"{now:%Y-%m-%d}" / f"{now:%H-%M-%S}"
|
||||||
if not capture_dir.exists():
|
if not capture_dir.exists():
|
||||||
|
@ -39,24 +48,21 @@ def display_and_save_video_stream(output_dir: Path, fps: int, width: int, height
|
||||||
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
|
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
|
||||||
|
|
||||||
frame_index = 0
|
frame_index = 0
|
||||||
while True:
|
start_time = time.time()
|
||||||
|
while time.time() - start_time < duration:
|
||||||
ret, frame = cap.read()
|
ret, frame = cap.read()
|
||||||
|
|
||||||
if not ret:
|
if not ret:
|
||||||
print("Error: Could not read frame.")
|
print("Error: Could not read frame.")
|
||||||
break
|
break
|
||||||
|
rr.log("video/stream", rr.Image(frame.numpy()), static=True)
|
||||||
cv2.imshow("Video Stream", frame)
|
|
||||||
cv2.imwrite(str(capture_dir / f"frame_{frame_index:06d}.png"), frame)
|
cv2.imwrite(str(capture_dir / f"frame_{frame_index:06d}.png"), frame)
|
||||||
frame_index += 1
|
frame_index += 1
|
||||||
|
|
||||||
# Break the loop on 'q' key press
|
# Release the capture
|
||||||
if cv2.waitKey(1) & 0xFF == ord("q"):
|
|
||||||
break
|
|
||||||
|
|
||||||
# Release the capture and destroy all windows
|
|
||||||
cap.release()
|
cap.release()
|
||||||
cv2.destroyAllWindows()
|
|
||||||
|
# TODO(Steven): Add a graceful shutdown via a close() method for the Viewer context, though not currently supported in the Rerun API.
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -86,5 +92,11 @@ if __name__ == "__main__":
|
||||||
default=720,
|
default=720,
|
||||||
help="Height of the captured images.",
|
help="Height of the captured images.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--duration",
|
||||||
|
type=int,
|
||||||
|
default=20,
|
||||||
|
help="Duration in seconds for which the video stream should be captured.",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
display_and_save_video_stream(**vars(args))
|
display_and_save_video_stream(**vars(args))
|
||||||
|
|
|
@ -18,7 +18,7 @@ training outputs directory. In the latter case, you might want to run examples/3
|
||||||
|
|
||||||
It requires the installation of the 'gym_pusht' simulation environment. Install it by running:
|
It requires the installation of the 'gym_pusht' simulation environment. Install it by running:
|
||||||
```bash
|
```bash
|
||||||
pip install --no-binary=av -e ".[pusht]"`
|
pip install -e ".[pusht]"
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
|
@ -33,7 +33,7 @@ First, install the additional dependencies required for robots built with dynami
|
||||||
|
|
||||||
Using `pip`:
|
Using `pip`:
|
||||||
```bash
|
```bash
|
||||||
pip install --no-binary=av -e ".[dynamixel]"
|
pip install -e ".[dynamixel]"
|
||||||
```
|
```
|
||||||
|
|
||||||
Using `poetry`:
|
Using `poetry`:
|
||||||
|
@ -55,6 +55,9 @@ Finally, connect both arms to your computer via USB. Note that the USB doesn't p
|
||||||
Now you are ready to configure your motors for the first time, as detailed in the sections below. In the upcoming sections, you'll learn about our classes and functions by running some python code in an interactive session, or by copy-pasting it in a python file.
|
Now you are ready to configure your motors for the first time, as detailed in the sections below. In the upcoming sections, you'll learn about our classes and functions by running some python code in an interactive session, or by copy-pasting it in a python file.
|
||||||
|
|
||||||
If you have already configured your motors the first time, you can streamline the process by directly running the teleoperate script (which is detailed further in the tutorial):
|
If you have already configured your motors the first time, you can streamline the process by directly running the teleoperate script (which is detailed further in the tutorial):
|
||||||
|
|
||||||
|
> **NOTE:** To visualize the data, enable `--control.display_data=true`. This streams the data using `rerun`.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python lerobot/scripts/control_robot.py \
|
python lerobot/scripts/control_robot.py \
|
||||||
--robot.type=koch \
|
--robot.type=koch \
|
||||||
|
@ -829,7 +832,7 @@ It contains:
|
||||||
Troubleshooting:
|
Troubleshooting:
|
||||||
- On Linux, if you encounter any issue during video encoding with `ffmpeg: unknown encoder libsvtav1`, you can:
|
- On Linux, if you encounter any issue during video encoding with `ffmpeg: unknown encoder libsvtav1`, you can:
|
||||||
- install with conda-forge by running `conda install -c conda-forge ffmpeg` (it should be compiled with `libsvtav1`),
|
- install with conda-forge by running `conda install -c conda-forge ffmpeg` (it should be compiled with `libsvtav1`),
|
||||||
- or, install [Homebrew](https://brew.sh) and run `brew install ffmpeg` (it should be compiled with `libsvtav1`),
|
> **NOTE:** This usually installs `ffmpeg 7.X` for your platform (check the version installed with `ffmpeg -encoders | grep libsvtav1`). If it isn't `ffmpeg 7.X` or lacks `libsvtav1` support, you can explicitly install `ffmpeg 7.X` using: `conda install ffmpeg=7.1.1 -c conda-forge`
|
||||||
- or, install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1),
|
- or, install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1),
|
||||||
- and, make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`.
|
- and, make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`.
|
||||||
- On Linux, if the left and right arrow keys and escape key don't have any effect during data recording, make sure you've set the `$DISPLAY` environment variable. See [pynput limitations](https://pynput.readthedocs.io/en/latest/limitations.html#linux).
|
- On Linux, if the left and right arrow keys and escape key don't have any effect during data recording, make sure you've set the `$DISPLAY` environment variable. See [pynput limitations](https://pynput.readthedocs.io/en/latest/limitations.html#linux).
|
||||||
|
|
|
@ -0,0 +1,144 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from lerobot.common.robots.lekiwi.config_lekiwi import LeKiwiClientConfig, RobotMode
|
||||||
|
from lerobot.common.robots.lekiwi.lekiwi_client import LeKiwiClient
|
||||||
|
from lerobot.common.teleoperators.keyboard import KeyboardTeleop, KeyboardTeleopConfig
|
||||||
|
from lerobot.common.teleoperators.so100 import SO100Leader, SO100LeaderConfig
|
||||||
|
|
||||||
|
# TODO(Steven): Check validity of these features
|
||||||
|
DUMMY_FEATURES = {
|
||||||
|
"observation.state": {
|
||||||
|
"dtype": "float64",
|
||||||
|
"shape": (9,),
|
||||||
|
"names": {
|
||||||
|
"motors": [
|
||||||
|
"arm_shoulder_pan",
|
||||||
|
"arm_shoulder_lift",
|
||||||
|
"arm_elbow_flex",
|
||||||
|
"arm_wrist_flex",
|
||||||
|
"arm_wrist_roll",
|
||||||
|
"arm_gripper",
|
||||||
|
"base_left_wheel",
|
||||||
|
"base_right_wheel",
|
||||||
|
"base_back_wheel",
|
||||||
|
]
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"action": {
|
||||||
|
"dtype": "float64",
|
||||||
|
"shape": (9,),
|
||||||
|
"names": {
|
||||||
|
"motors": [
|
||||||
|
"arm_shoulder_pan",
|
||||||
|
"arm_shoulder_lift",
|
||||||
|
"arm_elbow_flex",
|
||||||
|
"arm_wrist_flex",
|
||||||
|
"arm_wrist_roll",
|
||||||
|
"arm_gripper",
|
||||||
|
"base_left_wheel",
|
||||||
|
"base_right_wheel",
|
||||||
|
"base_back_wheel",
|
||||||
|
]
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"observation.images.front": {
|
||||||
|
"dtype": "image",
|
||||||
|
"shape": (640, 480, 3),
|
||||||
|
"names": [
|
||||||
|
"width",
|
||||||
|
"height",
|
||||||
|
"channels",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"observation.images.wrist": {
|
||||||
|
"dtype": "image",
|
||||||
|
"shape": (480, 640, 3),
|
||||||
|
"names": [
|
||||||
|
"width",
|
||||||
|
"height",
|
||||||
|
"channels",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
logging.info("Configuring Teleop Devices")
|
||||||
|
leader_arm_config = SO100LeaderConfig(port="/dev/tty.usbmodem58760434171")
|
||||||
|
leader_arm = SO100Leader(leader_arm_config)
|
||||||
|
|
||||||
|
keyboard_config = KeyboardTeleopConfig()
|
||||||
|
keyboard = KeyboardTeleop(keyboard_config)
|
||||||
|
|
||||||
|
logging.info("Configuring LeKiwi Client")
|
||||||
|
robot_config = LeKiwiClientConfig(id="lekiwi", robot_mode=RobotMode.TELEOP)
|
||||||
|
robot = LeKiwiClient(robot_config)
|
||||||
|
|
||||||
|
logging.info("Creating LeRobot Dataset")
|
||||||
|
|
||||||
|
# # TODO(Steven): Check this creation
|
||||||
|
# dataset = LeRobotDataset.create(
|
||||||
|
# repo_id="user/lekiwi2",
|
||||||
|
# fps=10,
|
||||||
|
# features=DUMMY_FEATURES,
|
||||||
|
# )
|
||||||
|
|
||||||
|
logging.info("Connecting Teleop Devices")
|
||||||
|
leader_arm.connect()
|
||||||
|
keyboard.connect()
|
||||||
|
|
||||||
|
logging.info("Connecting remote LeKiwi")
|
||||||
|
robot.connect()
|
||||||
|
|
||||||
|
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
|
||||||
|
logging.error("Failed to connect to all devices")
|
||||||
|
return
|
||||||
|
|
||||||
|
logging.info("Starting LeKiwi teleoperation")
|
||||||
|
i = 0
|
||||||
|
while i < 1000:
|
||||||
|
arm_action = leader_arm.get_action()
|
||||||
|
base_action = keyboard.get_action()
|
||||||
|
action = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
|
||||||
|
|
||||||
|
# TODO(Steven): Deal with policy action space
|
||||||
|
# robot.set_mode(RobotMode.AUTO)
|
||||||
|
# policy_action = policy.get_action() # This might be in body frame, key space or smt else
|
||||||
|
# robot.send_action(policy_action)
|
||||||
|
|
||||||
|
action_sent = robot.send_action(action)
|
||||||
|
observation = robot.get_observation()
|
||||||
|
|
||||||
|
frame = {**action_sent, **observation}
|
||||||
|
frame.update({"task": "Dummy Task Dataset"})
|
||||||
|
|
||||||
|
logging.info("Saved a frame into the dataset")
|
||||||
|
# dataset.add_frame(frame)
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
# dataset.save_episode()
|
||||||
|
# dataset.push_to_hub()
|
||||||
|
|
||||||
|
logging.info("Disconnecting Teleop Devices and LeKiwi Client")
|
||||||
|
robot.disconnect()
|
||||||
|
leader_arm.disconnect()
|
||||||
|
keyboard.disconnect()
|
||||||
|
logging.info("Finished LeKiwi cleanly")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -21,7 +21,7 @@ def main():
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
keyboard.disconnect()
|
keyboard.disconnect()
|
||||||
logging.info("Finished LeKiwiRobot cleanly")
|
logging.info("Finished LeKiwi cleanly")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
|
@ -48,5 +48,5 @@ default_cache_path = Path(HF_HOME) / "lerobot"
|
||||||
HF_LEROBOT_HOME = Path(os.getenv("HF_LEROBOT_HOME", default_cache_path)).expanduser()
|
HF_LEROBOT_HOME = Path(os.getenv("HF_LEROBOT_HOME", default_cache_path)).expanduser()
|
||||||
|
|
||||||
# calibration dir
|
# calibration dir
|
||||||
default_calibration_path = HF_LEROBOT_HOME / ".calibration"
|
default_calibration_path = HF_LEROBOT_HOME / "calibration"
|
||||||
HF_LEROBOT_CALIBRATION = Path(os.getenv("HF_LEROBOT_CALIBRATION", default_calibration_path)).expanduser()
|
HF_LEROBOT_CALIBRATION = Path(os.getenv("HF_LEROBOT_CALIBRATION", default_calibration_path)).expanduser()
|
||||||
|
|
|
@ -15,3 +15,14 @@ class DeviceAlreadyConnectedError(ConnectionError):
|
||||||
):
|
):
|
||||||
self.message = message
|
self.message = message
|
||||||
super().__init__(self.message)
|
super().__init__(self.message)
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidActionError(ConnectionError):
|
||||||
|
"""Exception raised when an action is already invalid."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message="The action is invalid. Check the value follows what it is expected from the action space.",
|
||||||
|
):
|
||||||
|
self.message = message
|
||||||
|
super().__init__(self.message)
|
||||||
|
|
|
@ -35,7 +35,7 @@ from .tables import (
|
||||||
)
|
)
|
||||||
|
|
||||||
PROTOCOL_VERSION = 2.0
|
PROTOCOL_VERSION = 2.0
|
||||||
BAUDRATE = 1_000_000
|
DEFAULT_BAUDRATE = 1_000_000
|
||||||
DEFAULT_TIMEOUT_MS = 1000
|
DEFAULT_TIMEOUT_MS = 1000
|
||||||
|
|
||||||
NORMALIZED_DATA = ["Goal_Position", "Present_Position"]
|
NORMALIZED_DATA = ["Goal_Position", "Present_Position"]
|
||||||
|
@ -84,6 +84,23 @@ class TorqueMode(Enum):
|
||||||
DISABLED = 0
|
DISABLED = 0
|
||||||
|
|
||||||
|
|
||||||
|
def _split_into_byte_chunks(value: int, length: int) -> list[int]:
|
||||||
|
import dynamixel_sdk as dxl
|
||||||
|
|
||||||
|
if length == 1:
|
||||||
|
data = [value]
|
||||||
|
elif length == 2:
|
||||||
|
data = [dxl.DXL_LOBYTE(value), dxl.DXL_HIBYTE(value)]
|
||||||
|
elif length == 4:
|
||||||
|
data = [
|
||||||
|
dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)),
|
||||||
|
dxl.DXL_HIBYTE(dxl.DXL_LOWORD(value)),
|
||||||
|
dxl.DXL_LOBYTE(dxl.DXL_HIWORD(value)),
|
||||||
|
dxl.DXL_HIBYTE(dxl.DXL_HIWORD(value)),
|
||||||
|
]
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class DynamixelMotorsBus(MotorsBus):
|
class DynamixelMotorsBus(MotorsBus):
|
||||||
"""
|
"""
|
||||||
The Dynamixel implementation for a MotorsBus. It relies on the python dynamixel sdk to communicate with
|
The Dynamixel implementation for a MotorsBus. It relies on the python dynamixel sdk to communicate with
|
||||||
|
@ -92,6 +109,7 @@ class DynamixelMotorsBus(MotorsBus):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
available_baudrates = deepcopy(AVAILABLE_BAUDRATES)
|
available_baudrates = deepcopy(AVAILABLE_BAUDRATES)
|
||||||
|
default_baudrate = DEFAULT_BAUDRATE
|
||||||
default_timeout = DEFAULT_TIMEOUT_MS
|
default_timeout = DEFAULT_TIMEOUT_MS
|
||||||
model_baudrate_table = deepcopy(MODEL_BAUDRATE_TABLE)
|
model_baudrate_table = deepcopy(MODEL_BAUDRATE_TABLE)
|
||||||
model_ctrl_table = deepcopy(MODEL_CONTROL_TABLE)
|
model_ctrl_table = deepcopy(MODEL_CONTROL_TABLE)
|
||||||
|
@ -119,19 +137,70 @@ class DynamixelMotorsBus(MotorsBus):
|
||||||
def _assert_protocol_is_compatible(self, instruction_name: str) -> None:
|
def _assert_protocol_is_compatible(self, instruction_name: str) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def _handshake(self) -> None:
|
||||||
|
self._assert_motors_exist()
|
||||||
|
|
||||||
|
def _find_single_motor(self, motor: str, initial_baudrate: int | None) -> tuple[int, int]:
|
||||||
|
model = self.motors[motor].model
|
||||||
|
search_baudrates = (
|
||||||
|
[initial_baudrate] if initial_baudrate is not None else self.model_baudrate_table[model]
|
||||||
|
)
|
||||||
|
|
||||||
|
for baudrate in search_baudrates:
|
||||||
|
self.set_baudrate(baudrate)
|
||||||
|
id_model = self.broadcast_ping()
|
||||||
|
if id_model:
|
||||||
|
found_id, found_model = next(iter(id_model.items()))
|
||||||
|
expected_model_nb = self.model_number_table[model]
|
||||||
|
if found_model != expected_model_nb:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Found one motor on {baudrate=} with id={found_id} but it has a "
|
||||||
|
f"model number '{found_model}' different than the one expected: '{expected_model_nb}' "
|
||||||
|
f"Make sure you are connected only connected to the '{motor}' motor (model '{model}')."
|
||||||
|
)
|
||||||
|
return baudrate, found_id
|
||||||
|
|
||||||
|
raise RuntimeError(f"Motor '{motor}' (model '{model}') was not found. Make sure it is connected.")
|
||||||
|
|
||||||
def configure_motors(self) -> None:
|
def configure_motors(self) -> None:
|
||||||
# By default, Dynamixel motors have a 500µs delay response time (corresponding to a value of 250 on
|
# By default, Dynamixel motors have a 500µs delay response time (corresponding to a value of 250 on
|
||||||
# the 'Return_Delay_Time' address). We ensure this is reduced to the minimum of 2µs (value of 0).
|
# the 'Return_Delay_Time' address). We ensure this is reduced to the minimum of 2µs (value of 0).
|
||||||
for id_ in self.ids:
|
for motor in self.motors:
|
||||||
self.write("Return_Delay_Time", id_, 0)
|
self.write("Return_Delay_Time", motor, 0)
|
||||||
|
|
||||||
def disable_torque(self, motors: str | list[str] | None = None) -> None:
|
def read_calibration(self) -> dict[str, MotorCalibration]:
|
||||||
for name in self._get_names_list(motors):
|
offsets = self.sync_read("Homing_Offset", normalize=False)
|
||||||
self.write("Torque_Enable", name, TorqueMode.DISABLED.value)
|
mins = self.sync_read("Min_Position_Limit", normalize=False)
|
||||||
|
maxes = self.sync_read("Max_Position_Limit", normalize=False)
|
||||||
|
drive_modes = self.sync_read("Drive_Mode", normalize=False)
|
||||||
|
|
||||||
def enable_torque(self, motors: str | list[str] | None = None) -> None:
|
calibration = {}
|
||||||
for name in self._get_names_list(motors):
|
for name, motor in self.motors.items():
|
||||||
self.write("Torque_Enable", name, TorqueMode.ENABLED.value)
|
calibration[name] = MotorCalibration(
|
||||||
|
id=motor.id,
|
||||||
|
drive_mode=drive_modes[name],
|
||||||
|
homing_offset=offsets[name],
|
||||||
|
range_min=mins[name],
|
||||||
|
range_max=maxes[name],
|
||||||
|
)
|
||||||
|
|
||||||
|
return calibration
|
||||||
|
|
||||||
|
def write_calibration(self, calibration_dict: dict[str, MotorCalibration]) -> None:
|
||||||
|
for motor, calibration in calibration_dict.items():
|
||||||
|
self.write("Homing_Offset", motor, calibration.homing_offset)
|
||||||
|
self.write("Min_Position_Limit", motor, calibration.range_min)
|
||||||
|
self.write("Max_Position_Limit", motor, calibration.range_max)
|
||||||
|
|
||||||
|
self.calibration = calibration_dict
|
||||||
|
|
||||||
|
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||||
|
for name in self._get_motors_list(motors):
|
||||||
|
self.write("Torque_Enable", name, TorqueMode.DISABLED.value, num_retry=num_retry)
|
||||||
|
|
||||||
|
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||||
|
for name in self._get_motors_list(motors):
|
||||||
|
self.write("Torque_Enable", name, TorqueMode.ENABLED.value, num_retry=num_retry)
|
||||||
|
|
||||||
def _encode_sign(self, data_name: str, ids_values: dict[int, int]) -> dict[int, int]:
|
def _encode_sign(self, data_name: str, ids_values: dict[int, int]) -> dict[int, int]:
|
||||||
for id_ in ids_values:
|
for id_ in ids_values:
|
||||||
|
@ -166,22 +235,8 @@ class DynamixelMotorsBus(MotorsBus):
|
||||||
|
|
||||||
return half_turn_homings
|
return half_turn_homings
|
||||||
|
|
||||||
@staticmethod
|
def _split_into_byte_chunks(self, value: int, length: int) -> list[int]:
|
||||||
def _split_into_byte_chunks(value: int, n_bytes: int) -> list[int]:
|
return _split_into_byte_chunks(value, length)
|
||||||
import dynamixel_sdk as dxl
|
|
||||||
|
|
||||||
if n_bytes == 1:
|
|
||||||
data = [value]
|
|
||||||
elif n_bytes == 2:
|
|
||||||
data = [dxl.DXL_LOBYTE(value), dxl.DXL_HIBYTE(value)]
|
|
||||||
elif n_bytes == 4:
|
|
||||||
data = [
|
|
||||||
dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)),
|
|
||||||
dxl.DXL_HIBYTE(dxl.DXL_LOWORD(value)),
|
|
||||||
dxl.DXL_LOBYTE(dxl.DXL_HIWORD(value)),
|
|
||||||
dxl.DXL_HIBYTE(dxl.DXL_HIWORD(value)),
|
|
||||||
]
|
|
||||||
return data
|
|
||||||
|
|
||||||
def broadcast_ping(self, num_retry: int = 0, raise_on_error: bool = False) -> dict[int, int] | None:
|
def broadcast_ping(self, num_retry: int = 0, raise_on_error: bool = False) -> dict[int, int] | None:
|
||||||
for n_try in range(1 + num_retry):
|
for n_try in range(1 + num_retry):
|
||||||
|
|
|
@ -1,3 +1,24 @@
|
||||||
|
# TODO(Steven): Consider doing the following:
|
||||||
|
# from enum import Enum
|
||||||
|
# class MyControlTableKey(Enum):
|
||||||
|
# ID = "ID"
|
||||||
|
# GOAL_SPEED = "Goal_Speed"
|
||||||
|
# ...
|
||||||
|
#
|
||||||
|
# MY_CONTROL_TABLE ={
|
||||||
|
# MyControlTableKey.ID.value: (5,1)
|
||||||
|
# MyControlTableKey.GOAL_SPEED.value: (46, 2)
|
||||||
|
# ...
|
||||||
|
# }
|
||||||
|
# This allows me do to:
|
||||||
|
# bus.write(MyControlTableKey.GOAL_SPEED, ...)
|
||||||
|
# Instead of:
|
||||||
|
# bus.write("Goal_Speed", ...)
|
||||||
|
# This is important for two reasons:
|
||||||
|
# 1. The linter will tell me if I'm trying to use an invalid key, instead of me realizing when I get the RunTimeError
|
||||||
|
# 2. We can change the value of the MyControlTableKey enums without impacting the client code
|
||||||
|
|
||||||
|
|
||||||
# {data_name: (address, size_byte)}
|
# {data_name: (address, size_byte)}
|
||||||
# https://emanual.robotis.com/docs/en/dxl/x/{MODEL}/#control-table
|
# https://emanual.robotis.com/docs/en/dxl/x/{MODEL}/#control-table
|
||||||
X_SERIES_CONTROL_TABLE = {
|
X_SERIES_CONTROL_TABLE = {
|
||||||
|
@ -57,13 +78,13 @@ X_SERIES_CONTROL_TABLE = {
|
||||||
|
|
||||||
# https://emanual.robotis.com/docs/en/dxl/x/{MODEL}/#baud-rate8
|
# https://emanual.robotis.com/docs/en/dxl/x/{MODEL}/#baud-rate8
|
||||||
X_SERIES_BAUDRATE_TABLE = {
|
X_SERIES_BAUDRATE_TABLE = {
|
||||||
0: 9_600,
|
9_600: 0,
|
||||||
1: 57_600,
|
57_600: 1,
|
||||||
2: 115_200,
|
115_200: 2,
|
||||||
3: 1_000_000,
|
1_000_000: 3,
|
||||||
4: 2_000_000,
|
2_000_000: 4,
|
||||||
5: 3_000_000,
|
3_000_000: 5,
|
||||||
6: 4_000_000,
|
4_000_000: 6,
|
||||||
}
|
}
|
||||||
|
|
||||||
# {data_name: size_byte}
|
# {data_name: size_byte}
|
||||||
|
|
|
@ -21,18 +21,20 @@ from lerobot.common.utils.encoding_utils import decode_sign_magnitude, encode_si
|
||||||
|
|
||||||
from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value
|
from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value
|
||||||
from .tables import (
|
from .tables import (
|
||||||
FIRMWARE_VERSION,
|
FIRMWARE_MAJOR_VERSION,
|
||||||
|
FIRMWARE_MINOR_VERSION,
|
||||||
MODEL_BAUDRATE_TABLE,
|
MODEL_BAUDRATE_TABLE,
|
||||||
MODEL_CONTROL_TABLE,
|
MODEL_CONTROL_TABLE,
|
||||||
MODEL_ENCODING_TABLE,
|
MODEL_ENCODING_TABLE,
|
||||||
MODEL_NUMBER,
|
MODEL_NUMBER,
|
||||||
MODEL_NUMBER_TABLE,
|
MODEL_NUMBER_TABLE,
|
||||||
|
MODEL_PROTOCOL,
|
||||||
MODEL_RESOLUTION,
|
MODEL_RESOLUTION,
|
||||||
SCAN_BAUDRATES,
|
SCAN_BAUDRATES,
|
||||||
)
|
)
|
||||||
|
|
||||||
DEFAULT_PROTOCOL_VERSION = 0
|
DEFAULT_PROTOCOL_VERSION = 0
|
||||||
BAUDRATE = 1_000_000
|
DEFAULT_BAUDRATE = 1_000_000
|
||||||
DEFAULT_TIMEOUT_MS = 1000
|
DEFAULT_TIMEOUT_MS = 1000
|
||||||
|
|
||||||
NORMALIZED_DATA = ["Goal_Position", "Present_Position"]
|
NORMALIZED_DATA = ["Goal_Position", "Present_Position"]
|
||||||
|
@ -64,6 +66,23 @@ class TorqueMode(Enum):
|
||||||
DISABLED = 0
|
DISABLED = 0
|
||||||
|
|
||||||
|
|
||||||
|
def _split_into_byte_chunks(value: int, length: int) -> list[int]:
|
||||||
|
import scservo_sdk as scs
|
||||||
|
|
||||||
|
if length == 1:
|
||||||
|
data = [value]
|
||||||
|
elif length == 2:
|
||||||
|
data = [scs.SCS_LOBYTE(value), scs.SCS_HIBYTE(value)]
|
||||||
|
elif length == 4:
|
||||||
|
data = [
|
||||||
|
scs.SCS_LOBYTE(scs.SCS_LOWORD(value)),
|
||||||
|
scs.SCS_HIBYTE(scs.SCS_LOWORD(value)),
|
||||||
|
scs.SCS_LOBYTE(scs.SCS_HIWORD(value)),
|
||||||
|
scs.SCS_HIBYTE(scs.SCS_HIWORD(value)),
|
||||||
|
]
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
def patch_setPacketTimeout(self, packet_length): # noqa: N802
|
def patch_setPacketTimeout(self, packet_length): # noqa: N802
|
||||||
"""
|
"""
|
||||||
HACK: This patches the PortHandler behavior to set the correct packet timeouts.
|
HACK: This patches the PortHandler behavior to set the correct packet timeouts.
|
||||||
|
@ -84,6 +103,7 @@ class FeetechMotorsBus(MotorsBus):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
available_baudrates = deepcopy(SCAN_BAUDRATES)
|
available_baudrates = deepcopy(SCAN_BAUDRATES)
|
||||||
|
default_baudrate = DEFAULT_BAUDRATE
|
||||||
default_timeout = DEFAULT_TIMEOUT_MS
|
default_timeout = DEFAULT_TIMEOUT_MS
|
||||||
model_baudrate_table = deepcopy(MODEL_BAUDRATE_TABLE)
|
model_baudrate_table = deepcopy(MODEL_BAUDRATE_TABLE)
|
||||||
model_ctrl_table = deepcopy(MODEL_CONTROL_TABLE)
|
model_ctrl_table = deepcopy(MODEL_CONTROL_TABLE)
|
||||||
|
@ -100,9 +120,10 @@ class FeetechMotorsBus(MotorsBus):
|
||||||
protocol_version: int = DEFAULT_PROTOCOL_VERSION,
|
protocol_version: int = DEFAULT_PROTOCOL_VERSION,
|
||||||
):
|
):
|
||||||
super().__init__(port, motors, calibration)
|
super().__init__(port, motors, calibration)
|
||||||
|
self.protocol_version = protocol_version
|
||||||
|
self._assert_same_protocol()
|
||||||
import scservo_sdk as scs
|
import scservo_sdk as scs
|
||||||
|
|
||||||
self.protocol_version = protocol_version
|
|
||||||
self.port_handler = scs.PortHandler(self.port)
|
self.port_handler = scs.PortHandler(self.port)
|
||||||
# HACK: monkeypatch
|
# HACK: monkeypatch
|
||||||
self.port_handler.setPacketTimeout = patch_setPacketTimeout.__get__(
|
self.port_handler.setPacketTimeout = patch_setPacketTimeout.__get__(
|
||||||
|
@ -114,17 +135,132 @@ class FeetechMotorsBus(MotorsBus):
|
||||||
self._comm_success = scs.COMM_SUCCESS
|
self._comm_success = scs.COMM_SUCCESS
|
||||||
self._no_error = 0x00
|
self._no_error = 0x00
|
||||||
|
|
||||||
|
if any(MODEL_PROTOCOL[model] != self.protocol_version for model in self.models):
|
||||||
|
raise ValueError(f"Some motors are incompatible with protocol_version={self.protocol_version}")
|
||||||
|
|
||||||
|
def _assert_same_protocol(self) -> None:
|
||||||
|
if any(MODEL_PROTOCOL[model] != self.protocol_version for model in self.models):
|
||||||
|
raise RuntimeError("Some motors use an incompatible protocol.")
|
||||||
|
|
||||||
def _assert_protocol_is_compatible(self, instruction_name: str) -> None:
|
def _assert_protocol_is_compatible(self, instruction_name: str) -> None:
|
||||||
if instruction_name == "sync_read" and self.protocol_version == 1:
|
if instruction_name == "sync_read" and self.protocol_version == 1:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"'Sync Read' is not available with Feetech motors using Protocol 1. Use 'Read' instead."
|
"'Sync Read' is not available with Feetech motors using Protocol 1. Use 'Read' sequentially instead."
|
||||||
|
)
|
||||||
|
if instruction_name == "broadcast_ping" and self.protocol_version == 1:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"'Broadcast Ping' is not available with Feetech motors using Protocol 1. Use 'Ping' sequentially instead."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _assert_same_firmware(self) -> None:
|
||||||
|
firmware_versions = self._read_firmware_version(self.ids)
|
||||||
|
if len(set(firmware_versions.values())) != 1:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Some Motors use different firmware versions. Update their firmware first using Feetech's software. "
|
||||||
|
"Visit https://www.feetechrc.com/software."
|
||||||
|
)
|
||||||
|
|
||||||
|
def _handshake(self) -> None:
|
||||||
|
self._assert_motors_exist()
|
||||||
|
self._assert_same_firmware()
|
||||||
|
|
||||||
|
def _find_single_motor(self, motor: str, initial_baudrate: int | None = None) -> tuple[int, int]:
|
||||||
|
if self.protocol_version == 0:
|
||||||
|
return self._find_single_motor_p0(motor, initial_baudrate)
|
||||||
|
else:
|
||||||
|
return self._find_single_motor_p1(motor, initial_baudrate)
|
||||||
|
|
||||||
|
def _find_single_motor_p0(self, motor: str, initial_baudrate: int | None = None) -> tuple[int, int]:
|
||||||
|
model = self.motors[motor].model
|
||||||
|
search_baudrates = (
|
||||||
|
[initial_baudrate] if initial_baudrate is not None else self.model_baudrate_table[model]
|
||||||
|
)
|
||||||
|
expected_model_nb = self.model_number_table[model]
|
||||||
|
|
||||||
|
for baudrate in search_baudrates:
|
||||||
|
self.set_baudrate(baudrate)
|
||||||
|
id_model = self.broadcast_ping()
|
||||||
|
if id_model:
|
||||||
|
found_id, found_model = next(iter(id_model.items()))
|
||||||
|
if found_model != expected_model_nb:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Found one motor on {baudrate=} with id={found_id} but it has a "
|
||||||
|
f"model number '{found_model}' different than the one expected: '{expected_model_nb}' "
|
||||||
|
f"Make sure you are connected only connected to the '{motor}' motor (model '{model}')."
|
||||||
|
)
|
||||||
|
return baudrate, found_id
|
||||||
|
|
||||||
|
raise RuntimeError(f"Motor '{motor}' (model '{model}') was not found. Make sure it is connected.")
|
||||||
|
|
||||||
|
def _find_single_motor_p1(self, motor: str, initial_baudrate: int | None = None) -> tuple[int, int]:
|
||||||
|
import scservo_sdk as scs
|
||||||
|
|
||||||
|
model = self.motors[motor].model
|
||||||
|
search_baudrates = (
|
||||||
|
[initial_baudrate] if initial_baudrate is not None else self.model_baudrate_table[model]
|
||||||
|
)
|
||||||
|
expected_model_nb = self.model_number_table[model]
|
||||||
|
|
||||||
|
for baudrate in search_baudrates:
|
||||||
|
self.set_baudrate(baudrate)
|
||||||
|
for id_ in range(scs.MAX_ID + 1):
|
||||||
|
found_model = self.ping(id_)
|
||||||
|
if found_model is not None and found_model != expected_model_nb:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Found one motor on {baudrate=} with id={id_} but it has a "
|
||||||
|
f"model number '{found_model}' different than the one expected: '{expected_model_nb}' "
|
||||||
|
f"Make sure you are connected only connected to the '{motor}' motor (model '{model}')."
|
||||||
|
)
|
||||||
|
return baudrate, id_
|
||||||
|
|
||||||
|
raise RuntimeError(f"Motor '{motor}' (model '{model}') was not found. Make sure it is connected.")
|
||||||
|
|
||||||
def configure_motors(self) -> None:
|
def configure_motors(self) -> None:
|
||||||
# By default, Feetech motors have a 500µs delay response time (corresponding to a value of 250 on the
|
for motor in self.motors:
|
||||||
# 'Return_Delay' address). We ensure this is reduced to the minimum of 2µs (value of 0).
|
# By default, Feetech motors have a 500µs delay response time (corresponding to a value of 250 on
|
||||||
for id_ in self.ids:
|
# the 'Return_Delay_Time' address). We ensure this is reduced to the minimum of 2µs (value of 0).
|
||||||
self.write("Return_Delay_Time", id_, 0)
|
self.write("Return_Delay_Time", motor, 0)
|
||||||
|
# Set 'Maximum_Acceleration' to 254 to speedup acceleration and deceleration of the motors.
|
||||||
|
# Note: this address is not in the official STS3215 Memory Table
|
||||||
|
self.write("Maximum_Acceleration", motor, 254)
|
||||||
|
self.write("Acceleration", motor, 254)
|
||||||
|
|
||||||
|
def read_calibration(self) -> dict[str, MotorCalibration]:
|
||||||
|
if self.protocol_version == 0:
|
||||||
|
offsets = self.sync_read("Homing_Offset", normalize=False)
|
||||||
|
mins = self.sync_read("Min_Position_Limit", normalize=False)
|
||||||
|
maxes = self.sync_read("Max_Position_Limit", normalize=False)
|
||||||
|
drive_modes = dict.fromkeys(self.motors, 0)
|
||||||
|
else:
|
||||||
|
offsets, mins, maxes, drive_modes = {}, {}, {}, {}
|
||||||
|
for motor in self.motors:
|
||||||
|
offsets[motor] = 0
|
||||||
|
mins[motor] = self.read("Min_Position_Limit", motor, normalize=False)
|
||||||
|
maxes[motor] = self.read("Max_Position_Limit", motor, normalize=False)
|
||||||
|
drive_modes[motor] = 0
|
||||||
|
|
||||||
|
# TODO(aliberts): add set/get_drive_mode?
|
||||||
|
|
||||||
|
calibration = {}
|
||||||
|
for name, motor in self.motors.items():
|
||||||
|
calibration[name] = MotorCalibration(
|
||||||
|
id=motor.id,
|
||||||
|
drive_mode=drive_modes[name],
|
||||||
|
homing_offset=offsets[name],
|
||||||
|
range_min=mins[name],
|
||||||
|
range_max=maxes[name],
|
||||||
|
)
|
||||||
|
|
||||||
|
return calibration
|
||||||
|
|
||||||
|
def write_calibration(self, calibration_dict: dict[str, MotorCalibration]) -> None:
|
||||||
|
for motor, calibration in calibration_dict.items():
|
||||||
|
if self.protocol_version == 0:
|
||||||
|
self.write("Homing_Offset", motor, calibration.homing_offset)
|
||||||
|
self.write("Min_Position_Limit", motor, calibration.range_min)
|
||||||
|
self.write("Max_Position_Limit", motor, calibration.range_max)
|
||||||
|
|
||||||
|
self.calibration = calibration_dict
|
||||||
|
|
||||||
def _get_half_turn_homings(self, positions: dict[NameOrID, Value]) -> dict[NameOrID, Value]:
|
def _get_half_turn_homings(self, positions: dict[NameOrID, Value]) -> dict[NameOrID, Value]:
|
||||||
"""
|
"""
|
||||||
|
@ -139,15 +275,15 @@ class FeetechMotorsBus(MotorsBus):
|
||||||
|
|
||||||
return half_turn_homings
|
return half_turn_homings
|
||||||
|
|
||||||
def disable_torque(self, motors: str | list[str] | None = None) -> None:
|
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||||
for name in self._get_names_list(motors):
|
for name in self._get_motors_list(motors):
|
||||||
self.write("Torque_Enable", name, TorqueMode.DISABLED.value)
|
self.write("Torque_Enable", name, TorqueMode.DISABLED.value, num_retry=num_retry)
|
||||||
self.write("Lock", name, 0)
|
self.write("Lock", name, 0, num_retry=num_retry)
|
||||||
|
|
||||||
def enable_torque(self, motors: str | list[str] | None = None) -> None:
|
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||||
for name in self._get_names_list(motors):
|
for name in self._get_motors_list(motors):
|
||||||
self.write("Torque_Enable", name, TorqueMode.ENABLED.value)
|
self.write("Torque_Enable", name, TorqueMode.ENABLED.value, num_retry=num_retry)
|
||||||
self.write("Lock", name, 1)
|
self.write("Lock", name, 1, num_retry=num_retry)
|
||||||
|
|
||||||
def _encode_sign(self, data_name: str, ids_values: dict[int, int]) -> dict[int, int]:
|
def _encode_sign(self, data_name: str, ids_values: dict[int, int]) -> dict[int, int]:
|
||||||
for id_ in ids_values:
|
for id_ in ids_values:
|
||||||
|
@ -169,40 +305,10 @@ class FeetechMotorsBus(MotorsBus):
|
||||||
|
|
||||||
return ids_values
|
return ids_values
|
||||||
|
|
||||||
@staticmethod
|
def _split_into_byte_chunks(self, value: int, length: int) -> list[int]:
|
||||||
def _split_into_byte_chunks(value: int, n_bytes: int) -> list[int]:
|
return _split_into_byte_chunks(value, length)
|
||||||
import scservo_sdk as scs
|
|
||||||
|
|
||||||
if n_bytes == 1:
|
def _broadcast_ping(self) -> tuple[dict[int, int], int]:
|
||||||
data = [value]
|
|
||||||
elif n_bytes == 2:
|
|
||||||
data = [scs.SCS_LOBYTE(value), scs.SCS_HIBYTE(value)]
|
|
||||||
elif n_bytes == 4:
|
|
||||||
data = [
|
|
||||||
scs.SCS_LOBYTE(scs.SCS_LOWORD(value)),
|
|
||||||
scs.SCS_HIBYTE(scs.SCS_LOWORD(value)),
|
|
||||||
scs.SCS_LOBYTE(scs.SCS_HIWORD(value)),
|
|
||||||
scs.SCS_HIBYTE(scs.SCS_HIWORD(value)),
|
|
||||||
]
|
|
||||||
return data
|
|
||||||
|
|
||||||
def _broadcast_ping_p1(self, known_motors_only: bool = True, num_retry: int = 0) -> dict[int, int]:
|
|
||||||
if known_motors_only:
|
|
||||||
ids = self.ids
|
|
||||||
else:
|
|
||||||
import scservo_sdk as scs
|
|
||||||
|
|
||||||
ids = range(scs.MAX_ID + 1)
|
|
||||||
|
|
||||||
ids_models = {}
|
|
||||||
for id_ in ids:
|
|
||||||
model_number = self.ping(id_, num_retry)
|
|
||||||
if model_number is not None:
|
|
||||||
ids_models[id_] = model_number
|
|
||||||
|
|
||||||
return ids_models
|
|
||||||
|
|
||||||
def _broadcast_ping_p0(self) -> tuple[dict[int, int], int]:
|
|
||||||
import scservo_sdk as scs
|
import scservo_sdk as scs
|
||||||
|
|
||||||
data_list = {}
|
data_list = {}
|
||||||
|
@ -277,9 +383,9 @@ class FeetechMotorsBus(MotorsBus):
|
||||||
rx_length = rx_length - idx
|
rx_length = rx_length - idx
|
||||||
|
|
||||||
def broadcast_ping(self, num_retry: int = 0, raise_on_error: bool = False) -> dict[int, int] | None:
|
def broadcast_ping(self, num_retry: int = 0, raise_on_error: bool = False) -> dict[int, int] | None:
|
||||||
if self.protocol_version == 0:
|
self._assert_protocol_is_compatible("broadcast_ping")
|
||||||
for n_try in range(1 + num_retry):
|
for n_try in range(1 + num_retry):
|
||||||
ids_status, comm = self._broadcast_ping_p0()
|
ids_status, comm = self._broadcast_ping()
|
||||||
if self._is_comm_success(comm):
|
if self._is_comm_success(comm):
|
||||||
break
|
break
|
||||||
logger.debug(f"Broadcast ping failed on port '{self.port}' ({n_try=})")
|
logger.debug(f"Broadcast ping failed on port '{self.port}' ({n_try=})")
|
||||||
|
@ -292,68 +398,37 @@ class FeetechMotorsBus(MotorsBus):
|
||||||
|
|
||||||
ids_errors = {id_: status for id_, status in ids_status.items() if self._is_error(status)}
|
ids_errors = {id_: status for id_, status in ids_status.items() if self._is_error(status)}
|
||||||
if ids_errors:
|
if ids_errors:
|
||||||
display_dict = {
|
display_dict = {id_: self.packet_handler.getRxPacketError(err) for id_, err in ids_errors.items()}
|
||||||
id_: self.packet_handler.getRxPacketError(err) for id_, err in ids_errors.items()
|
logger.error(f"Some motors found returned an error status:\n{pformat(display_dict, indent=4)}")
|
||||||
}
|
|
||||||
logger.error(
|
return self._read_model_number(list(ids_status), raise_on_error)
|
||||||
f"Some motors found returned an error status:\n{pformat(display_dict, indent=4)}"
|
|
||||||
|
def _read_firmware_version(self, motor_ids: list[int], raise_on_error: bool = False) -> dict[int, str]:
|
||||||
|
firmware_versions = {}
|
||||||
|
for id_ in motor_ids:
|
||||||
|
firm_ver_major, comm, error = self._read(
|
||||||
|
*FIRMWARE_MAJOR_VERSION, id_, raise_on_error=raise_on_error
|
||||||
)
|
)
|
||||||
|
if not self._is_comm_success(comm) or self._is_error(error):
|
||||||
return self._get_model_number(list(ids_status), raise_on_error)
|
|
||||||
else:
|
|
||||||
return self._broadcast_ping_p1(num_retry=num_retry)
|
|
||||||
|
|
||||||
def _get_firmware_version(self, motor_ids: list[int], raise_on_error: bool = False) -> dict[int, int]:
|
|
||||||
# comm, major = self._sync_read(*FIRMWARE_MAJOR_VERSION, motor_ids)
|
|
||||||
# if not self._is_comm_success(comm):
|
|
||||||
# if raise_on_error:
|
|
||||||
# raise ConnectionError(self.packet_handler.getTxRxResult(comm))
|
|
||||||
# return
|
|
||||||
|
|
||||||
# comm, minor = self._sync_read(*FIRMWARE_MINOR_VERSION, motor_ids)
|
|
||||||
# if not self._is_comm_success(comm):
|
|
||||||
# if raise_on_error:
|
|
||||||
# raise ConnectionError(self.packet_handler.getTxRxResult(comm))
|
|
||||||
# return
|
|
||||||
|
|
||||||
# return {id_: f"{major[id_]}.{minor[id_]}" for id_ in motor_ids}
|
|
||||||
|
|
||||||
comm, firmware_versions = self._sync_read(*FIRMWARE_VERSION, motor_ids)
|
|
||||||
if not self._is_comm_success(comm):
|
|
||||||
if raise_on_error:
|
|
||||||
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
firm_ver_minor, comm, error = self._read(
|
||||||
|
*FIRMWARE_MINOR_VERSION, id_, raise_on_error=raise_on_error
|
||||||
|
)
|
||||||
|
if not self._is_comm_success(comm) or self._is_error(error):
|
||||||
|
return
|
||||||
|
|
||||||
|
firmware_versions[id_] = f"{firm_ver_major}.{firm_ver_minor}"
|
||||||
|
|
||||||
return firmware_versions
|
return firmware_versions
|
||||||
|
|
||||||
def _get_model_number(self, motor_ids: list[int], raise_on_error: bool = False) -> dict[int, int]:
|
def _read_model_number(self, motor_ids: list[int], raise_on_error: bool = False) -> dict[int, int]:
|
||||||
# comm, major = self._sync_read(*MODEL_MAJOR_VERSION, motor_ids)
|
|
||||||
# if not self._is_comm_success(comm):
|
|
||||||
# if raise_on_error:
|
|
||||||
# raise ConnectionError(self.packet_handler.getTxRxResult(comm))
|
|
||||||
# return
|
|
||||||
|
|
||||||
# comm, minor = self._sync_read(*MODEL_MINOR_VERSION, motor_ids)
|
|
||||||
# if not self._is_comm_success(comm):
|
|
||||||
# if raise_on_error:
|
|
||||||
# raise ConnectionError(self.packet_handler.getTxRxResult(comm))
|
|
||||||
# return
|
|
||||||
|
|
||||||
# return {id_: f"{major[id_]}.{minor[id_]}" for id_ in motor_ids}
|
|
||||||
if self.protocol_version == 1:
|
|
||||||
model_numbers = {}
|
model_numbers = {}
|
||||||
for id_ in motor_ids:
|
for id_ in motor_ids:
|
||||||
model_nb, comm, error = self._read(*MODEL_NUMBER, id_)
|
model_nb, comm, error = self._read(*MODEL_NUMBER, id_, raise_on_error=raise_on_error)
|
||||||
if self._is_comm_success(comm) and not self._is_error(error):
|
if not self._is_comm_success(comm) or self._is_error(error):
|
||||||
model_numbers[id_] = model_nb
|
|
||||||
elif raise_on_error:
|
|
||||||
raise Exception # FIX
|
|
||||||
|
|
||||||
else:
|
|
||||||
comm, model_numbers = self._sync_read(*MODEL_NUMBER, motor_ids)
|
|
||||||
if not self._is_comm_success(comm):
|
|
||||||
if raise_on_error:
|
|
||||||
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
model_numbers[id_] = model_nb
|
||||||
|
|
||||||
return model_numbers
|
return model_numbers
|
||||||
|
|
|
@ -1,22 +1,34 @@
|
||||||
FIRMWARE_MAJOR_VERSION = (0, 1)
|
FIRMWARE_MAJOR_VERSION = (0, 1)
|
||||||
FIRMWARE_MINOR_VERSION = (1, 1)
|
FIRMWARE_MINOR_VERSION = (1, 1)
|
||||||
MODEL_MAJOR_VERSION = (3, 1)
|
|
||||||
MODEL_MINOR_VERSION = (4, 1)
|
|
||||||
|
|
||||||
FIRMWARE_VERSION = (0, 2)
|
|
||||||
MODEL_NUMBER = (3, 2)
|
MODEL_NUMBER = (3, 2)
|
||||||
|
|
||||||
# See this link for STS3215 Memory Table:
|
# TODO(Steven): Consider doing the following:
|
||||||
# https://docs.google.com/spreadsheets/d/1GVs7W1VS1PqdhA1nW-abeyAHhTUxKUdR/edit?usp=sharing&ouid=116566590112741600240&rtpof=true&sd=true
|
# from enum import Enum
|
||||||
|
# class MyControlTableKey(Enum):
|
||||||
|
# ID = "ID"
|
||||||
|
# GOAL_SPEED = "Goal_Speed"
|
||||||
|
# ...
|
||||||
|
#
|
||||||
|
# MY_CONTROL_TABLE ={
|
||||||
|
# MyControlTableKey.ID.value: (5,1)
|
||||||
|
# MyControlTableKey.GOAL_SPEED.value: (46, 2)
|
||||||
|
# ...
|
||||||
|
# }
|
||||||
|
# This allows me do to:
|
||||||
|
# bus.write(MyControlTableKey.GOAL_SPEED, ...)
|
||||||
|
# Instead of:
|
||||||
|
# bus.write("Goal_Speed", ...)
|
||||||
|
# This is important for two reasons:
|
||||||
|
# 1. The linter will tell me if I'm trying to use an invalid key, instead of me realizing when I get the RunTimeError
|
||||||
|
# 2. We can change the value of the MyControlTableKey enums without impacting the client code
|
||||||
|
|
||||||
# data_name: (address, size_byte)
|
# data_name: (address, size_byte)
|
||||||
|
# http://doc.feetech.cn/#/prodinfodownload?srcType=FT-SMS-STS-emanual-229f4476422d4059abfb1cb0
|
||||||
STS_SMS_SERIES_CONTROL_TABLE = {
|
STS_SMS_SERIES_CONTROL_TABLE = {
|
||||||
# EPROM
|
# EPROM
|
||||||
"Firmware_Version": FIRMWARE_VERSION, # read-only
|
"Firmware_Major_Version": FIRMWARE_MAJOR_VERSION, # read-only
|
||||||
|
"Firmware_Minor_Version": FIRMWARE_MINOR_VERSION, # read-only
|
||||||
"Model_Number": MODEL_NUMBER, # read-only
|
"Model_Number": MODEL_NUMBER, # read-only
|
||||||
# "Firmware_Major_Version": FIRMWARE_MAJOR_VERSION, # read-only
|
|
||||||
# "Firmware_Minor_Version": FIRMWARE_MINOR_VERSION, # read-only
|
|
||||||
# "Model_Major_Version": MODEL_MAJOR_VERSION, # read-only
|
|
||||||
# "Model_Minor_Version": MODEL_MINOR_VERSION,
|
|
||||||
"ID": (5, 1),
|
"ID": (5, 1),
|
||||||
"Baud_Rate": (6, 1),
|
"Baud_Rate": (6, 1),
|
||||||
"Return_Delay_Time": (7, 1),
|
"Return_Delay_Time": (7, 1),
|
||||||
|
@ -43,7 +55,7 @@ STS_SMS_SERIES_CONTROL_TABLE = {
|
||||||
"Protective_Torque": (34, 1),
|
"Protective_Torque": (34, 1),
|
||||||
"Protection_Time": (35, 1),
|
"Protection_Time": (35, 1),
|
||||||
"Overload_Torque": (36, 1),
|
"Overload_Torque": (36, 1),
|
||||||
"Speed_closed_loop_P_proportional_coefficient": (37, 1),
|
"Velocity_closed_loop_P_proportional_coefficient": (37, 1),
|
||||||
"Over_Current_Protection_Time": (38, 1),
|
"Over_Current_Protection_Time": (38, 1),
|
||||||
"Velocity_closed_loop_I_integral_coefficient": (39, 1),
|
"Velocity_closed_loop_I_integral_coefficient": (39, 1),
|
||||||
# SRAM
|
# SRAM
|
||||||
|
@ -51,32 +63,38 @@ STS_SMS_SERIES_CONTROL_TABLE = {
|
||||||
"Acceleration": (41, 1),
|
"Acceleration": (41, 1),
|
||||||
"Goal_Position": (42, 2),
|
"Goal_Position": (42, 2),
|
||||||
"Goal_Time": (44, 2),
|
"Goal_Time": (44, 2),
|
||||||
"Goal_Speed": (46, 2),
|
"Goal_Velocity": (46, 2),
|
||||||
"Torque_Limit": (48, 2),
|
"Torque_Limit": (48, 2),
|
||||||
"Lock": (55, 1),
|
"Lock": (55, 1),
|
||||||
"Present_Position": (56, 2), # read-only
|
"Present_Position": (56, 2), # read-only
|
||||||
"Present_Speed": (58, 2), # read-only
|
"Present_Velocity": (58, 2), # read-only
|
||||||
"Present_Load": (60, 2), # read-only
|
"Present_Load": (60, 2), # read-only
|
||||||
"Present_Voltage": (62, 1), # read-only
|
"Present_Voltage": (62, 1), # read-only
|
||||||
"Present_Temperature": (63, 1), # read-only
|
"Present_Temperature": (63, 1), # read-only
|
||||||
"Status": (65, 1), # read-only
|
"Status": (65, 1), # read-only
|
||||||
"Moving": (66, 1), # read-only
|
"Moving": (66, 1), # read-only
|
||||||
"Present_Current": (69, 2), # read-only
|
"Present_Current": (69, 2), # read-only
|
||||||
# Not in the Memory Table
|
"Goal_Position_2": (71, 2), # read-only
|
||||||
"Maximum_Acceleration": (85, 2),
|
# Factory
|
||||||
|
"Moving_Velocity": (80, 1),
|
||||||
|
"Moving_Velocity_Threshold": (80, 1),
|
||||||
|
"DTs": (81, 1), # (ms)
|
||||||
|
"Velocity_Unit_factor": (82, 1),
|
||||||
|
"Hts": (83, 1), # (ns) valid for firmware >= 2.54, other versions keep 0
|
||||||
|
"Maximum_Velocity_Limit": (84, 1),
|
||||||
|
"Maximum_Acceleration": (85, 1),
|
||||||
|
"Acceleration_Multiplier ": (86, 1), # Acceleration multiplier in effect when acceleration is 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# http://doc.feetech.cn/#/prodinfodownload?srcType=FT-SCSCL-emanual-cbcc8ab2e3384282a01d4bf3
|
||||||
SCS_SERIES_CONTROL_TABLE = {
|
SCS_SERIES_CONTROL_TABLE = {
|
||||||
# EPROM
|
# EPROM
|
||||||
"Firmware_Version": FIRMWARE_VERSION, # read-only
|
"Firmware_Major_Version": FIRMWARE_MAJOR_VERSION, # read-only
|
||||||
|
"Firmware_Minor_Version": FIRMWARE_MINOR_VERSION, # read-only
|
||||||
"Model_Number": MODEL_NUMBER, # read-only
|
"Model_Number": MODEL_NUMBER, # read-only
|
||||||
# "Firmware_Major_Version": FIRMWARE_MAJOR_VERSION, # read-only
|
|
||||||
# "Firmware_Minor_Version": FIRMWARE_MINOR_VERSION, # read-only
|
|
||||||
# "Model_Major_Version": MODEL_MAJOR_VERSION, # read-only
|
|
||||||
# "Model_Minor_Version": MODEL_MINOR_VERSION,
|
|
||||||
"ID": (5, 1),
|
"ID": (5, 1),
|
||||||
"Baud_Rate": (6, 1),
|
"Baud_Rate": (6, 1),
|
||||||
"Return_Delay": (7, 1),
|
"Return_Delay_Time": (7, 1),
|
||||||
"Response_Status_Level": (8, 1),
|
"Response_Status_Level": (8, 1),
|
||||||
"Min_Position_Limit": (9, 2),
|
"Min_Position_Limit": (9, 2),
|
||||||
"Max_Position_Limit": (11, 2),
|
"Max_Position_Limit": (11, 2),
|
||||||
|
@ -100,38 +118,45 @@ SCS_SERIES_CONTROL_TABLE = {
|
||||||
"Acceleration": (41, 1),
|
"Acceleration": (41, 1),
|
||||||
"Goal_Position": (42, 2),
|
"Goal_Position": (42, 2),
|
||||||
"Running_Time": (44, 2),
|
"Running_Time": (44, 2),
|
||||||
"Goal_Speed": (46, 2),
|
"Goal_Velocity": (46, 2),
|
||||||
"Lock": (48, 1),
|
"Lock": (48, 1),
|
||||||
"Present_Position": (56, 2), # read-only
|
"Present_Position": (56, 2), # read-only
|
||||||
"Present_Speed": (58, 2), # read-only
|
"Present_Velocity": (58, 2), # read-only
|
||||||
"Present_Load": (60, 2), # read-only
|
"Present_Load": (60, 2), # read-only
|
||||||
"Present_Voltage": (62, 1), # read-only
|
"Present_Voltage": (62, 1), # read-only
|
||||||
"Present_Temperature": (63, 1), # read-only
|
"Present_Temperature": (63, 1), # read-only
|
||||||
"Sync_Write_Flag": (64, 1), # read-only
|
"Sync_Write_Flag": (64, 1), # read-only
|
||||||
"Status": (65, 1), # read-only
|
"Status": (65, 1), # read-only
|
||||||
"Moving": (66, 1), # read-only
|
"Moving": (66, 1), # read-only
|
||||||
|
# Factory
|
||||||
|
"PWM_Maximum_Step": (78, 1),
|
||||||
|
"Moving_Velocity_Threshold*50": (79, 1),
|
||||||
|
"DTs": (80, 1), # (ms)
|
||||||
|
"Minimum_Velocity_Limit*50": (81, 1),
|
||||||
|
"Maximum_Velocity_Limit*50": (82, 1),
|
||||||
|
"Acceleration_2": (83, 1), # don't know what that is
|
||||||
}
|
}
|
||||||
|
|
||||||
STS_SMS_SERIES_BAUDRATE_TABLE = {
|
STS_SMS_SERIES_BAUDRATE_TABLE = {
|
||||||
0: 1_000_000,
|
1_000_000: 0,
|
||||||
1: 500_000,
|
500_000: 1,
|
||||||
2: 250_000,
|
250_000: 2,
|
||||||
3: 128_000,
|
128_000: 3,
|
||||||
4: 115_200,
|
115_200: 4,
|
||||||
5: 57_600,
|
57_600: 5,
|
||||||
6: 38_400,
|
38_400: 6,
|
||||||
7: 19_200,
|
19_200: 7,
|
||||||
}
|
}
|
||||||
|
|
||||||
SCS_SERIES_BAUDRATE_TABLE = {
|
SCS_SERIES_BAUDRATE_TABLE = {
|
||||||
0: 1_000_000,
|
1_000_000: 0,
|
||||||
1: 500_000,
|
500_000: 1,
|
||||||
2: 250_000,
|
250_000: 2,
|
||||||
3: 128_000,
|
128_000: 3,
|
||||||
4: 115_200,
|
115_200: 4,
|
||||||
5: 57_600,
|
57_600: 5,
|
||||||
6: 38_400,
|
38_400: 6,
|
||||||
7: 19_200,
|
19_200: 7,
|
||||||
}
|
}
|
||||||
|
|
||||||
MODEL_CONTROL_TABLE = {
|
MODEL_CONTROL_TABLE = {
|
||||||
|
@ -150,7 +175,7 @@ MODEL_RESOLUTION = {
|
||||||
"scs_series": 1024,
|
"scs_series": 1024,
|
||||||
"sts3215": 4096,
|
"sts3215": 4096,
|
||||||
"sts3250": 4096,
|
"sts3250": 4096,
|
||||||
"sm8512bl": 4096,
|
"sm8512bl": 65536,
|
||||||
"scs0009": 1024,
|
"scs0009": 1024,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -167,7 +192,7 @@ MODEL_BAUDRATE_TABLE = {
|
||||||
# Sign-Magnitude encoding bits
|
# Sign-Magnitude encoding bits
|
||||||
STS_SMS_SERIES_ENCODINGS_TABLE = {
|
STS_SMS_SERIES_ENCODINGS_TABLE = {
|
||||||
"Homing_Offset": 11,
|
"Homing_Offset": 11,
|
||||||
"Goal_Speed": 15,
|
"Goal_Velocity": 15,
|
||||||
}
|
}
|
||||||
|
|
||||||
MODEL_ENCODING_TABLE = {
|
MODEL_ENCODING_TABLE = {
|
||||||
|
@ -194,10 +219,19 @@ SCAN_BAUDRATES = [
|
||||||
1_000_000,
|
1_000_000,
|
||||||
]
|
]
|
||||||
|
|
||||||
# {model: model_number} TODO
|
|
||||||
MODEL_NUMBER_TABLE = {
|
MODEL_NUMBER_TABLE = {
|
||||||
"sts3215": 777,
|
"sts3215": 777,
|
||||||
"sts3250": None,
|
"sts3250": 2825,
|
||||||
"sm8512bl": None,
|
"sm8512bl": 11272,
|
||||||
"scs0009": 1284,
|
"scs0009": 1284,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MODEL_PROTOCOL = {
|
||||||
|
"sts_series": 0,
|
||||||
|
"sms_series": 0,
|
||||||
|
"scs_series": 1,
|
||||||
|
"sts3215": 0,
|
||||||
|
"sts3250": 0,
|
||||||
|
"sm8512bl": 0,
|
||||||
|
"scs0009": 1,
|
||||||
|
}
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
import logging
|
import logging
|
||||||
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
|
@ -254,6 +255,7 @@ class MotorsBus(abc.ABC):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
available_baudrates: list[int]
|
available_baudrates: list[int]
|
||||||
|
default_baudrate: int
|
||||||
default_timeout: int
|
default_timeout: int
|
||||||
model_baudrate_table: dict[str, dict]
|
model_baudrate_table: dict[str, dict]
|
||||||
model_ctrl_table: dict[str, dict]
|
model_ctrl_table: dict[str, dict]
|
||||||
|
@ -283,6 +285,8 @@ class MotorsBus(abc.ABC):
|
||||||
self._id_to_name_dict = {m.id: name for name, m in self.motors.items()}
|
self._id_to_name_dict = {m.id: name for name, m in self.motors.items()}
|
||||||
self._model_nb_to_model_dict = {v: k for k, v in self.model_number_table.items()}
|
self._model_nb_to_model_dict = {v: k for k, v in self.model_number_table.items()}
|
||||||
|
|
||||||
|
self._validate_motors()
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.motors)
|
return len(self.motors)
|
||||||
|
|
||||||
|
@ -341,7 +345,7 @@ class MotorsBus(abc.ABC):
|
||||||
else:
|
else:
|
||||||
raise TypeError(f"'{motor}' should be int, str.")
|
raise TypeError(f"'{motor}' should be int, str.")
|
||||||
|
|
||||||
def _get_names_list(self, motors: str | list[str] | None) -> list[str]:
|
def _get_motors_list(self, motors: str | list[str] | None) -> list[str]:
|
||||||
if motors is None:
|
if motors is None:
|
||||||
return self.names
|
return self.names
|
||||||
elif isinstance(motors, str):
|
elif isinstance(motors, str):
|
||||||
|
@ -375,9 +379,13 @@ class MotorsBus(abc.ABC):
|
||||||
|
|
||||||
def _assert_motors_exist(self) -> None:
|
def _assert_motors_exist(self) -> None:
|
||||||
# TODO(aliberts): collect all wrong ids/models and display them at once
|
# TODO(aliberts): collect all wrong ids/models and display them at once
|
||||||
found_models = self.broadcast_ping()
|
found_models = {}
|
||||||
|
for id_ in self.ids:
|
||||||
|
model_nb = self.ping(id_)
|
||||||
|
if model_nb is not None:
|
||||||
|
found_models[id_] = model_nb
|
||||||
expected_models = {m.id: self.model_number_table[m.model] for m in self.motors.values()}
|
expected_models = {m.id: self.model_number_table[m.model] for m in self.motors.values()}
|
||||||
if not found_models or set(found_models) != set(self.ids):
|
if set(found_models) != set(self.ids):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"{self.__class__.__name__} is supposed to have these motors: ({{id: model_nb}})"
|
f"{self.__class__.__name__} is supposed to have these motors: ({{id: model_nb}})"
|
||||||
f"\n{pformat(expected_models, indent=4, sort_dicts=False)}\n"
|
f"\n{pformat(expected_models, indent=4, sort_dicts=False)}\n"
|
||||||
|
@ -401,36 +409,36 @@ class MotorsBus(abc.ABC):
|
||||||
def is_connected(self) -> bool:
|
def is_connected(self) -> bool:
|
||||||
return self.port_handler.is_open
|
return self.port_handler.is_open
|
||||||
|
|
||||||
def connect(self, assert_motors_exist: bool = True) -> None:
|
def connect(self, handshake: bool = True) -> None:
|
||||||
if self.is_connected:
|
if self.is_connected:
|
||||||
raise DeviceAlreadyConnectedError(
|
raise DeviceAlreadyConnectedError(
|
||||||
f"{self.__class__.__name__}('{self.port}') is already connected. Do not call `{self.__class__.__name__}.connect()` twice."
|
f"{self.__class__.__name__}('{self.port}') is already connected. Do not call `{self.__class__.__name__}.connect()` twice."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._connect(handshake)
|
||||||
|
self.set_timeout()
|
||||||
|
logger.debug(f"{self.__class__.__name__} connected.")
|
||||||
|
|
||||||
|
def _connect(self, handshake: bool = True) -> None:
|
||||||
try:
|
try:
|
||||||
if not self.port_handler.openPort():
|
if not self.port_handler.openPort():
|
||||||
raise OSError(f"Failed to open port '{self.port}'.")
|
raise OSError(f"Failed to open port '{self.port}'.")
|
||||||
elif assert_motors_exist:
|
elif handshake:
|
||||||
self._assert_motors_exist()
|
self._handshake()
|
||||||
except (FileNotFoundError, OSError, serial.SerialException) as e:
|
except (FileNotFoundError, OSError, serial.SerialException) as e:
|
||||||
raise ConnectionError(
|
raise ConnectionError(
|
||||||
f"\nCould not connect on port '{self.port}'. Make sure you are using the correct port."
|
f"\nCould not connect on port '{self.port}'. Make sure you are using the correct port."
|
||||||
"\nTry running `python lerobot/scripts/find_motors_bus_port.py`\n"
|
"\nTry running `python lerobot/scripts/find_motors_bus_port.py`\n"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
self.set_timeout()
|
@abc.abstractmethod
|
||||||
logger.debug(f"{self.__class__.__name__} connected.")
|
def _handshake(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def scan_port(cls, port: str) -> dict[int, list[int]]:
|
def scan_port(cls, port: str, *args, **kwargs) -> dict[int, list[int]]:
|
||||||
bus = cls(port, {})
|
bus = cls(port, {}, *args, **kwargs)
|
||||||
try:
|
bus._connect(handshake=False)
|
||||||
bus.port_handler.openPort()
|
|
||||||
except (FileNotFoundError, OSError, serial.SerialException) as e:
|
|
||||||
raise ConnectionError(
|
|
||||||
f"Could not connect to port '{port}'. Make sure you are using the correct port."
|
|
||||||
"\nTry running `python lerobot/scripts/find_motors_bus_port.py`\n"
|
|
||||||
) from e
|
|
||||||
baudrate_ids = {}
|
baudrate_ids = {}
|
||||||
for baudrate in tqdm(bus.available_baudrates, desc="Scanning port"):
|
for baudrate in tqdm(bus.available_baudrates, desc="Scanning port"):
|
||||||
bus.set_baudrate(baudrate)
|
bus.set_baudrate(baudrate)
|
||||||
|
@ -441,18 +449,57 @@ class MotorsBus(abc.ABC):
|
||||||
|
|
||||||
return baudrate_ids
|
return baudrate_ids
|
||||||
|
|
||||||
|
def setup_motor(
|
||||||
|
self, motor: str, initial_baudrate: int | None = None, initial_id: int | None = None
|
||||||
|
) -> None:
|
||||||
|
if not self.is_connected:
|
||||||
|
self._connect(handshake=False)
|
||||||
|
|
||||||
|
if initial_baudrate is None:
|
||||||
|
initial_baudrate, initial_id = self._find_single_motor(motor)
|
||||||
|
|
||||||
|
if initial_id is None:
|
||||||
|
_, initial_id = self._find_single_motor(motor, initial_baudrate)
|
||||||
|
|
||||||
|
model = self.motors[motor].model
|
||||||
|
target_id = self.motors[motor].id
|
||||||
|
self.set_baudrate(initial_baudrate)
|
||||||
|
|
||||||
|
# Set ID
|
||||||
|
addr, length = get_address(self.model_ctrl_table, "ID", model)
|
||||||
|
self._write(addr, length, initial_id, target_id)
|
||||||
|
|
||||||
|
# Set Baudrate
|
||||||
|
addr, length = get_address(self.model_ctrl_table, "Baud_Rate", model)
|
||||||
|
baudrate_value = self.model_baudrate_table[model][self.default_baudrate]
|
||||||
|
self._write(addr, length, target_id, baudrate_value)
|
||||||
|
|
||||||
|
self.set_baudrate(self.default_baudrate)
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def _find_single_motor(self, motor: str, initial_baudrate: int | None) -> tuple[int, int]:
|
||||||
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def configure_motors(self) -> None:
|
def configure_motors(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def disable_torque(self, motors: str | list[str] | None = None) -> None:
|
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def enable_torque(self, motors: str | list[str] | None = None) -> None:
|
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def torque_disabled(self):
|
||||||
|
self.disable_torque()
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
self.enable_torque()
|
||||||
|
|
||||||
def set_timeout(self, timeout_ms: int | None = None):
|
def set_timeout(self, timeout_ms: int | None = None):
|
||||||
timeout_ms = timeout_ms if timeout_ms is not None else self.default_timeout
|
timeout_ms = timeout_ms if timeout_ms is not None else self.default_timeout
|
||||||
self.port_handler.setPacketTimeoutMillis(timeout_ms)
|
self.port_handler.setPacketTimeoutMillis(timeout_ms)
|
||||||
|
@ -473,35 +520,13 @@ class MotorsBus(abc.ABC):
|
||||||
def is_calibrated(self) -> bool:
|
def is_calibrated(self) -> bool:
|
||||||
return self.calibration == self.read_calibration()
|
return self.calibration == self.read_calibration()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
def read_calibration(self) -> dict[str, MotorCalibration]:
|
def read_calibration(self) -> dict[str, MotorCalibration]:
|
||||||
offsets = self.sync_read("Homing_Offset", normalize=False)
|
pass
|
||||||
mins = self.sync_read("Min_Position_Limit", normalize=False)
|
|
||||||
maxes = self.sync_read("Max_Position_Limit", normalize=False)
|
|
||||||
|
|
||||||
try:
|
|
||||||
drive_modes = self.sync_read("Drive_Mode", normalize=False)
|
|
||||||
except KeyError:
|
|
||||||
drive_modes = dict.fromkeys(self.names, 0)
|
|
||||||
|
|
||||||
calibration = {}
|
|
||||||
for name, motor in self.motors.items():
|
|
||||||
calibration[name] = MotorCalibration(
|
|
||||||
id=motor.id,
|
|
||||||
drive_mode=drive_modes[name],
|
|
||||||
homing_offset=offsets[name],
|
|
||||||
range_min=mins[name],
|
|
||||||
range_max=maxes[name],
|
|
||||||
)
|
|
||||||
|
|
||||||
return calibration
|
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
def write_calibration(self, calibration_dict: dict[str, MotorCalibration]) -> None:
|
def write_calibration(self, calibration_dict: dict[str, MotorCalibration]) -> None:
|
||||||
for motor, calibration in calibration_dict.items():
|
pass
|
||||||
self.write("Homing_Offset", motor, calibration.homing_offset)
|
|
||||||
self.write("Min_Position_Limit", motor, calibration.range_min)
|
|
||||||
self.write("Max_Position_Limit", motor, calibration.range_max)
|
|
||||||
|
|
||||||
self.calibration = calibration_dict
|
|
||||||
|
|
||||||
def reset_calibration(self, motors: NameOrID | list[NameOrID] | None = None) -> None:
|
def reset_calibration(self, motors: NameOrID | list[NameOrID] | None = None) -> None:
|
||||||
if motors is None:
|
if motors is None:
|
||||||
|
@ -544,7 +569,7 @@ class MotorsBus(abc.ABC):
|
||||||
motors = self.names
|
motors = self.names
|
||||||
elif isinstance(motors, (str, int)):
|
elif isinstance(motors, (str, int)):
|
||||||
motors = [motors]
|
motors = [motors]
|
||||||
else:
|
elif not isinstance(motors, list):
|
||||||
raise TypeError(motors)
|
raise TypeError(motors)
|
||||||
|
|
||||||
self.reset_calibration(motors)
|
self.reset_calibration(motors)
|
||||||
|
@ -600,12 +625,15 @@ class MotorsBus(abc.ABC):
|
||||||
def _normalize(self, data_name: str, ids_values: dict[int, int]) -> dict[int, float]:
|
def _normalize(self, data_name: str, ids_values: dict[int, int]) -> dict[int, float]:
|
||||||
if not self.calibration:
|
if not self.calibration:
|
||||||
raise RuntimeError(f"{self} has no calibration registered.")
|
raise RuntimeError(f"{self} has no calibration registered.")
|
||||||
|
|
||||||
normalized_values = {}
|
normalized_values = {}
|
||||||
for id_, val in ids_values.items():
|
for id_, val in ids_values.items():
|
||||||
name = self._id_to_name(id_)
|
name = self._id_to_name(id_)
|
||||||
min_ = self.calibration[name].range_min
|
min_ = self.calibration[name].range_min
|
||||||
max_ = self.calibration[name].range_max
|
max_ = self.calibration[name].range_max
|
||||||
bounded_val = min(max_, max(min_, val))
|
bounded_val = min(max_, max(min_, val))
|
||||||
|
# TODO(Steven): normalization can go boom if max_ == min_, we should add a check probably in record_ranges_of_motions
|
||||||
|
# (which probably indicates the user forgot to move a motor, most likely a gripper-like one)
|
||||||
if self.motors[name].norm_mode is MotorNormMode.RANGE_M100_100:
|
if self.motors[name].norm_mode is MotorNormMode.RANGE_M100_100:
|
||||||
normalized_values[id_] = (((bounded_val - min_) / (max_ - min_)) * 200) - 100
|
normalized_values[id_] = (((bounded_val - min_) / (max_ - min_)) * 200) - 100
|
||||||
elif self.motors[name].norm_mode is MotorNormMode.RANGE_0_100:
|
elif self.motors[name].norm_mode is MotorNormMode.RANGE_0_100:
|
||||||
|
@ -617,6 +645,9 @@ class MotorsBus(abc.ABC):
|
||||||
return normalized_values
|
return normalized_values
|
||||||
|
|
||||||
def _unnormalize(self, data_name: str, ids_values: dict[int, float]) -> dict[int, int]:
|
def _unnormalize(self, data_name: str, ids_values: dict[int, float]) -> dict[int, int]:
|
||||||
|
if not self.calibration:
|
||||||
|
raise RuntimeError(f"{self} has no calibration registered.")
|
||||||
|
|
||||||
unnormalized_values = {}
|
unnormalized_values = {}
|
||||||
for id_, val in ids_values.items():
|
for id_, val in ids_values.items():
|
||||||
name = self._id_to_name(id_)
|
name = self._id_to_name(id_)
|
||||||
|
@ -642,57 +673,30 @@ class MotorsBus(abc.ABC):
|
||||||
def _decode_sign(self, data_name: str, ids_values: dict[int, int]) -> dict[int, int]:
|
def _decode_sign(self, data_name: str, ids_values: dict[int, int]) -> dict[int, int]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _serialize_data(self, value: int, n_bytes: int) -> list[int]:
|
def _serialize_data(self, value: int, length: int) -> list[int]:
|
||||||
"""
|
"""
|
||||||
Converts an unsigned integer value into a list of byte-sized integers to be sent via a communication
|
Converts an unsigned integer value into a list of byte-sized integers to be sent via a communication
|
||||||
protocol. Depending on the protocol, split values can be in big-endian or little-endian order.
|
protocol. Depending on the protocol, split values can be in big-endian or little-endian order.
|
||||||
|
|
||||||
This function extracts the individual bytes of an integer based on the
|
Supported data length for both Feetech and Dynamixel:
|
||||||
specified number of bytes (`n_bytes`). The output is a list of integers,
|
|
||||||
each representing a byte (0-255).
|
|
||||||
|
|
||||||
**Byte order:** The function returns bytes in **little-endian format**,
|
|
||||||
meaning the least significant byte (LSB) comes first.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
value (int): The unsigned integer to be converted into a byte list. Must be within
|
|
||||||
the valid range for the specified `n_bytes`.
|
|
||||||
n_bytes (int): The number of bytes to use for conversion. Supported values for both Feetech and
|
|
||||||
Dynamixel:
|
|
||||||
- 1 (for values 0 to 255)
|
- 1 (for values 0 to 255)
|
||||||
- 2 (for values 0 to 65,535)
|
- 2 (for values 0 to 65,535)
|
||||||
- 4 (for values 0 to 4,294,967,295)
|
- 4 (for values 0 to 4,294,967,295)
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If `value` is negative or exceeds the maximum allowed for `n_bytes`.
|
|
||||||
NotImplementedError: If `n_bytes` is not 1, 2, or 4.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[int]: A list of integers, each representing a byte in **little-endian order**.
|
|
||||||
|
|
||||||
Examples (for a little-endian protocol):
|
|
||||||
>>> split_int_bytes(0x12, 1)
|
|
||||||
[18]
|
|
||||||
>>> split_int_bytes(0x1234, 2)
|
|
||||||
[52, 18] # 0x1234 → 0x34 0x12 (little-endian)
|
|
||||||
>>> split_int_bytes(0x12345678, 4)
|
|
||||||
[120, 86, 52, 18] # 0x12345678 → 0x78 0x56 0x34 0x12
|
|
||||||
"""
|
"""
|
||||||
if value < 0:
|
if value < 0:
|
||||||
raise ValueError(f"Negative values are not allowed: {value}")
|
raise ValueError(f"Negative values are not allowed: {value}")
|
||||||
|
|
||||||
max_value = {1: 0xFF, 2: 0xFFFF, 4: 0xFFFFFFFF}.get(n_bytes)
|
max_value = {1: 0xFF, 2: 0xFFFF, 4: 0xFFFFFFFF}.get(length)
|
||||||
if max_value is None:
|
if max_value is None:
|
||||||
raise NotImplementedError(f"Unsupported byte size: {n_bytes}. Expected [1, 2, 4].")
|
raise NotImplementedError(f"Unsupported byte size: {length}. Expected [1, 2, 4].")
|
||||||
|
|
||||||
if value > max_value:
|
if value > max_value:
|
||||||
raise ValueError(f"Value {value} exceeds the maximum for {n_bytes} bytes ({max_value}).")
|
raise ValueError(f"Value {value} exceeds the maximum for {length} bytes ({max_value}).")
|
||||||
|
|
||||||
return self._split_into_byte_chunks(value, n_bytes)
|
return self._split_into_byte_chunks(value, length)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def _split_into_byte_chunks(value: int, n_bytes: int) -> list[int]:
|
def _split_into_byte_chunks(self, value: int, length: int) -> list[int]:
|
||||||
"""Convert an integer into a list of byte-sized integers."""
|
"""Convert an integer into a list of byte-sized integers."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -711,7 +715,7 @@ class MotorsBus(abc.ABC):
|
||||||
return
|
return
|
||||||
if self._is_error(error):
|
if self._is_error(error):
|
||||||
if raise_on_error:
|
if raise_on_error:
|
||||||
raise RuntimeError(self.packet_handler.getTxRxResult(comm))
|
raise RuntimeError(self.packet_handler.getRxPacketError(error))
|
||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -736,19 +740,10 @@ class MotorsBus(abc.ABC):
|
||||||
|
|
||||||
id_ = self.motors[motor].id
|
id_ = self.motors[motor].id
|
||||||
model = self.motors[motor].model
|
model = self.motors[motor].model
|
||||||
addr, n_bytes = get_address(self.model_ctrl_table, model, data_name)
|
addr, length = get_address(self.model_ctrl_table, model, data_name)
|
||||||
|
|
||||||
value, comm, error = self._read(addr, n_bytes, id_, num_retry=num_retry)
|
err_msg = f"Failed to read '{data_name}' on {id_=} after {num_retry + 1} tries."
|
||||||
if not self._is_comm_success(comm):
|
value, _, _ = self._read(addr, length, id_, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
|
||||||
raise ConnectionError(
|
|
||||||
f"Failed to read '{data_name}' on {id_=} after {num_retry + 1} tries."
|
|
||||||
f"{self.packet_handler.getTxRxResult(comm)}"
|
|
||||||
)
|
|
||||||
elif self._is_error(error):
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Failed to read '{data_name}' on {id_=} after {num_retry + 1} tries."
|
|
||||||
f"\n{self.packet_handler.getRxPacketError(error)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
id_value = self._decode_sign(data_name, {id_: value})
|
id_value = self._decode_sign(data_name, {id_: value})
|
||||||
|
|
||||||
|
@ -757,25 +752,39 @@ class MotorsBus(abc.ABC):
|
||||||
|
|
||||||
return id_value[id_]
|
return id_value[id_]
|
||||||
|
|
||||||
def _read(self, addr: int, n_bytes: int, motor_id: int, num_retry: int = 0) -> tuple[int, int]:
|
def _read(
|
||||||
if n_bytes == 1:
|
self,
|
||||||
|
address: int,
|
||||||
|
length: int,
|
||||||
|
motor_id: int,
|
||||||
|
*,
|
||||||
|
num_retry: int = 0,
|
||||||
|
raise_on_error: bool = True,
|
||||||
|
err_msg: str = "",
|
||||||
|
) -> tuple[int, int]:
|
||||||
|
if length == 1:
|
||||||
read_fn = self.packet_handler.read1ByteTxRx
|
read_fn = self.packet_handler.read1ByteTxRx
|
||||||
elif n_bytes == 2:
|
elif length == 2:
|
||||||
read_fn = self.packet_handler.read2ByteTxRx
|
read_fn = self.packet_handler.read2ByteTxRx
|
||||||
elif n_bytes == 4:
|
elif length == 4:
|
||||||
read_fn = self.packet_handler.read4ByteTxRx
|
read_fn = self.packet_handler.read4ByteTxRx
|
||||||
else:
|
else:
|
||||||
raise ValueError(n_bytes)
|
raise ValueError(length)
|
||||||
|
|
||||||
for n_try in range(1 + num_retry):
|
for n_try in range(1 + num_retry):
|
||||||
value, comm, error = read_fn(self.port_handler, motor_id, addr)
|
value, comm, error = read_fn(self.port_handler, motor_id, address)
|
||||||
if self._is_comm_success(comm):
|
if self._is_comm_success(comm):
|
||||||
break
|
break
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Failed to read @{addr=} ({n_bytes=}) on {motor_id=} ({n_try=}): "
|
f"Failed to read @{address=} ({length=}) on {motor_id=} ({n_try=}): "
|
||||||
+ self.packet_handler.getTxRxResult(comm)
|
+ self.packet_handler.getTxRxResult(comm)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not self._is_comm_success(comm) and raise_on_error:
|
||||||
|
raise ConnectionError(f"{err_msg} {self.packet_handler.getTxRxResult(comm)}")
|
||||||
|
elif self._is_error(error) and raise_on_error:
|
||||||
|
raise RuntimeError(f"{err_msg} {self.packet_handler.getRxPacketError(error)}")
|
||||||
|
|
||||||
return value, comm, error
|
return value, comm, error
|
||||||
|
|
||||||
def write(
|
def write(
|
||||||
|
@ -788,38 +797,42 @@ class MotorsBus(abc.ABC):
|
||||||
|
|
||||||
id_ = self.motors[motor].id
|
id_ = self.motors[motor].id
|
||||||
model = self.motors[motor].model
|
model = self.motors[motor].model
|
||||||
addr, n_bytes = get_address(self.model_ctrl_table, model, data_name)
|
addr, length = get_address(self.model_ctrl_table, model, data_name)
|
||||||
|
|
||||||
if normalize and data_name in self.normalized_data:
|
if normalize and data_name in self.normalized_data:
|
||||||
value = self._unnormalize(data_name, {id_: value})[id_]
|
value = self._unnormalize(data_name, {id_: value})[id_]
|
||||||
|
|
||||||
value = self._encode_sign(data_name, {id_: value})[id_]
|
value = self._encode_sign(data_name, {id_: value})[id_]
|
||||||
|
|
||||||
comm, error = self._write(addr, n_bytes, id_, value, num_retry=num_retry)
|
err_msg = f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries."
|
||||||
if not self._is_comm_success(comm):
|
self._write(addr, length, id_, value, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
|
||||||
raise ConnectionError(
|
|
||||||
f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries."
|
|
||||||
f"\n{self.packet_handler.getTxRxResult(comm)}"
|
|
||||||
)
|
|
||||||
elif self._is_error(error):
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries."
|
|
||||||
f"\n{self.packet_handler.getRxPacketError(error)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _write(
|
def _write(
|
||||||
self, addr: int, n_bytes: int, motor_id: int, value: int, num_retry: int = 0
|
self,
|
||||||
|
addr: int,
|
||||||
|
length: int,
|
||||||
|
motor_id: int,
|
||||||
|
value: int,
|
||||||
|
*,
|
||||||
|
num_retry: int = 0,
|
||||||
|
raise_on_error: bool = True,
|
||||||
|
err_msg: str = "",
|
||||||
) -> tuple[int, int]:
|
) -> tuple[int, int]:
|
||||||
data = self._serialize_data(value, n_bytes)
|
data = self._serialize_data(value, length)
|
||||||
for n_try in range(1 + num_retry):
|
for n_try in range(1 + num_retry):
|
||||||
comm, error = self.packet_handler.writeTxRx(self.port_handler, motor_id, addr, n_bytes, data)
|
comm, error = self.packet_handler.writeTxRx(self.port_handler, motor_id, addr, length, data)
|
||||||
if self._is_comm_success(comm):
|
if self._is_comm_success(comm):
|
||||||
break
|
break
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Failed to sync write @{addr=} ({n_bytes=}) on id={motor_id} with {value=} ({n_try=}): "
|
f"Failed to sync write @{addr=} ({length=}) on id={motor_id} with {value=} ({n_try=}): "
|
||||||
+ self.packet_handler.getTxRxResult(comm)
|
+ self.packet_handler.getTxRxResult(comm)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not self._is_comm_success(comm) and raise_on_error:
|
||||||
|
raise ConnectionError(f"{err_msg} {self.packet_handler.getTxRxResult(comm)}")
|
||||||
|
elif self._is_error(error) and raise_on_error:
|
||||||
|
raise RuntimeError(f"{err_msg} {self.packet_handler.getRxPacketError(error)}")
|
||||||
|
|
||||||
return comm, error
|
return comm, error
|
||||||
|
|
||||||
def sync_read(
|
def sync_read(
|
||||||
|
@ -837,7 +850,7 @@ class MotorsBus(abc.ABC):
|
||||||
|
|
||||||
self._assert_protocol_is_compatible("sync_read")
|
self._assert_protocol_is_compatible("sync_read")
|
||||||
|
|
||||||
names = self._get_names_list(motors)
|
names = self._get_motors_list(motors)
|
||||||
ids = [self.motors[name].id for name in names]
|
ids = [self.motors[name].id for name in names]
|
||||||
models = [self.motors[name].model for name in names]
|
models = [self.motors[name].model for name in names]
|
||||||
|
|
||||||
|
@ -845,13 +858,11 @@ class MotorsBus(abc.ABC):
|
||||||
assert_same_address(self.model_ctrl_table, models, data_name)
|
assert_same_address(self.model_ctrl_table, models, data_name)
|
||||||
|
|
||||||
model = next(iter(models))
|
model = next(iter(models))
|
||||||
addr, n_bytes = get_address(self.model_ctrl_table, model, data_name)
|
addr, length = get_address(self.model_ctrl_table, model, data_name)
|
||||||
|
|
||||||
comm, ids_values = self._sync_read(addr, n_bytes, ids, num_retry=num_retry)
|
err_msg = f"Failed to sync read '{data_name}' on {ids=} after {num_retry + 1} tries."
|
||||||
if not self._is_comm_success(comm):
|
ids_values, _ = self._sync_read(
|
||||||
raise ConnectionError(
|
addr, length, ids, num_retry=num_retry, raise_on_error=True, err_msg=err_msg
|
||||||
f"Failed to sync read '{data_name}' on {ids=} after {num_retry + 1} tries."
|
|
||||||
f"{self.packet_handler.getTxRxResult(comm)}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
ids_values = self._decode_sign(data_name, ids_values)
|
ids_values = self._decode_sign(data_name, ids_values)
|
||||||
|
@ -862,25 +873,35 @@ class MotorsBus(abc.ABC):
|
||||||
return {self._id_to_name(id_): value for id_, value in ids_values.items()}
|
return {self._id_to_name(id_): value for id_, value in ids_values.items()}
|
||||||
|
|
||||||
def _sync_read(
|
def _sync_read(
|
||||||
self, addr: int, n_bytes: int, motor_ids: list[int], num_retry: int = 0
|
self,
|
||||||
) -> tuple[int, dict[int, int]]:
|
addr: int,
|
||||||
self._setup_sync_reader(motor_ids, addr, n_bytes)
|
length: int,
|
||||||
|
motor_ids: list[int],
|
||||||
|
*,
|
||||||
|
num_retry: int = 0,
|
||||||
|
raise_on_error: bool = True,
|
||||||
|
err_msg: str = "",
|
||||||
|
) -> tuple[dict[int, int], int]:
|
||||||
|
self._setup_sync_reader(motor_ids, addr, length)
|
||||||
for n_try in range(1 + num_retry):
|
for n_try in range(1 + num_retry):
|
||||||
comm = self.sync_reader.txRxPacket()
|
comm = self.sync_reader.txRxPacket()
|
||||||
if self._is_comm_success(comm):
|
if self._is_comm_success(comm):
|
||||||
break
|
break
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Failed to sync read @{addr=} ({n_bytes=}) on {motor_ids=} ({n_try=}): "
|
f"Failed to sync read @{addr=} ({length=}) on {motor_ids=} ({n_try=}): "
|
||||||
+ self.packet_handler.getTxRxResult(comm)
|
+ self.packet_handler.getTxRxResult(comm)
|
||||||
)
|
)
|
||||||
|
|
||||||
values = {id_: self.sync_reader.getData(id_, addr, n_bytes) for id_ in motor_ids}
|
if not self._is_comm_success(comm) and raise_on_error:
|
||||||
return comm, values
|
raise ConnectionError(f"{err_msg} {self.packet_handler.getTxRxResult(comm)}")
|
||||||
|
|
||||||
def _setup_sync_reader(self, motor_ids: list[int], addr: int, n_bytes: int) -> None:
|
values = {id_: self.sync_reader.getData(id_, addr, length) for id_ in motor_ids}
|
||||||
|
return values, comm
|
||||||
|
|
||||||
|
def _setup_sync_reader(self, motor_ids: list[int], addr: int, length: int) -> None:
|
||||||
self.sync_reader.clearParam()
|
self.sync_reader.clearParam()
|
||||||
self.sync_reader.start_address = addr
|
self.sync_reader.start_address = addr
|
||||||
self.sync_reader.data_length = n_bytes
|
self.sync_reader.data_length = length
|
||||||
for id_ in motor_ids:
|
for id_ in motor_ids:
|
||||||
self.sync_reader.addParam(id_)
|
self.sync_reader.addParam(id_)
|
||||||
|
|
||||||
|
@ -888,15 +909,15 @@ class MotorsBus(abc.ABC):
|
||||||
# Would have to handle the logic of checking if a packet has been sent previously though but doable.
|
# Would have to handle the logic of checking if a packet has been sent previously though but doable.
|
||||||
# This could be at the cost of increase latency between the moment the data is produced by the motors and
|
# This could be at the cost of increase latency between the moment the data is produced by the motors and
|
||||||
# the moment it is used by a policy.
|
# the moment it is used by a policy.
|
||||||
# def _async_read(self, motor_ids: list[int], address: int, n_bytes: int):
|
# def _async_read(self, motor_ids: list[int], address: int, length: int):
|
||||||
# if self.sync_reader.start_address != address or self.sync_reader.data_length != n_bytes or ...:
|
# if self.sync_reader.start_address != address or self.sync_reader.data_length != length or ...:
|
||||||
# self._setup_sync_reader(motor_ids, address, n_bytes)
|
# self._setup_sync_reader(motor_ids, address, length)
|
||||||
# else:
|
# else:
|
||||||
# self.sync_reader.rxPacket()
|
# self.sync_reader.rxPacket()
|
||||||
# self.sync_reader.txPacket()
|
# self.sync_reader.txPacket()
|
||||||
|
|
||||||
# for id_ in motor_ids:
|
# for id_ in motor_ids:
|
||||||
# value = self.sync_reader.getData(id_, address, n_bytes)
|
# value = self.sync_reader.getData(id_, address, length)
|
||||||
|
|
||||||
def sync_write(
|
def sync_write(
|
||||||
self,
|
self,
|
||||||
|
@ -917,39 +938,46 @@ class MotorsBus(abc.ABC):
|
||||||
assert_same_address(self.model_ctrl_table, models, data_name)
|
assert_same_address(self.model_ctrl_table, models, data_name)
|
||||||
|
|
||||||
model = next(iter(models))
|
model = next(iter(models))
|
||||||
addr, n_bytes = get_address(self.model_ctrl_table, model, data_name)
|
addr, length = get_address(self.model_ctrl_table, model, data_name)
|
||||||
|
|
||||||
if normalize and data_name in self.normalized_data:
|
if normalize and data_name in self.normalized_data:
|
||||||
ids_values = self._unnormalize(data_name, ids_values)
|
ids_values = self._unnormalize(data_name, ids_values)
|
||||||
|
|
||||||
ids_values = self._encode_sign(data_name, ids_values)
|
ids_values = self._encode_sign(data_name, ids_values)
|
||||||
|
|
||||||
comm = self._sync_write(addr, n_bytes, ids_values, num_retry=num_retry)
|
err_msg = f"Failed to sync write '{data_name}' with {ids_values=} after {num_retry + 1} tries."
|
||||||
if not self._is_comm_success(comm):
|
self._sync_write(addr, length, ids_values, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
|
||||||
raise ConnectionError(
|
|
||||||
f"Failed to sync write '{data_name}' with {ids_values=} after {num_retry + 1} tries."
|
|
||||||
f"\n{self.packet_handler.getTxRxResult(comm)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _sync_write(self, addr: int, n_bytes: int, ids_values: dict[int, int], num_retry: int = 0) -> int:
|
def _sync_write(
|
||||||
self._setup_sync_writer(ids_values, addr, n_bytes)
|
self,
|
||||||
|
addr: int,
|
||||||
|
length: int,
|
||||||
|
ids_values: dict[int, int],
|
||||||
|
num_retry: int = 0,
|
||||||
|
raise_on_error: bool = True,
|
||||||
|
err_msg: str = "",
|
||||||
|
) -> int:
|
||||||
|
self._setup_sync_writer(ids_values, addr, length)
|
||||||
for n_try in range(1 + num_retry):
|
for n_try in range(1 + num_retry):
|
||||||
comm = self.sync_writer.txPacket()
|
comm = self.sync_writer.txPacket()
|
||||||
if self._is_comm_success(comm):
|
if self._is_comm_success(comm):
|
||||||
break
|
break
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Failed to sync write @{addr=} ({n_bytes=}) with {ids_values=} ({n_try=}): "
|
f"Failed to sync write @{addr=} ({length=}) with {ids_values=} ({n_try=}): "
|
||||||
+ self.packet_handler.getTxRxResult(comm)
|
+ self.packet_handler.getTxRxResult(comm)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not self._is_comm_success(comm) and raise_on_error:
|
||||||
|
raise ConnectionError(f"{err_msg} {self.packet_handler.getTxRxResult(comm)}")
|
||||||
|
|
||||||
return comm
|
return comm
|
||||||
|
|
||||||
def _setup_sync_writer(self, ids_values: dict[int, int], addr: int, n_bytes: int) -> None:
|
def _setup_sync_writer(self, ids_values: dict[int, int], addr: int, length: int) -> None:
|
||||||
self.sync_writer.clearParam()
|
self.sync_writer.clearParam()
|
||||||
self.sync_writer.start_address = addr
|
self.sync_writer.start_address = addr
|
||||||
self.sync_writer.data_length = n_bytes
|
self.sync_writer.data_length = length
|
||||||
for id_, value in ids_values.items():
|
for id_, value in ids_values.items():
|
||||||
data = self._serialize_data(value, n_bytes)
|
data = self._serialize_data(value, length)
|
||||||
self.sync_writer.addParam(id_, data)
|
self.sync_writer.addParam(id_, data)
|
||||||
|
|
||||||
def disconnect(self, disable_torque: bool = True) -> None:
|
def disconnect(self, disable_torque: bool = True) -> None:
|
||||||
|
@ -961,7 +989,7 @@ class MotorsBus(abc.ABC):
|
||||||
if disable_torque:
|
if disable_torque:
|
||||||
self.port_handler.clearPort()
|
self.port_handler.clearPort()
|
||||||
self.port_handler.is_using = False
|
self.port_handler.is_using = False
|
||||||
self.disable_torque()
|
self.disable_torque(num_retry=5)
|
||||||
|
|
||||||
self.port_handler.closePort()
|
self.port_handler.closePort()
|
||||||
logger.debug(f"{self.__class__.__name__} disconnected.")
|
logger.debug(f"{self.__class__.__name__} disconnected.")
|
||||||
|
|
|
@ -24,7 +24,7 @@ Designed by Physical Intelligence. Ported from Jax by Hugging Face.
|
||||||
|
|
||||||
Install pi0 extra dependencies:
|
Install pi0 extra dependencies:
|
||||||
```bash
|
```bash
|
||||||
pip install --no-binary=av -e ".[pi0]"
|
pip install -e ".[pi0]"
|
||||||
```
|
```
|
||||||
|
|
||||||
Example of finetuning the pi0 pretrained model (`pi0_base` in `openpi`):
|
Example of finetuning the pi0 pretrained model (`pi0_base` in `openpi`):
|
||||||
|
|
|
@ -122,7 +122,7 @@ class KochFollower(Robot):
|
||||||
|
|
||||||
full_turn_motors = ["shoulder_pan", "wrist_roll"]
|
full_turn_motors = ["shoulder_pan", "wrist_roll"]
|
||||||
unknown_range_motors = [name for name in self.arm.names if name not in full_turn_motors]
|
unknown_range_motors = [name for name in self.arm.names if name not in full_turn_motors]
|
||||||
logger.info(
|
print(
|
||||||
f"Move all joints except {full_turn_motors} sequentially through their entire "
|
f"Move all joints except {full_turn_motors} sequentially through their entire "
|
||||||
"ranges of motion.\nRecording positions. Press ENTER to stop..."
|
"ranges of motion.\nRecording positions. Press ENTER to stop..."
|
||||||
)
|
)
|
||||||
|
@ -146,21 +146,21 @@ class KochFollower(Robot):
|
||||||
logger.info(f"Calibration saved to {self.calibration_fpath}")
|
logger.info(f"Calibration saved to {self.calibration_fpath}")
|
||||||
|
|
||||||
def configure(self) -> None:
|
def configure(self) -> None:
|
||||||
self.arm.disable_torque()
|
with self.arm.torque_disabled():
|
||||||
self.arm.configure_motors()
|
self.arm.configure_motors()
|
||||||
# Use 'extended position mode' for all motors except gripper, because in joint mode the servos
|
# Use 'extended position mode' for all motors except gripper, because in joint mode the servos
|
||||||
# can't rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while
|
# can't rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while assembling
|
||||||
# assembling the arm, you could end up with a servo with a position 0 or 4095 at a crucial
|
# the arm, you could end up with a servo with a position 0 or 4095 at a crucial point
|
||||||
# point
|
|
||||||
for name in self.arm.names:
|
for name in self.arm.names:
|
||||||
if name != "gripper":
|
if name != "gripper":
|
||||||
self.arm.write("Operating_Mode", name, OperatingMode.EXTENDED_POSITION.value)
|
self.arm.write("Operating_Mode", name, OperatingMode.EXTENDED_POSITION.value)
|
||||||
|
|
||||||
# Use 'position control current based' for gripper to be limited by the limit of the current.
|
# Use 'position control current based' for gripper to be limited by the limit of the current. For
|
||||||
# For the follower gripper, it means it can grasp an object without forcing too much even tho,
|
# the follower gripper, it means it can grasp an object without forcing too much even tho, its
|
||||||
# its goal position is a complete grasp (both gripper fingers are ordered to join and reach a touch).
|
# goal position is a complete grasp (both gripper fingers are ordered to join and reach a touch).
|
||||||
# For the leader gripper, it means we can use it as a physical trigger, since we can force with our finger
|
# For the leader gripper, it means we can use it as a physical trigger, since we can force with
|
||||||
# to make it move, and it will move back to its original target position when we release the force.
|
# our finger to make it move, and it will move back to its original target position when we
|
||||||
|
# release the force.
|
||||||
self.arm.write("Operating_Mode", "gripper", OperatingMode.CURRENT_POSITION.value)
|
self.arm.write("Operating_Mode", "gripper", OperatingMode.CURRENT_POSITION.value)
|
||||||
|
|
||||||
# Set better PID values to close the gap between recorded states and actions
|
# Set better PID values to close the gap between recorded states and actions
|
||||||
|
@ -168,7 +168,6 @@ class KochFollower(Robot):
|
||||||
self.arm.write("Position_P_Gain", "elbow_flex", 1500)
|
self.arm.write("Position_P_Gain", "elbow_flex", 1500)
|
||||||
self.arm.write("Position_I_Gain", "elbow_flex", 0)
|
self.arm.write("Position_I_Gain", "elbow_flex", 0)
|
||||||
self.arm.write("Position_D_Gain", "elbow_flex", 600)
|
self.arm.write("Position_D_Gain", "elbow_flex", 600)
|
||||||
self.arm.enable_torque()
|
|
||||||
|
|
||||||
def get_observation(self) -> dict[str, Any]:
|
def get_observation(self) -> dict[str, Any]:
|
||||||
if not self.is_connected:
|
if not self.is_connected:
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
# TODO(Steven): Update README
|
||||||
|
|
||||||
# Using the [LeKiwi](https://github.com/SIGRobotics-UIUC/LeKiwi) Robot with LeRobot
|
# Using the [LeKiwi](https://github.com/SIGRobotics-UIUC/LeKiwi) Robot with LeRobot
|
||||||
|
|
||||||
## Table of Contents
|
## Table of Contents
|
||||||
|
@ -67,9 +69,15 @@ conda activate lerobot
|
||||||
git clone https://github.com/huggingface/lerobot.git ~/lerobot
|
git clone https://github.com/huggingface/lerobot.git ~/lerobot
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 5. Install LeRobot with dependencies for the feetech motors:
|
#### 5. Install ffmpeg in your environment:
|
||||||
|
When using `miniconda`, install `ffmpeg` in your environment:
|
||||||
```bash
|
```bash
|
||||||
cd ~/lerobot && pip install --no-binary=av -e ".[feetech]"
|
conda install ffmpeg -c conda-forge
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 6. Install LeRobot with dependencies for the feetech motors:
|
||||||
|
```bash
|
||||||
|
cd ~/lerobot && pip install -e ".[feetech]"
|
||||||
```
|
```
|
||||||
|
|
||||||
## C. Install LeRobot on laptop
|
## C. Install LeRobot on laptop
|
||||||
|
@ -108,9 +116,15 @@ conda activate lerobot
|
||||||
git clone https://github.com/huggingface/lerobot.git ~/lerobot
|
git clone https://github.com/huggingface/lerobot.git ~/lerobot
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 5. Install LeRobot with dependencies for the feetech motors:
|
#### 5. Install ffmpeg in your environment:
|
||||||
|
When using `miniconda`, install `ffmpeg` in your environment:
|
||||||
```bash
|
```bash
|
||||||
cd ~/lerobot && pip install --no-binary=av -e ".[feetech]"
|
conda install ffmpeg -c conda-forge
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 6. Install LeRobot with dependencies for the feetech motors:
|
||||||
|
```bash
|
||||||
|
cd ~/lerobot && pip install -e ".[feetech]"
|
||||||
```
|
```
|
||||||
|
|
||||||
Great :hugs:! You are now done installing LeRobot and we can begin assembling the SO100 arms and Mobile base :robot:.
|
Great :hugs:! You are now done installing LeRobot and we can begin assembling the SO100 arms and Mobile base :robot:.
|
||||||
|
@ -182,11 +196,11 @@ sudo chmod 666 /dev/ttyACM1
|
||||||
|
|
||||||
#### d. Update config file
|
#### d. Update config file
|
||||||
|
|
||||||
IMPORTANTLY: Now that you have your ports of leader and follower arm and ip address of the mobile-so100, update the **ip** in Network configuration, **port** in leader_arms and **port** in lekiwi. In the [`LeKiwiRobotConfig`](../lerobot/common/robot_devices/robots/configs.py) file. Where you will find something like:
|
IMPORTANTLY: Now that you have your ports of leader and follower arm and ip address of the mobile-so100, update the **ip** in Network configuration, **port** in leader_arms and **port** in lekiwi. In the [`LeKiwiConfig`](../lerobot/common/robot_devices/robots/configs.py) file. Where you will find something like:
|
||||||
```python
|
```python
|
||||||
@RobotConfig.register_subclass("lekiwi")
|
@RobotConfig.register_subclass("lekiwi")
|
||||||
@dataclass
|
@dataclass
|
||||||
class LeKiwiRobotConfig(RobotConfig):
|
class LeKiwiConfig(RobotConfig):
|
||||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||||
# the number of motors in your follower arms.
|
# the number of motors in your follower arms.
|
||||||
|
@ -269,7 +283,7 @@ For the wired LeKiwi version your configured IP address should refer to your own
|
||||||
```python
|
```python
|
||||||
@RobotConfig.register_subclass("lekiwi")
|
@RobotConfig.register_subclass("lekiwi")
|
||||||
@dataclass
|
@dataclass
|
||||||
class LeKiwiRobotConfig(RobotConfig):
|
class LeKiwiConfig(RobotConfig):
|
||||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||||
# the number of motors in your follower arms.
|
# the number of motors in your follower arms.
|
||||||
|
@ -412,6 +426,8 @@ python lerobot/scripts/control_robot.py \
|
||||||
--control.fps=30
|
--control.fps=30
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> **NOTE:** To visualize the data, enable `--control.display_data=true`. This streams the data using `rerun`. For the `--control.type=remote_robot` you will also need to set `--control.viewer_ip` and `--control.viewer_port`
|
||||||
|
|
||||||
You should see on your laptop something like this: ```[INFO] Connected to remote robot at tcp://172.17.133.91:5555 and video stream at tcp://172.17.133.91:5556.``` Now you can move the leader arm and use the keyboard (w,a,s,d) to drive forward, left, backwards, right. And use (z,x) to turn left or turn right. You can use (r,f) to increase and decrease the speed of the mobile robot. There are three speed modes, see the table below:
|
You should see on your laptop something like this: ```[INFO] Connected to remote robot at tcp://172.17.133.91:5555 and video stream at tcp://172.17.133.91:5556.``` Now you can move the leader arm and use the keyboard (w,a,s,d) to drive forward, left, backwards, right. And use (z,x) to turn left or turn right. You can use (r,f) to increase and decrease the speed of the mobile robot. There are three speed modes, see the table below:
|
||||||
| Speed Mode | Linear Speed (m/s) | Rotation Speed (deg/s) |
|
| Speed Mode | Linear Speed (m/s) | Rotation Speed (deg/s) |
|
||||||
| ---------- | ------------------ | ---------------------- |
|
| ---------- | ------------------ | ---------------------- |
|
||||||
|
@ -432,7 +448,7 @@ You should see on your laptop something like this: ```[INFO] Connected to remote
|
||||||
| F | Decrease speed |
|
| F | Decrease speed |
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> If you use a different keyboard you can change the keys for each command in the [`LeKiwiRobotConfig`](../lerobot/common/robot_devices/robots/configs.py).
|
> If you use a different keyboard you can change the keys for each command in the [`LeKiwiConfig`](../lerobot/common/robot_devices/robots/configs.py).
|
||||||
|
|
||||||
### Wired version
|
### Wired version
|
||||||
If you have the **wired** LeKiwi version please run all commands including both these teleoperation commands on your laptop.
|
If you have the **wired** LeKiwi version please run all commands including both these teleoperation commands on your laptop.
|
||||||
|
|
|
@ -0,0 +1,3 @@
|
||||||
|
from .config_lekiwi import LeKiwiClientConfig, LeKiwiConfig
|
||||||
|
from .lekiwi import LeKiwi
|
||||||
|
from .lekiwi_client import LeKiwiClient
|
|
@ -0,0 +1,85 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import enum
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from lerobot.common.cameras.configs import CameraConfig
|
||||||
|
from lerobot.common.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||||
|
|
||||||
|
from ..config import RobotConfig
|
||||||
|
|
||||||
|
|
||||||
|
class RobotMode(enum.Enum):
|
||||||
|
TELEOP = 0
|
||||||
|
AUTO = 1
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(Steven): Consider sending config at initial step over a socket
|
||||||
|
# However, this isn't practical because anyways we have to configure the
|
||||||
|
# socket ports to begin with
|
||||||
|
@RobotConfig.register_subclass("lekiwi")
|
||||||
|
@dataclass
|
||||||
|
class LeKiwiConfig(RobotConfig):
|
||||||
|
port = "/dev/ttyACM0" # port to connect to the bus
|
||||||
|
|
||||||
|
disable_torque_on_disconnect: bool = True
|
||||||
|
|
||||||
|
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||||
|
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||||
|
# the number of motors in your follower arms.
|
||||||
|
max_relative_target: int | None = None
|
||||||
|
|
||||||
|
cameras: dict[str, CameraConfig] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"front": OpenCVCameraConfig(
|
||||||
|
camera_index="/dev/video1", fps=30, width=640, height=480, rotation=90
|
||||||
|
),
|
||||||
|
"wrist": OpenCVCameraConfig(
|
||||||
|
camera_index="/dev/video4", fps=30, width=640, height=480, rotation=180
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Network Configuration
|
||||||
|
port_zmq_cmd: int = 5555
|
||||||
|
port_zmq_observations: int = 5556
|
||||||
|
|
||||||
|
|
||||||
|
@RobotConfig.register_subclass("lekiwi_client")
|
||||||
|
@dataclass
|
||||||
|
class LeKiwiClientConfig(RobotConfig):
|
||||||
|
# Network Configuration
|
||||||
|
remote_ip: str = "172.18.129.208"
|
||||||
|
port_zmq_cmd: int = 5555
|
||||||
|
port_zmq_observations: int = 5556
|
||||||
|
|
||||||
|
teleop_keys: dict[str, str] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
# Movement
|
||||||
|
"forward": "w",
|
||||||
|
"backward": "s",
|
||||||
|
"left": "a",
|
||||||
|
"right": "d",
|
||||||
|
"rotate_left": "z",
|
||||||
|
"rotate_right": "x",
|
||||||
|
# Speed control
|
||||||
|
"speed_up": "r",
|
||||||
|
"speed_down": "f",
|
||||||
|
# quit teleop
|
||||||
|
"quit": "q",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
robot_mode: RobotMode | None = None
|
|
@ -1,89 +0,0 @@
|
||||||
from dataclasses import dataclass, field
|
|
||||||
|
|
||||||
from lerobot.common.cameras.configs import CameraConfig
|
|
||||||
from lerobot.common.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
|
||||||
from lerobot.common.motors.configs import FeetechMotorsBusConfig, MotorsBusConfig
|
|
||||||
from lerobot.common.robots.config import RobotConfig
|
|
||||||
|
|
||||||
|
|
||||||
@RobotConfig.register_subclass("lekiwi")
|
|
||||||
@dataclass
|
|
||||||
class LeKiwiRobotConfig(RobotConfig):
|
|
||||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
|
||||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
|
||||||
# the number of motors in your follower arms.
|
|
||||||
max_relative_target: int | None = None
|
|
||||||
|
|
||||||
# Network Configuration
|
|
||||||
ip: str = "192.168.0.193"
|
|
||||||
port: int = 5555
|
|
||||||
video_port: int = 5556
|
|
||||||
|
|
||||||
cameras: dict[str, CameraConfig] = field(
|
|
||||||
default_factory=lambda: {
|
|
||||||
"front": OpenCVCameraConfig(
|
|
||||||
camera_index="/dev/video0", fps=30, width=640, height=480, rotation=90
|
|
||||||
),
|
|
||||||
"wrist": OpenCVCameraConfig(
|
|
||||||
camera_index="/dev/video2", fps=30, width=640, height=480, rotation=180
|
|
||||||
),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
calibration_dir: str = ".cache/calibration/lekiwi"
|
|
||||||
|
|
||||||
leader_arms: dict[str, MotorsBusConfig] = field(
|
|
||||||
default_factory=lambda: {
|
|
||||||
"main": FeetechMotorsBusConfig(
|
|
||||||
port="/dev/tty.usbmodem585A0077581",
|
|
||||||
motors={
|
|
||||||
# name: (index, model)
|
|
||||||
"shoulder_pan": [1, "sts3215"],
|
|
||||||
"shoulder_lift": [2, "sts3215"],
|
|
||||||
"elbow_flex": [3, "sts3215"],
|
|
||||||
"wrist_flex": [4, "sts3215"],
|
|
||||||
"wrist_roll": [5, "sts3215"],
|
|
||||||
"gripper": [6, "sts3215"],
|
|
||||||
},
|
|
||||||
),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
follower_arms: dict[str, MotorsBusConfig] = field(
|
|
||||||
default_factory=lambda: {
|
|
||||||
"main": FeetechMotorsBusConfig(
|
|
||||||
port="/dev/ttyACM0",
|
|
||||||
motors={
|
|
||||||
# name: (index, model)
|
|
||||||
"shoulder_pan": [1, "sts3215"],
|
|
||||||
"shoulder_lift": [2, "sts3215"],
|
|
||||||
"elbow_flex": [3, "sts3215"],
|
|
||||||
"wrist_flex": [4, "sts3215"],
|
|
||||||
"wrist_roll": [5, "sts3215"],
|
|
||||||
"gripper": [6, "sts3215"],
|
|
||||||
"left_wheel": (7, "sts3215"),
|
|
||||||
"back_wheel": (8, "sts3215"),
|
|
||||||
"right_wheel": (9, "sts3215"),
|
|
||||||
},
|
|
||||||
),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
teleop_keys: dict[str, str] = field(
|
|
||||||
default_factory=lambda: {
|
|
||||||
# Movement
|
|
||||||
"forward": "w",
|
|
||||||
"backward": "s",
|
|
||||||
"left": "a",
|
|
||||||
"right": "d",
|
|
||||||
"rotate_left": "z",
|
|
||||||
"rotate_right": "x",
|
|
||||||
# Speed control
|
|
||||||
"speed_up": "r",
|
|
||||||
"speed_down": "f",
|
|
||||||
# quit teleop
|
|
||||||
"quit": "q",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
mock: bool = False
|
|
|
@ -0,0 +1,261 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from lerobot.common.cameras.utils import make_cameras_from_configs
|
||||||
|
from lerobot.common.constants import OBS_IMAGES, OBS_STATE
|
||||||
|
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||||
|
from lerobot.common.motors import Motor, MotorCalibration, MotorNormMode
|
||||||
|
from lerobot.common.motors.feetech import (
|
||||||
|
FeetechMotorsBus,
|
||||||
|
OperatingMode,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ..robot import Robot
|
||||||
|
from ..utils import ensure_safe_goal_position
|
||||||
|
from .config_lekiwi import LeKiwiConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LeKiwi(Robot):
|
||||||
|
"""
|
||||||
|
The robot includes a three omniwheel mobile base and a remote follower arm.
|
||||||
|
The leader arm is connected locally (on the laptop) and its joint positions are recorded and then
|
||||||
|
forwarded to the remote follower arm (after applying a safety clamp).
|
||||||
|
In parallel, keyboard teleoperation is used to generate raw velocity commands for the wheels.
|
||||||
|
"""
|
||||||
|
|
||||||
|
config_class = LeKiwiConfig
|
||||||
|
name = "lekiwi"
|
||||||
|
|
||||||
|
def __init__(self, config: LeKiwiConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
self.config = config
|
||||||
|
self.id = config.id
|
||||||
|
self.bus = FeetechMotorsBus(
|
||||||
|
port=self.config.port,
|
||||||
|
motors={
|
||||||
|
# arm
|
||||||
|
"arm_shoulder_pan": Motor(1, "sts3215", MotorNormMode.RANGE_M100_100),
|
||||||
|
"arm_shoulder_lift": Motor(2, "sts3215", MotorNormMode.RANGE_M100_100),
|
||||||
|
"arm_elbow_flex": Motor(3, "sts3215", MotorNormMode.RANGE_M100_100),
|
||||||
|
"arm_wrist_flex": Motor(4, "sts3215", MotorNormMode.RANGE_M100_100),
|
||||||
|
"arm_wrist_roll": Motor(5, "sts3215", MotorNormMode.RANGE_M100_100),
|
||||||
|
"arm_gripper": Motor(6, "sts3215", MotorNormMode.RANGE_0_100),
|
||||||
|
# base
|
||||||
|
"base_left_wheel": Motor(7, "sts3215", MotorNormMode.RANGE_M100_100),
|
||||||
|
"base_right_wheel": Motor(8, "sts3215", MotorNormMode.RANGE_M100_100),
|
||||||
|
"base_back_wheel": Motor(9, "sts3215", MotorNormMode.RANGE_M100_100),
|
||||||
|
},
|
||||||
|
calibration=self.calibration,
|
||||||
|
)
|
||||||
|
self.arm_motors = [m for m in self.bus.names if m.startswith("arm")]
|
||||||
|
self.base_motors = [m for m in self.bus.names if m.startswith("base")]
|
||||||
|
self.cameras = make_cameras_from_configs(config.cameras)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state_feature(self) -> dict:
|
||||||
|
return {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (len(self.bus),),
|
||||||
|
"names": {"motors": list(self.bus.motors)},
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def action_feature(self) -> dict:
|
||||||
|
return self.state_feature
|
||||||
|
|
||||||
|
@property
|
||||||
|
def camera_features(self) -> dict[str, dict]:
|
||||||
|
cam_ft = {}
|
||||||
|
for cam_key, cam in self.cameras.items():
|
||||||
|
cam_ft[cam_key] = {
|
||||||
|
"shape": (cam.height, cam.width, cam.channels),
|
||||||
|
"names": ["height", "width", "channels"],
|
||||||
|
"info": None,
|
||||||
|
}
|
||||||
|
return cam_ft
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_connected(self) -> bool:
|
||||||
|
# TODO(aliberts): add cam.is_connected for cam in self.cameras
|
||||||
|
return self.bus.is_connected
|
||||||
|
|
||||||
|
def connect(self) -> None:
|
||||||
|
if self.is_connected:
|
||||||
|
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||||
|
|
||||||
|
self.bus.connect()
|
||||||
|
if not self.is_calibrated:
|
||||||
|
self.calibrate()
|
||||||
|
|
||||||
|
for cam in self.cameras.values():
|
||||||
|
cam.connect()
|
||||||
|
|
||||||
|
self.configure()
|
||||||
|
logger.info(f"{self} connected.")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_calibrated(self) -> bool:
|
||||||
|
return self.bus.is_calibrated
|
||||||
|
|
||||||
|
def calibrate(self) -> None:
|
||||||
|
logger.info(f"\nRunning calibration of {self}")
|
||||||
|
|
||||||
|
motors = self.arm_motors + self.base_motors
|
||||||
|
|
||||||
|
self.bus.disable_torque(self.arm_motors)
|
||||||
|
for name in self.arm_motors:
|
||||||
|
self.bus.write("Operating_Mode", name, OperatingMode.POSITION.value)
|
||||||
|
|
||||||
|
input("Move robot to the middle of its range of motion and press ENTER....")
|
||||||
|
homing_offsets = self.bus.set_half_turn_homings(motors)
|
||||||
|
|
||||||
|
# TODO(Steven): Previously homig_offsets was called only on self.arm_motors
|
||||||
|
# After a discussion, we said it was better to keep it like this and then
|
||||||
|
# just populate with the rest of motors. However, I don't know which value
|
||||||
|
# should we use for this
|
||||||
|
# homing_offsets.update({k,None???} for k in self.base_motors)
|
||||||
|
|
||||||
|
# TODO(Steven): Might be worth to do this also in other robots but it should be added in the docs
|
||||||
|
full_turn_motor = [
|
||||||
|
motor for motor in motors if any(keyword in motor for keyword in ["wheel", "wrist"])
|
||||||
|
]
|
||||||
|
unknown_range_motors = [motor for motor in motors if motor not in full_turn_motor]
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Move all arm joints except '{full_turn_motor}' sequentially through their "
|
||||||
|
"entire ranges of motion.\nRecording positions. Press ENTER to stop..."
|
||||||
|
)
|
||||||
|
range_mins, range_maxes = self.bus.record_ranges_of_motion(unknown_range_motors)
|
||||||
|
for name in full_turn_motor:
|
||||||
|
range_mins[name] = 0
|
||||||
|
range_maxes[name] = 4095
|
||||||
|
|
||||||
|
self.calibration = {}
|
||||||
|
for name, motor in self.bus.motors.items():
|
||||||
|
self.calibration[name] = MotorCalibration(
|
||||||
|
id=motor.id,
|
||||||
|
drive_mode=0,
|
||||||
|
homing_offset=homing_offsets[name],
|
||||||
|
range_min=range_mins[name],
|
||||||
|
range_max=range_maxes[name],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.bus.write_calibration(self.calibration)
|
||||||
|
self._save_calibration()
|
||||||
|
print("Calibration saved to", self.calibration_fpath)
|
||||||
|
|
||||||
|
def configure(self):
|
||||||
|
# Set-up arm actuators (position mode)
|
||||||
|
# We assume that at connection time, arm is in a rest position,
|
||||||
|
# and torque can be safely disabled to run calibration.
|
||||||
|
self.bus.disable_torque(self.arm_motors)
|
||||||
|
for name in self.arm_motors:
|
||||||
|
self.bus.write("Operating_Mode", name, OperatingMode.POSITION.value)
|
||||||
|
# Set P_Coefficient to lower value to avoid shakiness (Default is 32)
|
||||||
|
self.bus.write("P_Coefficient", name, 16)
|
||||||
|
# Set I_Coefficient and D_Coefficient to default value 0 and 32
|
||||||
|
self.bus.write("I_Coefficient", name, 0)
|
||||||
|
self.bus.write("D_Coefficient", name, 32)
|
||||||
|
# Set Maximum_Acceleration to 254 to speedup acceleration and deceleration of
|
||||||
|
# the motors. Note: this configuration is not in the official STS3215 Memory Table
|
||||||
|
self.bus.write("Maximum_Acceleration", name, 254)
|
||||||
|
self.bus.write("Acceleration", name, 254)
|
||||||
|
|
||||||
|
for name in self.base_motors:
|
||||||
|
self.bus.write("Operating_Mode", name, OperatingMode.VELOCITY.value)
|
||||||
|
|
||||||
|
self.bus.enable_torque()
|
||||||
|
|
||||||
|
def get_observation(self) -> dict[str, Any]:
|
||||||
|
if not self.is_connected:
|
||||||
|
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||||
|
|
||||||
|
obs_dict = {OBS_IMAGES: {}}
|
||||||
|
|
||||||
|
# Read actuators position for arm and vel for base
|
||||||
|
start = time.perf_counter()
|
||||||
|
arm_pos = self.bus.sync_read("Present_Position", self.arm_motors)
|
||||||
|
base_vel = self.bus.sync_read("Present_Velocity", self.base_motors)
|
||||||
|
obs_dict[OBS_STATE] = {**arm_pos, **base_vel}
|
||||||
|
dt_ms = (time.perf_counter() - start) * 1e3
|
||||||
|
logger.debug(f"{self} read state: {dt_ms:.1f}ms")
|
||||||
|
|
||||||
|
# Capture images from cameras
|
||||||
|
for cam_key, cam in self.cameras.items():
|
||||||
|
start = time.perf_counter()
|
||||||
|
obs_dict[OBS_IMAGES][cam_key] = cam.async_read()
|
||||||
|
dt_ms = (time.perf_counter() - start) * 1e3
|
||||||
|
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||||
|
|
||||||
|
return obs_dict
|
||||||
|
|
||||||
|
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
# Copied from S100 robot
|
||||||
|
"""Command lekiwi to move to a target joint configuration.
|
||||||
|
|
||||||
|
The relative action magnitude may be clipped depending on the configuration parameter
|
||||||
|
`max_relative_target`. In this case, the action sent differs from original action.
|
||||||
|
Thus, this function always returns the action actually sent.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RobotDeviceNotConnectedError: if robot is not connected.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: the action sent to the motors, potentially clipped.
|
||||||
|
"""
|
||||||
|
if not self.is_connected:
|
||||||
|
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||||
|
|
||||||
|
arm_goal_pos = {k: v for k, v in action.items() if k in self.arm_motors}
|
||||||
|
base_goal_vel = {k: v for k, v in action.items() if k in self.base_motors}
|
||||||
|
|
||||||
|
# Cap goal position when too far away from present position.
|
||||||
|
# /!\ Slower fps expected due to reading from the follower.
|
||||||
|
if self.config.max_relative_target is not None:
|
||||||
|
present_pos = self.bus.sync_read("Present_Position", self.arm_motors)
|
||||||
|
goal_present_pos = {key: (g_pos, present_pos[key]) for key, g_pos in arm_goal_pos.items()}
|
||||||
|
arm_safe_goal_pos = ensure_safe_goal_position(goal_present_pos, self.config.max_relative_target)
|
||||||
|
arm_goal_pos = arm_safe_goal_pos
|
||||||
|
|
||||||
|
# TODO(Steven): Message fetching failed: Magnitude 34072 exceeds 32767 (max for sign_bit_index=15)
|
||||||
|
# TODO(Steven): IMO, this should be a check in the motor bus
|
||||||
|
|
||||||
|
# Send goal position to the actuators
|
||||||
|
self.bus.sync_write("Goal_Position", arm_goal_pos)
|
||||||
|
self.bus.sync_write("Goal_Velocity", base_goal_vel)
|
||||||
|
|
||||||
|
return {**arm_goal_pos, **base_goal_vel}
|
||||||
|
|
||||||
|
def stop_base(self):
|
||||||
|
self.bus.sync_write("Goal_Velocity", dict.fromkeys(self.base_motors, 0), num_retry=5)
|
||||||
|
logger.info("Base motors stopped")
|
||||||
|
|
||||||
|
def disconnect(self):
|
||||||
|
if not self.is_connected:
|
||||||
|
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||||
|
|
||||||
|
self.stop_base()
|
||||||
|
self.bus.disconnect(self.config.disable_torque_on_disconnect)
|
||||||
|
for cam in self.cameras.values():
|
||||||
|
cam.disconnect()
|
||||||
|
|
||||||
|
logger.info(f"{self} disconnected.")
|
|
@ -0,0 +1,505 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import zmq
|
||||||
|
|
||||||
|
from lerobot.common.constants import OBS_IMAGES, OBS_STATE
|
||||||
|
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError, InvalidActionError
|
||||||
|
|
||||||
|
from ..robot import Robot
|
||||||
|
from .config_lekiwi import LeKiwiClientConfig, RobotMode
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(Steven): This doesn't need to inherit from Robot
|
||||||
|
# But we do it for now to offer a familiar API
|
||||||
|
# TODO(Steven): This doesn't need to take care of the
|
||||||
|
# mapping from teleop to motor commands, but given that
|
||||||
|
# we already have a middle-man (this class) we add it here
|
||||||
|
# Other options include:
|
||||||
|
# 1. Adding it to the Telop implementation for lekiwi
|
||||||
|
# (meaning each robot will need a teleop imple) or
|
||||||
|
# 2. Adding it into the robot implementation
|
||||||
|
# (meaning the policy might be needed to be train
|
||||||
|
# over the teleop action space)
|
||||||
|
# TODO(Steven): Check if we can move everything to 32 instead
|
||||||
|
class LeKiwiClient(Robot):
|
||||||
|
config_class = LeKiwiClientConfig
|
||||||
|
name = "lekiwi_client"
|
||||||
|
|
||||||
|
def __init__(self, config: LeKiwiClientConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
self.config = config
|
||||||
|
self.id = config.id
|
||||||
|
self.robot_type = config.type
|
||||||
|
self.robot_mode = config.robot_mode
|
||||||
|
|
||||||
|
self.remote_ip = config.remote_ip
|
||||||
|
self.port_zmq_cmd = config.port_zmq_cmd
|
||||||
|
self.port_zmq_observations = config.port_zmq_observations
|
||||||
|
|
||||||
|
self.teleop_keys = config.teleop_keys
|
||||||
|
|
||||||
|
self.zmq_context = None
|
||||||
|
self.zmq_cmd_socket = None
|
||||||
|
self.zmq_observation_socket = None
|
||||||
|
|
||||||
|
self.last_frames = {}
|
||||||
|
self.last_present_speed = {"x_cmd": 0.0, "y_cmd": 0.0, "theta_cmd": 0.0}
|
||||||
|
|
||||||
|
self.last_remote_arm_state = {}
|
||||||
|
|
||||||
|
# Define three speed levels and a current index
|
||||||
|
self.speed_levels = [
|
||||||
|
{"xy": 0.1, "theta": 30}, # slow
|
||||||
|
{"xy": 0.2, "theta": 60}, # medium
|
||||||
|
{"xy": 0.3, "theta": 90}, # fast
|
||||||
|
]
|
||||||
|
self.speed_index = 0 # Start at slow
|
||||||
|
|
||||||
|
self._is_connected = False
|
||||||
|
self.logs = {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state_feature(self) -> dict:
|
||||||
|
# TODO(Steven): Get this from the data fetched? Motor names are unknown for the Daemon
|
||||||
|
# For now we assume its size/metadata is known
|
||||||
|
return {
|
||||||
|
"dtype": "float64",
|
||||||
|
"shape": (9,),
|
||||||
|
"names": {
|
||||||
|
"motors": [
|
||||||
|
"arm_shoulder_pan",
|
||||||
|
"arm_shoulder_lift",
|
||||||
|
"arm_elbow_flex",
|
||||||
|
"arm_wrist_flex",
|
||||||
|
"arm_wrist_roll",
|
||||||
|
"arm_gripper",
|
||||||
|
"base_left_wheel",
|
||||||
|
"base_right_wheel",
|
||||||
|
"base_back_wheel",
|
||||||
|
]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def action_feature(self) -> dict:
|
||||||
|
return self.state_feature
|
||||||
|
|
||||||
|
@property
|
||||||
|
def camera_features(self) -> dict[str, dict]:
|
||||||
|
# TODO(Steven): Get this from the data fetched? Motor names are unknown for the Daemon
|
||||||
|
# For now we assume its size/metadata is known
|
||||||
|
# TODO(Steven): Check consistency of image sizes
|
||||||
|
cam_ft = {
|
||||||
|
"front": {
|
||||||
|
"shape": (480, 640, 3),
|
||||||
|
"names": ["height", "width", "channels"],
|
||||||
|
"info": None,
|
||||||
|
},
|
||||||
|
"wrist": {
|
||||||
|
"shape": (480, 640, 3),
|
||||||
|
"names": ["height", "width", "channels"],
|
||||||
|
"info": None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return cam_ft
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_connected(self) -> bool:
|
||||||
|
# TODO(Steven): Ideally we could check instead the status of the sockets
|
||||||
|
# I didn't find any API that allows us to do that easily
|
||||||
|
return self._is_connected
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_calibrated(self) -> bool:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def connect(self) -> None:
|
||||||
|
"""Establishes ZMQ sockets with the remote mobile robot"""
|
||||||
|
|
||||||
|
# TODO(Steven): Consider instead returning a bool + warn
|
||||||
|
if self._is_connected:
|
||||||
|
raise DeviceAlreadyConnectedError(
|
||||||
|
"LeKiwi Daemon is already connected. Do not run `robot.connect()` twice."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.zmq_context = zmq.Context()
|
||||||
|
self.zmq_cmd_socket = self.zmq_context.socket(zmq.PUSH)
|
||||||
|
zmq_cmd_locator = f"tcp://{self.remote_ip}:{self.port_zmq_cmd}"
|
||||||
|
self.zmq_cmd_socket.connect(zmq_cmd_locator)
|
||||||
|
self.zmq_cmd_socket.setsockopt(zmq.CONFLATE, 1)
|
||||||
|
|
||||||
|
self.zmq_observation_socket = self.zmq_context.socket(zmq.PULL)
|
||||||
|
zmq_observations_locator = f"tcp://{self.remote_ip}:{self.port_zmq_observations}"
|
||||||
|
self.zmq_observation_socket.connect(zmq_observations_locator)
|
||||||
|
self.zmq_observation_socket.setsockopt(zmq.CONFLATE, 1)
|
||||||
|
|
||||||
|
self._is_connected = True
|
||||||
|
|
||||||
|
def calibrate(self) -> None:
|
||||||
|
logging.warning("LeKiwiClient has nothing to calibrate.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Consider moving these static functions out of the class
|
||||||
|
# Copied from robot_lekiwi MobileManipulator class* (before the refactor)
|
||||||
|
@staticmethod
|
||||||
|
def _degps_to_raw(degps: float) -> int:
|
||||||
|
steps_per_deg = 4096.0 / 360.0
|
||||||
|
speed_in_steps = degps * steps_per_deg
|
||||||
|
speed_int = int(round(speed_in_steps))
|
||||||
|
# Cap the value to fit within signed 16-bit range (-32768 to 32767)
|
||||||
|
if speed_int > 0x7FFF:
|
||||||
|
speed_int = 0x7FFF # 32767 -> maximum positive value
|
||||||
|
elif speed_int < -0x8000:
|
||||||
|
speed_int = -0x8000 # -32768 -> minimum negative value
|
||||||
|
return speed_int
|
||||||
|
|
||||||
|
# Copied from robot_lekiwi MobileManipulator class* (before the refactor)
|
||||||
|
@staticmethod
|
||||||
|
def _raw_to_degps(raw_speed: int) -> float:
|
||||||
|
steps_per_deg = 4096.0 / 360.0
|
||||||
|
magnitude = raw_speed
|
||||||
|
degps = magnitude / steps_per_deg
|
||||||
|
return degps
|
||||||
|
|
||||||
|
# Copied from robot_lekiwi MobileManipulator class* (before the refactor)
|
||||||
|
def _body_to_wheel_raw(
|
||||||
|
self,
|
||||||
|
x_cmd: float,
|
||||||
|
y_cmd: float,
|
||||||
|
theta_cmd: float,
|
||||||
|
wheel_radius: float = 0.05,
|
||||||
|
base_radius: float = 0.125,
|
||||||
|
max_raw: int = 3000,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Convert desired body-frame velocities into wheel raw commands.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
x_cmd : Linear velocity in x (m/s).
|
||||||
|
y_cmd : Linear velocity in y (m/s).
|
||||||
|
theta_cmd : Rotational velocity (deg/s).
|
||||||
|
wheel_radius: Radius of each wheel (meters).
|
||||||
|
base_radius : Distance from the center of rotation to each wheel (meters).
|
||||||
|
max_raw : Maximum allowed raw command (ticks) per wheel.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary with wheel raw commands:
|
||||||
|
{"left_wheel": value, "back_wheel": value, "right_wheel": value}.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- Internally, the method converts theta_cmd to rad/s for the kinematics.
|
||||||
|
- The raw command is computed from the wheels angular speed in deg/s
|
||||||
|
using _degps_to_raw(). If any command exceeds max_raw, all commands
|
||||||
|
are scaled down proportionally.
|
||||||
|
"""
|
||||||
|
# Convert rotational velocity from deg/s to rad/s.
|
||||||
|
theta_rad = theta_cmd * (np.pi / 180.0)
|
||||||
|
# Create the body velocity vector [x, y, theta_rad].
|
||||||
|
velocity_vector = np.array([x_cmd, y_cmd, theta_rad])
|
||||||
|
|
||||||
|
# Define the wheel mounting angles with a -90° offset.
|
||||||
|
angles = np.radians(np.array([240, 120, 0]) - 90)
|
||||||
|
# Build the kinematic matrix: each row maps body velocities to a wheel’s linear speed.
|
||||||
|
# The third column (base_radius) accounts for the effect of rotation.
|
||||||
|
m = np.array([[np.cos(a), np.sin(a), base_radius] for a in angles])
|
||||||
|
|
||||||
|
# Compute each wheel’s linear speed (m/s) and then its angular speed (rad/s).
|
||||||
|
wheel_linear_speeds = m.dot(velocity_vector)
|
||||||
|
wheel_angular_speeds = wheel_linear_speeds / wheel_radius
|
||||||
|
|
||||||
|
# Convert wheel angular speeds from rad/s to deg/s.
|
||||||
|
wheel_degps = wheel_angular_speeds * (180.0 / np.pi)
|
||||||
|
|
||||||
|
# Scaling
|
||||||
|
steps_per_deg = 4096.0 / 360.0
|
||||||
|
raw_floats = [abs(degps) * steps_per_deg for degps in wheel_degps]
|
||||||
|
max_raw_computed = max(raw_floats)
|
||||||
|
if max_raw_computed > max_raw:
|
||||||
|
scale = max_raw / max_raw_computed
|
||||||
|
wheel_degps = wheel_degps * scale
|
||||||
|
|
||||||
|
# Convert each wheel’s angular speed (deg/s) to a raw integer.
|
||||||
|
wheel_raw = [LeKiwiClient._degps_to_raw(deg) for deg in wheel_degps]
|
||||||
|
|
||||||
|
# TODO(Steven): remove hard-coded names
|
||||||
|
return {"left_wheel": wheel_raw[0], "back_wheel": wheel_raw[1], "right_wheel": wheel_raw[2]}
|
||||||
|
|
||||||
|
# Copied from robot_lekiwi MobileManipulator class
|
||||||
|
def _wheel_raw_to_body(
|
||||||
|
self, wheel_raw: dict[str, Any], wheel_radius: float = 0.05, base_radius: float = 0.125
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Convert wheel raw command feedback back into body-frame velocities.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
wheel_raw : Vector with raw wheel commands ("left_wheel", "back_wheel", "right_wheel").
|
||||||
|
wheel_radius: Radius of each wheel (meters).
|
||||||
|
base_radius : Distance from the robot center to each wheel (meters).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple (x_cmd, y_cmd, theta_cmd) where:
|
||||||
|
x_cmd : Linear velocity in x (m/s).
|
||||||
|
y_cmd : Linear velocity in y (m/s).
|
||||||
|
theta_cmd : Rotational velocity in deg/s.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# TODO(Steven): No check is done for dict keys
|
||||||
|
# Convert each raw command back to an angular speed in deg/s.
|
||||||
|
wheel_degps = np.array([LeKiwiClient._raw_to_degps(int(v)) for _, v in wheel_raw.items()])
|
||||||
|
# Convert from deg/s to rad/s.
|
||||||
|
wheel_radps = wheel_degps * (np.pi / 180.0)
|
||||||
|
# Compute each wheel’s linear speed (m/s) from its angular speed.
|
||||||
|
wheel_linear_speeds = wheel_radps * wheel_radius
|
||||||
|
|
||||||
|
# Define the wheel mounting angles with a -90° offset.
|
||||||
|
angles = np.radians(np.array([240, 120, 0]) - 90)
|
||||||
|
m = np.array([[np.cos(a), np.sin(a), base_radius] for a in angles])
|
||||||
|
|
||||||
|
# Solve the inverse kinematics: body_velocity = M⁻¹ · wheel_linear_speeds.
|
||||||
|
m_inv = np.linalg.inv(m)
|
||||||
|
velocity_vector = m_inv.dot(wheel_linear_speeds)
|
||||||
|
x_cmd, y_cmd, theta_rad = velocity_vector
|
||||||
|
theta_cmd = theta_rad * (180.0 / np.pi)
|
||||||
|
return {"x_cmd": x_cmd, "y_cmd": y_cmd, "theta_cmd": theta_cmd}
|
||||||
|
|
||||||
|
# TODO(Steven): This is flaky, for example, if we received a state but failed decoding the image, we will not update any value
|
||||||
|
# TODO(Steven): All this function needs to be refactored
|
||||||
|
# Copied from robot_lekiwi MobileManipulator class* (before the refactor)
|
||||||
|
def _get_data(self):
|
||||||
|
# Copied from robot_lekiwi.py
|
||||||
|
"""Polls the video socket for up to 15 ms. If data arrives, decode only
|
||||||
|
the *latest* message, returning frames, speed, and arm state. If
|
||||||
|
nothing arrives for any field, use the last known values."""
|
||||||
|
|
||||||
|
frames = {}
|
||||||
|
present_speed = {}
|
||||||
|
|
||||||
|
remote_arm_state_tensor = {}
|
||||||
|
|
||||||
|
# Poll up to 15 ms
|
||||||
|
poller = zmq.Poller()
|
||||||
|
poller.register(self.zmq_observation_socket, zmq.POLLIN)
|
||||||
|
socks = dict(poller.poll(15))
|
||||||
|
if self.zmq_observation_socket not in socks or socks[self.zmq_observation_socket] != zmq.POLLIN:
|
||||||
|
# No new data arrived → reuse ALL old data
|
||||||
|
# TODO(Steven): This might return empty variables at init
|
||||||
|
return (self.last_frames, self.last_present_speed, self.last_remote_arm_state)
|
||||||
|
|
||||||
|
# Drain all messages, keep only the last
|
||||||
|
last_msg = None
|
||||||
|
# TODO(Steven): There's probably a way to do this without while True
|
||||||
|
# TODO(Steven): Even consider changing to PUB/SUB
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
obs_string = self.zmq_observation_socket.recv_string(zmq.NOBLOCK)
|
||||||
|
last_msg = obs_string
|
||||||
|
except zmq.Again:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not last_msg:
|
||||||
|
# No new message → also reuse old
|
||||||
|
return (self.last_frames, self.last_present_speed, self.last_remote_arm_state)
|
||||||
|
|
||||||
|
# Decode only the final message
|
||||||
|
try:
|
||||||
|
observation = json.loads(last_msg)
|
||||||
|
|
||||||
|
state_observation = observation[OBS_STATE]
|
||||||
|
image_observation = observation[OBS_IMAGES]
|
||||||
|
|
||||||
|
# Convert images
|
||||||
|
for cam_name, image_b64 in image_observation.items():
|
||||||
|
if image_b64:
|
||||||
|
jpg_data = base64.b64decode(image_b64)
|
||||||
|
np_arr = np.frombuffer(jpg_data, dtype=np.uint8)
|
||||||
|
frame_candidate = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
|
||||||
|
if frame_candidate is not None:
|
||||||
|
frames[cam_name] = frame_candidate
|
||||||
|
|
||||||
|
# TODO(Steven): Should we really ignore the arm state if the image is None?
|
||||||
|
# If remote_arm_state is None and frames is None there is no message then use the previous message
|
||||||
|
if state_observation is not None and frames is not None:
|
||||||
|
self.last_frames = frames
|
||||||
|
|
||||||
|
# TODO(Steven): hard-coded name of expected keys, not good
|
||||||
|
remote_arm_state_tensor = {k: v for k, v in state_observation.items() if k.startswith("arm")}
|
||||||
|
self.last_remote_arm_state = remote_arm_state_tensor
|
||||||
|
|
||||||
|
present_speed = {k: v for k, v in state_observation.items() if k.startswith("base")}
|
||||||
|
self.last_present_speed = present_speed
|
||||||
|
else:
|
||||||
|
frames = self.last_frames
|
||||||
|
remote_arm_state_tensor = self.last_remote_arm_state
|
||||||
|
present_speed = self.last_present_speed
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[DEBUG] Error decoding video message: {e}")
|
||||||
|
# If decode fails, fall back to old data
|
||||||
|
return (self.last_frames, self.last_present_speed, self.last_remote_arm_state)
|
||||||
|
return frames, present_speed, remote_arm_state_tensor
|
||||||
|
|
||||||
|
# TODO(Steven): The returned space is different from the get_observation of LeKiwi
|
||||||
|
# This returns body-frames velocities instead of wheel pos/speeds
|
||||||
|
def get_observation(self) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Capture observations from the remote robot: current follower arm positions,
|
||||||
|
present wheel speeds (converted to body-frame velocities: x, y, theta),
|
||||||
|
and a camera frame. Receives over ZMQ, translate to body-frame vel
|
||||||
|
"""
|
||||||
|
if not self._is_connected:
|
||||||
|
raise DeviceNotConnectedError("LeKiwiClient is not connected. You need to run `robot.connect()`.")
|
||||||
|
|
||||||
|
# TODO(Steven): remove hard-coded cam names & dims
|
||||||
|
# This is needed at init for when there's no comms
|
||||||
|
obs_dict = {
|
||||||
|
OBS_IMAGES: {"wrist": np.zeros(shape=(480, 640, 3)), "front": np.zeros(shape=(640, 480, 3))}
|
||||||
|
}
|
||||||
|
|
||||||
|
frames, present_speed, remote_arm_state_tensor = self._get_data()
|
||||||
|
body_state = self._wheel_raw_to_body(present_speed)
|
||||||
|
# TODO(Steven): output is dict[str,Any] and we multiply by 1000.0. This should be more explicit and specify the expected type instead of Any
|
||||||
|
body_state_mm = {k: v * 1000.0 for k, v in body_state.items()} # Convert x,y to mm/s
|
||||||
|
|
||||||
|
obs_dict[OBS_STATE] = {**remote_arm_state_tensor, **body_state_mm}
|
||||||
|
|
||||||
|
# Loop over each configured camera
|
||||||
|
for cam_name, frame in frames.items():
|
||||||
|
if frame is None:
|
||||||
|
# TODO(Steven): Daemon doesn't know camera dimensions (hard-coded for now), consider at least getting them from state features
|
||||||
|
logging.warning("Frame is None")
|
||||||
|
frame = np.zeros((480, 640, 3), dtype=np.uint8)
|
||||||
|
obs_dict[OBS_IMAGES][cam_name] = torch.from_numpy(frame)
|
||||||
|
|
||||||
|
return obs_dict
|
||||||
|
|
||||||
|
def _from_keyboard_to_wheel_action(self, pressed_keys: np.ndarray):
|
||||||
|
# Speed control
|
||||||
|
if self.teleop_keys["speed_up"] in pressed_keys:
|
||||||
|
self.speed_index = min(self.speed_index + 1, 2)
|
||||||
|
if self.teleop_keys["speed_down"] in pressed_keys:
|
||||||
|
self.speed_index = max(self.speed_index - 1, 0)
|
||||||
|
speed_setting = self.speed_levels[self.speed_index]
|
||||||
|
xy_speed = speed_setting["xy"] # e.g. 0.1, 0.25, or 0.4
|
||||||
|
theta_speed = speed_setting["theta"] # e.g. 30, 60, or 90
|
||||||
|
|
||||||
|
x_cmd = 0.0 # m/s forward/backward
|
||||||
|
y_cmd = 0.0 # m/s lateral
|
||||||
|
theta_cmd = 0.0 # deg/s rotation
|
||||||
|
|
||||||
|
if self.teleop_keys["forward"] in pressed_keys:
|
||||||
|
x_cmd += xy_speed
|
||||||
|
if self.teleop_keys["backward"] in pressed_keys:
|
||||||
|
x_cmd -= xy_speed
|
||||||
|
if self.teleop_keys["left"] in pressed_keys:
|
||||||
|
y_cmd += xy_speed
|
||||||
|
if self.teleop_keys["right"] in pressed_keys:
|
||||||
|
y_cmd -= xy_speed
|
||||||
|
if self.teleop_keys["rotate_left"] in pressed_keys:
|
||||||
|
theta_cmd += theta_speed
|
||||||
|
if self.teleop_keys["rotate_right"] in pressed_keys:
|
||||||
|
theta_cmd -= theta_speed
|
||||||
|
return self._body_to_wheel_raw(x_cmd, y_cmd, theta_cmd)
|
||||||
|
|
||||||
|
def configure(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# TODO(Steven): This assumes this call is always called with a keyboard as a teleop device. It breaks if we teleop with other device
|
||||||
|
# TODO(Steven): Doing this mapping in here adds latecy between send_action and movement from the user perspective.
|
||||||
|
# t0: get teleop_cmd
|
||||||
|
# t1: send_action(teleop_cmd)
|
||||||
|
# t2: mapping teleop_cmd -> motor_cmd
|
||||||
|
# t3: execute motor_md
|
||||||
|
# This mapping for other robots/teleop devices might be slower. Doing this in the teleop will make this explicit
|
||||||
|
# t0': get teleop_cmd
|
||||||
|
# t1': mapping teleop_cmd -> motor_cmd
|
||||||
|
# t2': send_action(motor_cmd)
|
||||||
|
# t3': execute motor_cmd
|
||||||
|
# t3'-t2' << t3-t1
|
||||||
|
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Command lekiwi to move to a target joint configuration. Translates to motor space + sends over ZMQ
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action (np.ndarray): array containing the goal positions for the motors.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RobotDeviceNotConnectedError: if robot is not connected.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: the action sent to the motors, potentially clipped.
|
||||||
|
"""
|
||||||
|
if not self._is_connected:
|
||||||
|
raise DeviceNotConnectedError(
|
||||||
|
"ManipulatorRobot is not connected. You need to run `robot.connect()`."
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.robot_mode is RobotMode.AUTO:
|
||||||
|
# TODO(Steven): Not yet implemented. The policy outputs might need a different conversion
|
||||||
|
raise InvalidActionError("Policy output as action input is not yet well defined")
|
||||||
|
|
||||||
|
goal_pos = {}
|
||||||
|
# TODO(Steven): This assumes teleop mode is always used with keyboard. Tomorrow we could teleop with another device ... ?
|
||||||
|
if self.robot_mode is RobotMode.TELEOP:
|
||||||
|
motors_name = self.state_feature.get("names").get("motors")
|
||||||
|
|
||||||
|
common_keys = [
|
||||||
|
key for key in action if key in (motor.replace("arm_", "") for motor in motors_name)
|
||||||
|
]
|
||||||
|
|
||||||
|
# TODO(Steven): I don't like this
|
||||||
|
if len(common_keys) < 6:
|
||||||
|
logging.error("Action should include at least the states of the leader arm")
|
||||||
|
raise InvalidActionError
|
||||||
|
|
||||||
|
arm_actions = {"arm_" + arm_motor: action[arm_motor] for arm_motor in common_keys}
|
||||||
|
goal_pos = arm_actions
|
||||||
|
|
||||||
|
if len(action) > 6:
|
||||||
|
keyboard_keys = np.array(list(set(action.keys()) - set(common_keys)))
|
||||||
|
wheel_actions = {
|
||||||
|
"base_" + k: v for k, v in self._from_keyboard_to_wheel_action(keyboard_keys).items()
|
||||||
|
}
|
||||||
|
goal_pos = {**arm_actions, **wheel_actions}
|
||||||
|
|
||||||
|
self.zmq_cmd_socket.send_string(json.dumps(goal_pos)) # action is in motor space
|
||||||
|
|
||||||
|
return goal_pos
|
||||||
|
|
||||||
|
def print_logs(self):
|
||||||
|
# TODO(Steven): Refactor logger
|
||||||
|
pass
|
||||||
|
|
||||||
|
def disconnect(self):
|
||||||
|
"""Cleans ZMQ comms"""
|
||||||
|
|
||||||
|
if not self._is_connected:
|
||||||
|
raise DeviceNotConnectedError(
|
||||||
|
"LeKiwi is not connected. You need to run `robot.connect()` before disconnecting."
|
||||||
|
)
|
||||||
|
self.zmq_observation_socket.close()
|
||||||
|
self.zmq_cmd_socket.close()
|
||||||
|
self.zmq_context.term()
|
||||||
|
self._is_connected = False
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
if getattr(self, "is_connected", False):
|
||||||
|
self.disconnect()
|
|
@ -0,0 +1,117 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import zmq
|
||||||
|
|
||||||
|
from lerobot.common.constants import OBS_IMAGES
|
||||||
|
|
||||||
|
from .config_lekiwi import LeKiwiConfig
|
||||||
|
from .lekiwi import LeKiwi
|
||||||
|
|
||||||
|
|
||||||
|
class HostAgent:
|
||||||
|
def __init__(self, port_zmq_cmd, port_zmq_observations):
|
||||||
|
self.zmq_context = zmq.Context()
|
||||||
|
self.zmq_cmd_socket = self.zmq_context.socket(zmq.PULL)
|
||||||
|
self.zmq_cmd_socket.setsockopt(zmq.CONFLATE, 1)
|
||||||
|
self.zmq_cmd_socket.bind(f"tcp://*:{port_zmq_cmd}")
|
||||||
|
|
||||||
|
self.zmq_observation_socket = self.zmq_context.socket(zmq.PUSH)
|
||||||
|
self.zmq_observation_socket.setsockopt(zmq.CONFLATE, 1)
|
||||||
|
self.zmq_observation_socket.bind(f"tcp://*:{port_zmq_observations}")
|
||||||
|
|
||||||
|
def disconnect(self):
|
||||||
|
self.zmq_observation_socket.close()
|
||||||
|
self.zmq_cmd_socket.close()
|
||||||
|
self.zmq_context.term()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
logging.info("Configuring LeKiwi")
|
||||||
|
robot_config = LeKiwiConfig()
|
||||||
|
robot = LeKiwi(robot_config)
|
||||||
|
|
||||||
|
logging.info("Connecting LeKiwi")
|
||||||
|
robot.connect()
|
||||||
|
|
||||||
|
logging.info("Starting HostAgent")
|
||||||
|
remote_agent = HostAgent(robot_config.port_zmq_cmd, robot_config.port_zmq_observations)
|
||||||
|
|
||||||
|
last_cmd_time = time.time()
|
||||||
|
logging.info("Waiting for commands...")
|
||||||
|
try:
|
||||||
|
# Business logic
|
||||||
|
start = time.perf_counter()
|
||||||
|
duration = 0
|
||||||
|
while duration < 100:
|
||||||
|
loop_start_time = time.time()
|
||||||
|
try:
|
||||||
|
msg = remote_agent.zmq_cmd_socket.recv_string(zmq.NOBLOCK)
|
||||||
|
data = dict(json.loads(msg))
|
||||||
|
_action_sent = robot.send_action(data)
|
||||||
|
last_cmd_time = time.time()
|
||||||
|
except zmq.Again:
|
||||||
|
logging.warning("No command available")
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("Message fetching failed: %s", e)
|
||||||
|
|
||||||
|
# TODO(Steven): Check this value
|
||||||
|
# Watchdog: stop the robot if no command is received for over 0.5 seconds.
|
||||||
|
now = time.time()
|
||||||
|
if now - last_cmd_time > 0.5:
|
||||||
|
robot.stop_base()
|
||||||
|
|
||||||
|
last_observation = robot.get_observation()
|
||||||
|
|
||||||
|
# Encode ndarrays to base64 strings
|
||||||
|
for cam_key, _ in robot.cameras.items():
|
||||||
|
ret, buffer = cv2.imencode(
|
||||||
|
".jpg", last_observation[OBS_IMAGES][cam_key], [int(cv2.IMWRITE_JPEG_QUALITY), 90]
|
||||||
|
)
|
||||||
|
if ret:
|
||||||
|
last_observation[OBS_IMAGES][cam_key] = base64.b64encode(buffer).decode("utf-8")
|
||||||
|
else:
|
||||||
|
last_observation[OBS_IMAGES][cam_key] = ""
|
||||||
|
|
||||||
|
# Send the observation to the remote agent
|
||||||
|
remote_agent.zmq_observation_socket.send_string(json.dumps(last_observation))
|
||||||
|
|
||||||
|
# Ensure a short sleep to avoid overloading the CPU.
|
||||||
|
elapsed = time.time() - loop_start_time
|
||||||
|
|
||||||
|
# TODO(Steven): Check this value
|
||||||
|
time.sleep(
|
||||||
|
max(0.033 - elapsed, 0)
|
||||||
|
) # If robot jitters increase the sleep and monitor cpu load with `top` in cmd
|
||||||
|
duration = time.perf_counter() - start
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("Shutting down LeKiwi server.")
|
||||||
|
finally:
|
||||||
|
robot.disconnect()
|
||||||
|
remote_agent.disconnect()
|
||||||
|
|
||||||
|
logging.info("Finished LeKiwi cleanly")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -1,224 +0,0 @@
|
||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import base64
|
|
||||||
import json
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import zmq
|
|
||||||
|
|
||||||
from lerobot.common.robots.mobile_manipulator import LeKiwi
|
|
||||||
|
|
||||||
|
|
||||||
def setup_zmq_sockets(config):
|
|
||||||
context = zmq.Context()
|
|
||||||
cmd_socket = context.socket(zmq.PULL)
|
|
||||||
cmd_socket.setsockopt(zmq.CONFLATE, 1)
|
|
||||||
cmd_socket.bind(f"tcp://*:{config.port}")
|
|
||||||
|
|
||||||
video_socket = context.socket(zmq.PUSH)
|
|
||||||
video_socket.setsockopt(zmq.CONFLATE, 1)
|
|
||||||
video_socket.bind(f"tcp://*:{config.video_port}")
|
|
||||||
|
|
||||||
return context, cmd_socket, video_socket
|
|
||||||
|
|
||||||
|
|
||||||
def run_camera_capture(cameras, images_lock, latest_images_dict, stop_event):
|
|
||||||
while not stop_event.is_set():
|
|
||||||
local_dict = {}
|
|
||||||
for name, cam in cameras.items():
|
|
||||||
frame = cam.async_read()
|
|
||||||
ret, buffer = cv2.imencode(".jpg", frame, [int(cv2.IMWRITE_JPEG_QUALITY), 90])
|
|
||||||
if ret:
|
|
||||||
local_dict[name] = base64.b64encode(buffer).decode("utf-8")
|
|
||||||
else:
|
|
||||||
local_dict[name] = ""
|
|
||||||
with images_lock:
|
|
||||||
latest_images_dict.update(local_dict)
|
|
||||||
time.sleep(0.01)
|
|
||||||
|
|
||||||
|
|
||||||
def calibrate_follower_arm(motors_bus, calib_dir_str):
|
|
||||||
"""
|
|
||||||
Calibrates the follower arm. Attempts to load an existing calibration file;
|
|
||||||
if not found, runs manual calibration and saves the result.
|
|
||||||
"""
|
|
||||||
calib_dir = Path(calib_dir_str)
|
|
||||||
calib_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
calib_file = calib_dir / "main_follower.json"
|
|
||||||
try:
|
|
||||||
from lerobot.common.motors.feetech.feetech_calibration import run_full_arm_calibration
|
|
||||||
except ImportError:
|
|
||||||
print("[WARNING] Calibration function not available. Skipping calibration.")
|
|
||||||
return
|
|
||||||
|
|
||||||
if calib_file.exists():
|
|
||||||
with open(calib_file) as f:
|
|
||||||
calibration = json.load(f)
|
|
||||||
print(f"[INFO] Loaded calibration from {calib_file}")
|
|
||||||
else:
|
|
||||||
print("[INFO] Calibration file not found. Running manual calibration...")
|
|
||||||
calibration = run_full_arm_calibration(motors_bus, "lekiwi", "follower_arm", "follower")
|
|
||||||
print(f"[INFO] Calibration complete. Saving to {calib_file}")
|
|
||||||
with open(calib_file, "w") as f:
|
|
||||||
json.dump(calibration, f)
|
|
||||||
try:
|
|
||||||
motors_bus.set_calibration(calibration)
|
|
||||||
print("[INFO] Applied calibration for follower arm.")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[WARNING] Could not apply calibration: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
def run_lekiwi(robot_config):
|
|
||||||
"""
|
|
||||||
Runs the LeKiwi robot:
|
|
||||||
- Sets up cameras and connects them.
|
|
||||||
- Initializes the follower arm motors.
|
|
||||||
- Calibrates the follower arm if necessary.
|
|
||||||
- Creates ZeroMQ sockets for receiving commands and streaming observations.
|
|
||||||
- Processes incoming commands (arm and wheel commands) and sends back sensor and camera data.
|
|
||||||
"""
|
|
||||||
# Import helper functions and classes
|
|
||||||
from lerobot.common.cameras.utils import make_cameras_from_configs
|
|
||||||
from lerobot.common.motors.feetech.feetech import FeetechMotorsBus, TorqueMode
|
|
||||||
|
|
||||||
# Initialize cameras from the robot configuration.
|
|
||||||
cameras = make_cameras_from_configs(robot_config.cameras)
|
|
||||||
for cam in cameras.values():
|
|
||||||
cam.connect()
|
|
||||||
|
|
||||||
# Initialize the motors bus using the follower arm configuration.
|
|
||||||
motor_config = robot_config.follower_arms.get("main")
|
|
||||||
if motor_config is None:
|
|
||||||
print("[ERROR] Follower arm 'main' configuration not found.")
|
|
||||||
return
|
|
||||||
motors_bus = FeetechMotorsBus(motor_config)
|
|
||||||
motors_bus.connect()
|
|
||||||
|
|
||||||
# Calibrate the follower arm.
|
|
||||||
calibrate_follower_arm(motors_bus, robot_config.calibration_dir)
|
|
||||||
|
|
||||||
# Create the LeKiwi robot instance.
|
|
||||||
robot = LeKiwi(motors_bus)
|
|
||||||
|
|
||||||
# Define the expected arm motor IDs.
|
|
||||||
arm_motor_ids = ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"]
|
|
||||||
|
|
||||||
# Disable torque for each arm motor.
|
|
||||||
for motor in arm_motor_ids:
|
|
||||||
motors_bus.write("Torque_Enable", TorqueMode.DISABLED.value, motor)
|
|
||||||
|
|
||||||
# Set up ZeroMQ sockets.
|
|
||||||
context, cmd_socket, video_socket = setup_zmq_sockets(robot_config)
|
|
||||||
|
|
||||||
# Start the camera capture thread.
|
|
||||||
latest_images_dict = {}
|
|
||||||
images_lock = threading.Lock()
|
|
||||||
stop_event = threading.Event()
|
|
||||||
cam_thread = threading.Thread(
|
|
||||||
target=run_camera_capture, args=(cameras, images_lock, latest_images_dict, stop_event), daemon=True
|
|
||||||
)
|
|
||||||
cam_thread.start()
|
|
||||||
|
|
||||||
last_cmd_time = time.time()
|
|
||||||
print("LeKiwi robot server started. Waiting for commands...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
loop_start_time = time.time()
|
|
||||||
|
|
||||||
# Process incoming commands (non-blocking).
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
msg = cmd_socket.recv_string(zmq.NOBLOCK)
|
|
||||||
except zmq.Again:
|
|
||||||
break
|
|
||||||
try:
|
|
||||||
data = json.loads(msg)
|
|
||||||
# Process arm position commands.
|
|
||||||
if "arm_positions" in data:
|
|
||||||
arm_positions = data["arm_positions"]
|
|
||||||
if not isinstance(arm_positions, list):
|
|
||||||
print(f"[ERROR] Invalid arm_positions: {arm_positions}")
|
|
||||||
elif len(arm_positions) < len(arm_motor_ids):
|
|
||||||
print(
|
|
||||||
f"[WARNING] Received {len(arm_positions)} arm positions, expected {len(arm_motor_ids)}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
for motor, pos in zip(arm_motor_ids, arm_positions, strict=False):
|
|
||||||
motors_bus.write("Goal_Position", pos, motor)
|
|
||||||
# Process wheel (base) commands.
|
|
||||||
if "raw_velocity" in data:
|
|
||||||
raw_command = data["raw_velocity"]
|
|
||||||
# Expect keys: "left_wheel", "back_wheel", "right_wheel".
|
|
||||||
command_speeds = [
|
|
||||||
int(raw_command.get("left_wheel", 0)),
|
|
||||||
int(raw_command.get("back_wheel", 0)),
|
|
||||||
int(raw_command.get("right_wheel", 0)),
|
|
||||||
]
|
|
||||||
robot.set_velocity(command_speeds)
|
|
||||||
last_cmd_time = time.time()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[ERROR] Parsing message failed: {e}")
|
|
||||||
|
|
||||||
# Watchdog: stop the robot if no command is received for over 0.5 seconds.
|
|
||||||
now = time.time()
|
|
||||||
if now - last_cmd_time > 0.5:
|
|
||||||
robot.stop()
|
|
||||||
last_cmd_time = now
|
|
||||||
|
|
||||||
# Read current wheel speeds from the robot.
|
|
||||||
current_velocity = robot.read_velocity()
|
|
||||||
|
|
||||||
# Read the follower arm state from the motors bus.
|
|
||||||
follower_arm_state = []
|
|
||||||
for motor in arm_motor_ids:
|
|
||||||
try:
|
|
||||||
pos = motors_bus.read("Present_Position", motor)
|
|
||||||
# Convert the position to a float (or use as is if already numeric).
|
|
||||||
follower_arm_state.append(float(pos) if not isinstance(pos, (int, float)) else pos)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[ERROR] Reading motor {motor} failed: {e}")
|
|
||||||
|
|
||||||
# Get the latest camera images.
|
|
||||||
with images_lock:
|
|
||||||
images_dict_copy = dict(latest_images_dict)
|
|
||||||
|
|
||||||
# Build the observation dictionary.
|
|
||||||
observation = {
|
|
||||||
"images": images_dict_copy,
|
|
||||||
"present_speed": current_velocity,
|
|
||||||
"follower_arm_state": follower_arm_state,
|
|
||||||
}
|
|
||||||
# Send the observation over the video socket.
|
|
||||||
video_socket.send_string(json.dumps(observation))
|
|
||||||
|
|
||||||
# Ensure a short sleep to avoid overloading the CPU.
|
|
||||||
elapsed = time.time() - loop_start_time
|
|
||||||
time.sleep(
|
|
||||||
max(0.033 - elapsed, 0)
|
|
||||||
) # If robot jitters increase the sleep and monitor cpu load with `top` in cmd
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
print("Shutting down LeKiwi server.")
|
|
||||||
finally:
|
|
||||||
stop_event.set()
|
|
||||||
cam_thread.join()
|
|
||||||
robot.stop()
|
|
||||||
motors_bus.disconnect()
|
|
||||||
cmd_socket.close()
|
|
||||||
video_socket.close()
|
|
||||||
context.term()
|
|
|
@ -1,692 +0,0 @@
|
||||||
import base64
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import zmq
|
|
||||||
|
|
||||||
from lerobot.common.cameras.utils import make_cameras_from_configs
|
|
||||||
from lerobot.common.errors import DeviceNotConnectedError
|
|
||||||
from lerobot.common.motors.feetech.feetech import TorqueMode
|
|
||||||
from lerobot.common.motors.feetech.feetech_calibration import run_full_arm_calibration
|
|
||||||
from lerobot.common.motors.motors_bus import MotorsBus
|
|
||||||
from lerobot.common.motors.utils import make_motors_buses_from_configs
|
|
||||||
from lerobot.common.robots.lekiwi.configuration_lekiwi import LeKiwiRobotConfig
|
|
||||||
from lerobot.common.robots.utils import get_arm_id
|
|
||||||
|
|
||||||
PYNPUT_AVAILABLE = True
|
|
||||||
try:
|
|
||||||
# Only import if there's a valid X server or if we're not on a Pi
|
|
||||||
if ("DISPLAY" not in os.environ) and ("linux" in sys.platform):
|
|
||||||
print("No DISPLAY set. Skipping pynput import.")
|
|
||||||
raise ImportError("pynput blocked intentionally due to no display.")
|
|
||||||
|
|
||||||
from pynput import keyboard
|
|
||||||
except ImportError:
|
|
||||||
keyboard = None
|
|
||||||
PYNPUT_AVAILABLE = False
|
|
||||||
except Exception as e:
|
|
||||||
keyboard = None
|
|
||||||
PYNPUT_AVAILABLE = False
|
|
||||||
print(f"Could not import pynput: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
class MobileManipulator:
|
|
||||||
"""
|
|
||||||
MobileManipulator is a class for connecting to and controlling a remote mobile manipulator robot.
|
|
||||||
The robot includes a three omniwheel mobile base and a remote follower arm.
|
|
||||||
The leader arm is connected locally (on the laptop) and its joint positions are recorded and then
|
|
||||||
forwarded to the remote follower arm (after applying a safety clamp).
|
|
||||||
In parallel, keyboard teleoperation is used to generate raw velocity commands for the wheels.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: LeKiwiRobotConfig):
|
|
||||||
"""
|
|
||||||
Expected keys in config:
|
|
||||||
- ip, port, video_port for the remote connection.
|
|
||||||
- calibration_dir, leader_arms, follower_arms, max_relative_target, etc.
|
|
||||||
"""
|
|
||||||
self.robot_type = config.type
|
|
||||||
self.config = config
|
|
||||||
self.remote_ip = config.ip
|
|
||||||
self.remote_port = config.port
|
|
||||||
self.remote_port_video = config.video_port
|
|
||||||
self.calibration_dir = Path(self.config.calibration_dir)
|
|
||||||
self.logs = {}
|
|
||||||
|
|
||||||
self.teleop_keys = self.config.teleop_keys
|
|
||||||
|
|
||||||
# For teleoperation, the leader arm (local) is used to record the desired arm pose.
|
|
||||||
self.leader_arms = make_motors_buses_from_configs(self.config.leader_arms)
|
|
||||||
|
|
||||||
self.follower_arms = make_motors_buses_from_configs(self.config.follower_arms)
|
|
||||||
|
|
||||||
self.cameras = make_cameras_from_configs(self.config.cameras)
|
|
||||||
|
|
||||||
self.is_connected = False
|
|
||||||
|
|
||||||
self.last_frames = {}
|
|
||||||
self.last_present_speed = {}
|
|
||||||
self.last_remote_arm_state = torch.zeros(6, dtype=torch.float32)
|
|
||||||
|
|
||||||
# Define three speed levels and a current index
|
|
||||||
self.speed_levels = [
|
|
||||||
{"xy": 0.1, "theta": 30}, # slow
|
|
||||||
{"xy": 0.2, "theta": 60}, # medium
|
|
||||||
{"xy": 0.3, "theta": 90}, # fast
|
|
||||||
]
|
|
||||||
self.speed_index = 0 # Start at slow
|
|
||||||
|
|
||||||
# ZeroMQ context and sockets.
|
|
||||||
self.context = None
|
|
||||||
self.cmd_socket = None
|
|
||||||
self.video_socket = None
|
|
||||||
|
|
||||||
# Keyboard state for base teleoperation.
|
|
||||||
self.running = True
|
|
||||||
self.pressed_keys = {
|
|
||||||
"forward": False,
|
|
||||||
"backward": False,
|
|
||||||
"left": False,
|
|
||||||
"right": False,
|
|
||||||
"rotate_left": False,
|
|
||||||
"rotate_right": False,
|
|
||||||
}
|
|
||||||
|
|
||||||
if PYNPUT_AVAILABLE:
|
|
||||||
print("pynput is available - enabling local keyboard listener.")
|
|
||||||
self.listener = keyboard.Listener(
|
|
||||||
on_press=self.on_press,
|
|
||||||
on_release=self.on_release,
|
|
||||||
)
|
|
||||||
self.listener.start()
|
|
||||||
else:
|
|
||||||
print("pynput not available - skipping local keyboard listener.")
|
|
||||||
self.listener = None
|
|
||||||
|
|
||||||
def get_motor_names(self, arms: dict[str, MotorsBus]) -> list:
|
|
||||||
return [f"{arm}_{motor}" for arm, bus in arms.items() for motor in bus.motors]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def camera_features(self) -> dict:
|
|
||||||
cam_ft = {}
|
|
||||||
for cam_key, cam in self.cameras.items():
|
|
||||||
key = f"observation.images.{cam_key}"
|
|
||||||
cam_ft[key] = {
|
|
||||||
"shape": (cam.height, cam.width, cam.channels),
|
|
||||||
"names": ["height", "width", "channels"],
|
|
||||||
"info": None,
|
|
||||||
}
|
|
||||||
return cam_ft
|
|
||||||
|
|
||||||
@property
|
|
||||||
def motor_features(self) -> dict:
|
|
||||||
follower_arm_names = [
|
|
||||||
"shoulder_pan",
|
|
||||||
"shoulder_lift",
|
|
||||||
"elbow_flex",
|
|
||||||
"wrist_flex",
|
|
||||||
"wrist_roll",
|
|
||||||
"gripper",
|
|
||||||
]
|
|
||||||
observations = ["x_mm", "y_mm", "theta"]
|
|
||||||
combined_names = follower_arm_names + observations
|
|
||||||
return {
|
|
||||||
"action": {
|
|
||||||
"dtype": "float32",
|
|
||||||
"shape": (len(combined_names),),
|
|
||||||
"names": combined_names,
|
|
||||||
},
|
|
||||||
"observation.state": {
|
|
||||||
"dtype": "float32",
|
|
||||||
"shape": (len(combined_names),),
|
|
||||||
"names": combined_names,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def features(self):
|
|
||||||
return {**self.motor_features, **self.camera_features}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def has_camera(self):
|
|
||||||
return len(self.cameras) > 0
|
|
||||||
|
|
||||||
@property
|
|
||||||
def num_cameras(self):
|
|
||||||
return len(self.cameras)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def available_arms(self):
|
|
||||||
available = []
|
|
||||||
for name in self.leader_arms:
|
|
||||||
available.append(get_arm_id(name, "leader"))
|
|
||||||
for name in self.follower_arms:
|
|
||||||
available.append(get_arm_id(name, "follower"))
|
|
||||||
return available
|
|
||||||
|
|
||||||
def on_press(self, key):
|
|
||||||
try:
|
|
||||||
# Movement
|
|
||||||
if key.char == self.teleop_keys["forward"]:
|
|
||||||
self.pressed_keys["forward"] = True
|
|
||||||
elif key.char == self.teleop_keys["backward"]:
|
|
||||||
self.pressed_keys["backward"] = True
|
|
||||||
elif key.char == self.teleop_keys["left"]:
|
|
||||||
self.pressed_keys["left"] = True
|
|
||||||
elif key.char == self.teleop_keys["right"]:
|
|
||||||
self.pressed_keys["right"] = True
|
|
||||||
elif key.char == self.teleop_keys["rotate_left"]:
|
|
||||||
self.pressed_keys["rotate_left"] = True
|
|
||||||
elif key.char == self.teleop_keys["rotate_right"]:
|
|
||||||
self.pressed_keys["rotate_right"] = True
|
|
||||||
|
|
||||||
# Quit teleoperation
|
|
||||||
elif key.char == self.teleop_keys["quit"]:
|
|
||||||
self.running = False
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Speed control
|
|
||||||
elif key.char == self.teleop_keys["speed_up"]:
|
|
||||||
self.speed_index = min(self.speed_index + 1, 2)
|
|
||||||
print(f"Speed index increased to {self.speed_index}")
|
|
||||||
elif key.char == self.teleop_keys["speed_down"]:
|
|
||||||
self.speed_index = max(self.speed_index - 1, 0)
|
|
||||||
print(f"Speed index decreased to {self.speed_index}")
|
|
||||||
|
|
||||||
except AttributeError:
|
|
||||||
# e.g., if key is special like Key.esc
|
|
||||||
if key == keyboard.Key.esc:
|
|
||||||
self.running = False
|
|
||||||
return False
|
|
||||||
|
|
||||||
def on_release(self, key):
|
|
||||||
try:
|
|
||||||
if hasattr(key, "char"):
|
|
||||||
if key.char == self.teleop_keys["forward"]:
|
|
||||||
self.pressed_keys["forward"] = False
|
|
||||||
elif key.char == self.teleop_keys["backward"]:
|
|
||||||
self.pressed_keys["backward"] = False
|
|
||||||
elif key.char == self.teleop_keys["left"]:
|
|
||||||
self.pressed_keys["left"] = False
|
|
||||||
elif key.char == self.teleop_keys["right"]:
|
|
||||||
self.pressed_keys["right"] = False
|
|
||||||
elif key.char == self.teleop_keys["rotate_left"]:
|
|
||||||
self.pressed_keys["rotate_left"] = False
|
|
||||||
elif key.char == self.teleop_keys["rotate_right"]:
|
|
||||||
self.pressed_keys["rotate_right"] = False
|
|
||||||
except AttributeError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def connect(self):
|
|
||||||
if not self.leader_arms:
|
|
||||||
raise ValueError("MobileManipulator has no leader arm to connect.")
|
|
||||||
for name in self.leader_arms:
|
|
||||||
print(f"Connecting {name} leader arm.")
|
|
||||||
self.calibrate_leader()
|
|
||||||
|
|
||||||
# Set up ZeroMQ sockets to communicate with the remote mobile robot.
|
|
||||||
self.context = zmq.Context()
|
|
||||||
self.cmd_socket = self.context.socket(zmq.PUSH)
|
|
||||||
connection_string = f"tcp://{self.remote_ip}:{self.remote_port}"
|
|
||||||
self.cmd_socket.connect(connection_string)
|
|
||||||
self.cmd_socket.setsockopt(zmq.CONFLATE, 1)
|
|
||||||
self.video_socket = self.context.socket(zmq.PULL)
|
|
||||||
video_connection = f"tcp://{self.remote_ip}:{self.remote_port_video}"
|
|
||||||
self.video_socket.connect(video_connection)
|
|
||||||
self.video_socket.setsockopt(zmq.CONFLATE, 1)
|
|
||||||
print(
|
|
||||||
f"[INFO] Connected to remote robot at {connection_string} and video stream at {video_connection}."
|
|
||||||
)
|
|
||||||
self.is_connected = True
|
|
||||||
|
|
||||||
def load_or_run_calibration_(self, name, arm, arm_type):
|
|
||||||
arm_id = get_arm_id(name, arm_type)
|
|
||||||
arm_calib_path = self.calibration_dir / f"{arm_id}.json"
|
|
||||||
|
|
||||||
if arm_calib_path.exists():
|
|
||||||
with open(arm_calib_path) as f:
|
|
||||||
calibration = json.load(f)
|
|
||||||
else:
|
|
||||||
print(f"Missing calibration file '{arm_calib_path}'")
|
|
||||||
calibration = run_full_arm_calibration(arm, self.robot_type, name, arm_type)
|
|
||||||
print(f"Calibration is done! Saving calibration file '{arm_calib_path}'")
|
|
||||||
arm_calib_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
with open(arm_calib_path, "w") as f:
|
|
||||||
json.dump(calibration, f)
|
|
||||||
|
|
||||||
return calibration
|
|
||||||
|
|
||||||
def calibrate_leader(self):
|
|
||||||
for name, arm in self.leader_arms.items():
|
|
||||||
# Connect the bus
|
|
||||||
arm.connect()
|
|
||||||
|
|
||||||
# Disable torque on all motors
|
|
||||||
for motor_id in arm.motors:
|
|
||||||
arm.write("Torque_Enable", TorqueMode.DISABLED.value, motor_id)
|
|
||||||
|
|
||||||
# Now run calibration
|
|
||||||
calibration = self.load_or_run_calibration_(name, arm, "leader")
|
|
||||||
arm.set_calibration(calibration)
|
|
||||||
|
|
||||||
def calibrate_follower(self):
|
|
||||||
for name, bus in self.follower_arms.items():
|
|
||||||
bus.connect()
|
|
||||||
|
|
||||||
# Disable torque on all motors
|
|
||||||
for motor_id in bus.motors:
|
|
||||||
bus.write("Torque_Enable", 0, motor_id)
|
|
||||||
|
|
||||||
# Then filter out wheels
|
|
||||||
arm_only_dict = {k: v for k, v in bus.motors.items() if not k.startswith("wheel_")}
|
|
||||||
if not arm_only_dict:
|
|
||||||
continue
|
|
||||||
|
|
||||||
original_motors = bus.motors
|
|
||||||
bus.motors = arm_only_dict
|
|
||||||
|
|
||||||
calibration = self.load_or_run_calibration_(name, bus, "follower")
|
|
||||||
bus.set_calibration(calibration)
|
|
||||||
|
|
||||||
bus.motors = original_motors
|
|
||||||
|
|
||||||
def _get_data(self):
|
|
||||||
"""
|
|
||||||
Polls the video socket for up to 15 ms. If data arrives, decode only
|
|
||||||
the *latest* message, returning frames, speed, and arm state. If
|
|
||||||
nothing arrives for any field, use the last known values.
|
|
||||||
"""
|
|
||||||
frames = {}
|
|
||||||
present_speed = {}
|
|
||||||
remote_arm_state_tensor = torch.zeros(6, dtype=torch.float32)
|
|
||||||
|
|
||||||
# Poll up to 15 ms
|
|
||||||
poller = zmq.Poller()
|
|
||||||
poller.register(self.video_socket, zmq.POLLIN)
|
|
||||||
socks = dict(poller.poll(15))
|
|
||||||
if self.video_socket not in socks or socks[self.video_socket] != zmq.POLLIN:
|
|
||||||
# No new data arrived → reuse ALL old data
|
|
||||||
return (self.last_frames, self.last_present_speed, self.last_remote_arm_state)
|
|
||||||
|
|
||||||
# Drain all messages, keep only the last
|
|
||||||
last_msg = None
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
obs_string = self.video_socket.recv_string(zmq.NOBLOCK)
|
|
||||||
last_msg = obs_string
|
|
||||||
except zmq.Again:
|
|
||||||
break
|
|
||||||
|
|
||||||
if not last_msg:
|
|
||||||
# No new message → also reuse old
|
|
||||||
return (self.last_frames, self.last_present_speed, self.last_remote_arm_state)
|
|
||||||
|
|
||||||
# Decode only the final message
|
|
||||||
try:
|
|
||||||
observation = json.loads(last_msg)
|
|
||||||
|
|
||||||
images_dict = observation.get("images", {})
|
|
||||||
new_speed = observation.get("present_speed", {})
|
|
||||||
new_arm_state = observation.get("follower_arm_state", None)
|
|
||||||
|
|
||||||
# Convert images
|
|
||||||
for cam_name, image_b64 in images_dict.items():
|
|
||||||
if image_b64:
|
|
||||||
jpg_data = base64.b64decode(image_b64)
|
|
||||||
np_arr = np.frombuffer(jpg_data, dtype=np.uint8)
|
|
||||||
frame_candidate = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
|
|
||||||
if frame_candidate is not None:
|
|
||||||
frames[cam_name] = frame_candidate
|
|
||||||
|
|
||||||
# If remote_arm_state is None and frames is None there is no message then use the previous message
|
|
||||||
if new_arm_state is not None and frames is not None:
|
|
||||||
self.last_frames = frames
|
|
||||||
|
|
||||||
remote_arm_state_tensor = torch.tensor(new_arm_state, dtype=torch.float32)
|
|
||||||
self.last_remote_arm_state = remote_arm_state_tensor
|
|
||||||
|
|
||||||
present_speed = new_speed
|
|
||||||
self.last_present_speed = new_speed
|
|
||||||
else:
|
|
||||||
frames = self.last_frames
|
|
||||||
|
|
||||||
remote_arm_state_tensor = self.last_remote_arm_state
|
|
||||||
|
|
||||||
present_speed = self.last_present_speed
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[DEBUG] Error decoding video message: {e}")
|
|
||||||
# If decode fails, fall back to old data
|
|
||||||
return (self.last_frames, self.last_present_speed, self.last_remote_arm_state)
|
|
||||||
|
|
||||||
return frames, present_speed, remote_arm_state_tensor
|
|
||||||
|
|
||||||
def _process_present_speed(self, present_speed: dict) -> torch.Tensor:
|
|
||||||
state_tensor = torch.zeros(3, dtype=torch.int32)
|
|
||||||
if present_speed:
|
|
||||||
decoded = {key: MobileManipulator.raw_to_degps(value) for key, value in present_speed.items()}
|
|
||||||
if "1" in decoded:
|
|
||||||
state_tensor[0] = decoded["1"]
|
|
||||||
if "2" in decoded:
|
|
||||||
state_tensor[1] = decoded["2"]
|
|
||||||
if "3" in decoded:
|
|
||||||
state_tensor[2] = decoded["3"]
|
|
||||||
return state_tensor
|
|
||||||
|
|
||||||
def teleop_step(
|
|
||||||
self, record_data: bool = False
|
|
||||||
) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
|
|
||||||
if not self.is_connected:
|
|
||||||
raise DeviceNotConnectedError("MobileManipulator is not connected. Run `connect()` first.")
|
|
||||||
|
|
||||||
speed_setting = self.speed_levels[self.speed_index]
|
|
||||||
xy_speed = speed_setting["xy"] # e.g. 0.1, 0.25, or 0.4
|
|
||||||
theta_speed = speed_setting["theta"] # e.g. 30, 60, or 90
|
|
||||||
|
|
||||||
# Prepare to assign the position of the leader to the follower
|
|
||||||
arm_positions = []
|
|
||||||
for name in self.leader_arms:
|
|
||||||
pos = self.leader_arms[name].read("Present_Position")
|
|
||||||
pos_tensor = torch.from_numpy(pos).float()
|
|
||||||
# Instead of pos_tensor.item(), use tolist() to convert the entire tensor to a list
|
|
||||||
arm_positions.extend(pos_tensor.tolist())
|
|
||||||
|
|
||||||
# (The rest of your code for generating wheel commands remains unchanged)
|
|
||||||
x_cmd = 0.0 # m/s forward/backward
|
|
||||||
y_cmd = 0.0 # m/s lateral
|
|
||||||
theta_cmd = 0.0 # deg/s rotation
|
|
||||||
if self.pressed_keys["forward"]:
|
|
||||||
x_cmd += xy_speed
|
|
||||||
if self.pressed_keys["backward"]:
|
|
||||||
x_cmd -= xy_speed
|
|
||||||
if self.pressed_keys["left"]:
|
|
||||||
y_cmd += xy_speed
|
|
||||||
if self.pressed_keys["right"]:
|
|
||||||
y_cmd -= xy_speed
|
|
||||||
if self.pressed_keys["rotate_left"]:
|
|
||||||
theta_cmd += theta_speed
|
|
||||||
if self.pressed_keys["rotate_right"]:
|
|
||||||
theta_cmd -= theta_speed
|
|
||||||
|
|
||||||
wheel_commands = self.body_to_wheel_raw(x_cmd, y_cmd, theta_cmd)
|
|
||||||
|
|
||||||
message = {"raw_velocity": wheel_commands, "arm_positions": arm_positions}
|
|
||||||
self.cmd_socket.send_string(json.dumps(message))
|
|
||||||
|
|
||||||
if not record_data:
|
|
||||||
return
|
|
||||||
|
|
||||||
obs_dict = self.capture_observation()
|
|
||||||
|
|
||||||
arm_state_tensor = torch.tensor(arm_positions, dtype=torch.float32)
|
|
||||||
|
|
||||||
wheel_velocity_tuple = self.wheel_raw_to_body(wheel_commands)
|
|
||||||
wheel_velocity_mm = (
|
|
||||||
wheel_velocity_tuple[0] * 1000.0,
|
|
||||||
wheel_velocity_tuple[1] * 1000.0,
|
|
||||||
wheel_velocity_tuple[2],
|
|
||||||
)
|
|
||||||
wheel_tensor = torch.tensor(wheel_velocity_mm, dtype=torch.float32)
|
|
||||||
action_tensor = torch.cat([arm_state_tensor, wheel_tensor])
|
|
||||||
action_dict = {"action": action_tensor}
|
|
||||||
|
|
||||||
return obs_dict, action_dict
|
|
||||||
|
|
||||||
def capture_observation(self) -> dict:
|
|
||||||
"""
|
|
||||||
Capture observations from the remote robot: current follower arm positions,
|
|
||||||
present wheel speeds (converted to body-frame velocities: x, y, theta),
|
|
||||||
and a camera frame.
|
|
||||||
"""
|
|
||||||
if not self.is_connected:
|
|
||||||
raise DeviceNotConnectedError("Not connected. Run `connect()` first.")
|
|
||||||
|
|
||||||
frames, present_speed, remote_arm_state_tensor = self._get_data()
|
|
||||||
|
|
||||||
body_state = self.wheel_raw_to_body(present_speed)
|
|
||||||
|
|
||||||
body_state_mm = (body_state[0] * 1000.0, body_state[1] * 1000.0, body_state[2]) # Convert x,y to mm/s
|
|
||||||
wheel_state_tensor = torch.tensor(body_state_mm, dtype=torch.float32)
|
|
||||||
combined_state_tensor = torch.cat((remote_arm_state_tensor, wheel_state_tensor), dim=0)
|
|
||||||
|
|
||||||
obs_dict = {"observation.state": combined_state_tensor}
|
|
||||||
|
|
||||||
# Loop over each configured camera
|
|
||||||
for cam_name, cam in self.cameras.items():
|
|
||||||
frame = frames.get(cam_name, None)
|
|
||||||
if frame is None:
|
|
||||||
# Create a black image using the camera's configured width, height, and channels
|
|
||||||
frame = np.zeros((cam.height, cam.width, cam.channels), dtype=np.uint8)
|
|
||||||
obs_dict[f"observation.images.{cam_name}"] = torch.from_numpy(frame)
|
|
||||||
|
|
||||||
return obs_dict
|
|
||||||
|
|
||||||
def send_action(self, action: torch.Tensor) -> torch.Tensor:
|
|
||||||
if not self.is_connected:
|
|
||||||
raise DeviceNotConnectedError("Not connected. Run `connect()` first.")
|
|
||||||
|
|
||||||
# Ensure the action tensor has at least 9 elements:
|
|
||||||
# - First 6: arm positions.
|
|
||||||
# - Last 3: base commands.
|
|
||||||
if action.numel() < 9:
|
|
||||||
# Pad with zeros if there are not enough elements.
|
|
||||||
padded = torch.zeros(9, dtype=action.dtype)
|
|
||||||
padded[: action.numel()] = action
|
|
||||||
action = padded
|
|
||||||
|
|
||||||
# Extract arm and base actions.
|
|
||||||
arm_actions = action[:6].flatten()
|
|
||||||
base_actions = action[6:].flatten()
|
|
||||||
|
|
||||||
x_cmd_mm = base_actions[0].item() # mm/s
|
|
||||||
y_cmd_mm = base_actions[1].item() # mm/s
|
|
||||||
theta_cmd = base_actions[2].item() # deg/s
|
|
||||||
|
|
||||||
# Convert mm/s to m/s for the kinematics calculations.
|
|
||||||
x_cmd = x_cmd_mm / 1000.0 # m/s
|
|
||||||
y_cmd = y_cmd_mm / 1000.0 # m/s
|
|
||||||
|
|
||||||
# Compute wheel commands from body commands.
|
|
||||||
wheel_commands = self.body_to_wheel_raw(x_cmd, y_cmd, theta_cmd)
|
|
||||||
|
|
||||||
arm_positions_list = arm_actions.tolist()
|
|
||||||
|
|
||||||
message = {"raw_velocity": wheel_commands, "arm_positions": arm_positions_list}
|
|
||||||
self.cmd_socket.send_string(json.dumps(message))
|
|
||||||
|
|
||||||
return action
|
|
||||||
|
|
||||||
def print_logs(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def disconnect(self):
|
|
||||||
if not self.is_connected:
|
|
||||||
raise DeviceNotConnectedError("Not connected.")
|
|
||||||
if self.cmd_socket:
|
|
||||||
stop_cmd = {
|
|
||||||
"raw_velocity": {"left_wheel": 0, "back_wheel": 0, "right_wheel": 0},
|
|
||||||
"arm_positions": {},
|
|
||||||
}
|
|
||||||
self.cmd_socket.send_string(json.dumps(stop_cmd))
|
|
||||||
self.cmd_socket.close()
|
|
||||||
if self.video_socket:
|
|
||||||
self.video_socket.close()
|
|
||||||
if self.context:
|
|
||||||
self.context.term()
|
|
||||||
if PYNPUT_AVAILABLE:
|
|
||||||
self.listener.stop()
|
|
||||||
self.is_connected = False
|
|
||||||
print("[INFO] Disconnected from remote robot.")
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
if getattr(self, "is_connected", False):
|
|
||||||
self.disconnect()
|
|
||||||
if PYNPUT_AVAILABLE:
|
|
||||||
self.listener.stop()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def degps_to_raw(degps: float) -> int:
|
|
||||||
steps_per_deg = 4096.0 / 360.0
|
|
||||||
speed_in_steps = abs(degps) * steps_per_deg
|
|
||||||
speed_int = int(round(speed_in_steps))
|
|
||||||
if speed_int > 0x7FFF:
|
|
||||||
speed_int = 0x7FFF
|
|
||||||
if degps < 0:
|
|
||||||
return speed_int | 0x8000
|
|
||||||
else:
|
|
||||||
return speed_int & 0x7FFF
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def raw_to_degps(raw_speed: int) -> float:
|
|
||||||
steps_per_deg = 4096.0 / 360.0
|
|
||||||
magnitude = raw_speed & 0x7FFF
|
|
||||||
degps = magnitude / steps_per_deg
|
|
||||||
if raw_speed & 0x8000:
|
|
||||||
degps = -degps
|
|
||||||
return degps
|
|
||||||
|
|
||||||
def body_to_wheel_raw(
|
|
||||||
self,
|
|
||||||
x_cmd: float,
|
|
||||||
y_cmd: float,
|
|
||||||
theta_cmd: float,
|
|
||||||
wheel_radius: float = 0.05,
|
|
||||||
base_radius: float = 0.125,
|
|
||||||
max_raw: int = 3000,
|
|
||||||
) -> dict:
|
|
||||||
"""
|
|
||||||
Convert desired body-frame velocities into wheel raw commands.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
x_cmd : Linear velocity in x (m/s).
|
|
||||||
y_cmd : Linear velocity in y (m/s).
|
|
||||||
theta_cmd : Rotational velocity (deg/s).
|
|
||||||
wheel_radius: Radius of each wheel (meters).
|
|
||||||
base_radius : Distance from the center of rotation to each wheel (meters).
|
|
||||||
max_raw : Maximum allowed raw command (ticks) per wheel.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A dictionary with wheel raw commands:
|
|
||||||
{"left_wheel": value, "back_wheel": value, "right_wheel": value}.
|
|
||||||
|
|
||||||
Notes:
|
|
||||||
- Internally, the method converts theta_cmd to rad/s for the kinematics.
|
|
||||||
- The raw command is computed from the wheels angular speed in deg/s
|
|
||||||
using degps_to_raw(). If any command exceeds max_raw, all commands
|
|
||||||
are scaled down proportionally.
|
|
||||||
"""
|
|
||||||
# Convert rotational velocity from deg/s to rad/s.
|
|
||||||
theta_rad = theta_cmd * (np.pi / 180.0)
|
|
||||||
# Create the body velocity vector [x, y, theta_rad].
|
|
||||||
velocity_vector = np.array([x_cmd, y_cmd, theta_rad])
|
|
||||||
|
|
||||||
# Define the wheel mounting angles with a -90° offset.
|
|
||||||
angles = np.radians(np.array([240, 120, 0]) - 90)
|
|
||||||
# Build the kinematic matrix: each row maps body velocities to a wheel’s linear speed.
|
|
||||||
# The third column (base_radius) accounts for the effect of rotation.
|
|
||||||
m = np.array([[np.cos(a), np.sin(a), base_radius] for a in angles])
|
|
||||||
|
|
||||||
# Compute each wheel’s linear speed (m/s) and then its angular speed (rad/s).
|
|
||||||
wheel_linear_speeds = m.dot(velocity_vector)
|
|
||||||
wheel_angular_speeds = wheel_linear_speeds / wheel_radius
|
|
||||||
|
|
||||||
# Convert wheel angular speeds from rad/s to deg/s.
|
|
||||||
wheel_degps = wheel_angular_speeds * (180.0 / np.pi)
|
|
||||||
|
|
||||||
# Scaling
|
|
||||||
steps_per_deg = 4096.0 / 360.0
|
|
||||||
raw_floats = [abs(degps) * steps_per_deg for degps in wheel_degps]
|
|
||||||
max_raw_computed = max(raw_floats)
|
|
||||||
if max_raw_computed > max_raw:
|
|
||||||
scale = max_raw / max_raw_computed
|
|
||||||
wheel_degps = wheel_degps * scale
|
|
||||||
|
|
||||||
# Convert each wheel’s angular speed (deg/s) to a raw integer.
|
|
||||||
wheel_raw = [MobileManipulator.degps_to_raw(deg) for deg in wheel_degps]
|
|
||||||
|
|
||||||
return {"left_wheel": wheel_raw[0], "back_wheel": wheel_raw[1], "right_wheel": wheel_raw[2]}
|
|
||||||
|
|
||||||
def wheel_raw_to_body(
|
|
||||||
self, wheel_raw: dict, wheel_radius: float = 0.05, base_radius: float = 0.125
|
|
||||||
) -> tuple:
|
|
||||||
"""
|
|
||||||
Convert wheel raw command feedback back into body-frame velocities.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
wheel_raw : Dictionary with raw wheel commands (keys: "left_wheel", "back_wheel", "right_wheel").
|
|
||||||
wheel_radius: Radius of each wheel (meters).
|
|
||||||
base_radius : Distance from the robot center to each wheel (meters).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A tuple (x_cmd, y_cmd, theta_cmd) where:
|
|
||||||
x_cmd : Linear velocity in x (m/s).
|
|
||||||
y_cmd : Linear velocity in y (m/s).
|
|
||||||
theta_cmd : Rotational velocity in deg/s.
|
|
||||||
"""
|
|
||||||
# Extract the raw values in order.
|
|
||||||
raw_list = [
|
|
||||||
int(wheel_raw.get("left_wheel", 0)),
|
|
||||||
int(wheel_raw.get("back_wheel", 0)),
|
|
||||||
int(wheel_raw.get("right_wheel", 0)),
|
|
||||||
]
|
|
||||||
|
|
||||||
# Convert each raw command back to an angular speed in deg/s.
|
|
||||||
wheel_degps = np.array([MobileManipulator.raw_to_degps(r) for r in raw_list])
|
|
||||||
# Convert from deg/s to rad/s.
|
|
||||||
wheel_radps = wheel_degps * (np.pi / 180.0)
|
|
||||||
# Compute each wheel’s linear speed (m/s) from its angular speed.
|
|
||||||
wheel_linear_speeds = wheel_radps * wheel_radius
|
|
||||||
|
|
||||||
# Define the wheel mounting angles with a -90° offset.
|
|
||||||
angles = np.radians(np.array([240, 120, 0]) - 90)
|
|
||||||
m = np.array([[np.cos(a), np.sin(a), base_radius] for a in angles])
|
|
||||||
|
|
||||||
# Solve the inverse kinematics: body_velocity = M⁻¹ · wheel_linear_speeds.
|
|
||||||
m_inv = np.linalg.inv(m)
|
|
||||||
velocity_vector = m_inv.dot(wheel_linear_speeds)
|
|
||||||
x_cmd, y_cmd, theta_rad = velocity_vector
|
|
||||||
theta_cmd = theta_rad * (180.0 / np.pi)
|
|
||||||
return (x_cmd, y_cmd, theta_cmd)
|
|
||||||
|
|
||||||
|
|
||||||
class LeKiwi:
|
|
||||||
def __init__(self, motor_bus):
|
|
||||||
"""
|
|
||||||
Initializes the LeKiwi with Feetech motors bus.
|
|
||||||
"""
|
|
||||||
self.motor_bus = motor_bus
|
|
||||||
self.motor_ids = ["left_wheel", "back_wheel", "right_wheel"]
|
|
||||||
|
|
||||||
# Initialize motors in velocity mode.
|
|
||||||
self.motor_bus.write("Lock", 0)
|
|
||||||
self.motor_bus.write("Mode", [1, 1, 1], self.motor_ids)
|
|
||||||
self.motor_bus.write("Lock", 1)
|
|
||||||
print("Motors set to velocity mode.")
|
|
||||||
|
|
||||||
def read_velocity(self):
|
|
||||||
"""
|
|
||||||
Reads the raw speeds for all wheels. Returns a dictionary with motor names:
|
|
||||||
"""
|
|
||||||
raw_speeds = self.motor_bus.read("Present_Speed", self.motor_ids)
|
|
||||||
return {
|
|
||||||
"left_wheel": int(raw_speeds[0]),
|
|
||||||
"back_wheel": int(raw_speeds[1]),
|
|
||||||
"right_wheel": int(raw_speeds[2]),
|
|
||||||
}
|
|
||||||
|
|
||||||
def set_velocity(self, command_speeds):
|
|
||||||
"""
|
|
||||||
Sends raw velocity commands (16-bit encoded values) directly to the motor bus.
|
|
||||||
The order of speeds must correspond to self.motor_ids.
|
|
||||||
"""
|
|
||||||
self.motor_bus.write("Goal_Speed", command_speeds, self.motor_ids)
|
|
||||||
|
|
||||||
def stop(self):
|
|
||||||
"""Stops the robot by setting all motor speeds to zero."""
|
|
||||||
self.motor_bus.write("Goal_Speed", [0, 0, 0], self.motor_ids)
|
|
||||||
print("Motors stopped.")
|
|
|
@ -1,704 +0,0 @@
|
||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import base64
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import zmq
|
|
||||||
|
|
||||||
from lerobot.common.cameras.utils import make_cameras_from_configs
|
|
||||||
from lerobot.common.errors import DeviceNotConnectedError
|
|
||||||
from lerobot.common.motors.feetech.feetech import TorqueMode
|
|
||||||
from lerobot.common.motors.feetech.feetech_calibration import run_full_arm_calibration
|
|
||||||
from lerobot.common.motors.motors_bus import MotorsBus
|
|
||||||
from lerobot.common.motors.utils import make_motors_buses_from_configs
|
|
||||||
from lerobot.common.robots.lekiwi.configuration_lekiwi import LeKiwiRobotConfig
|
|
||||||
from lerobot.common.robots.utils import get_arm_id
|
|
||||||
|
|
||||||
PYNPUT_AVAILABLE = True
|
|
||||||
try:
|
|
||||||
# Only import if there's a valid X server or if we're not on a Pi
|
|
||||||
if ("DISPLAY" not in os.environ) and ("linux" in sys.platform):
|
|
||||||
print("No DISPLAY set. Skipping pynput import.")
|
|
||||||
raise ImportError("pynput blocked intentionally due to no display.")
|
|
||||||
|
|
||||||
from pynput import keyboard
|
|
||||||
except ImportError:
|
|
||||||
keyboard = None
|
|
||||||
PYNPUT_AVAILABLE = False
|
|
||||||
except Exception as e:
|
|
||||||
keyboard = None
|
|
||||||
PYNPUT_AVAILABLE = False
|
|
||||||
print(f"Could not import pynput: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
class MobileManipulator:
|
|
||||||
"""
|
|
||||||
MobileManipulator is a class for connecting to and controlling a remote mobile manipulator robot.
|
|
||||||
The robot includes a three omniwheel mobile base and a remote follower arm.
|
|
||||||
The leader arm is connected locally (on the laptop) and its joint positions are recorded and then
|
|
||||||
forwarded to the remote follower arm (after applying a safety clamp).
|
|
||||||
In parallel, keyboard teleoperation is used to generate raw velocity commands for the wheels.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: LeKiwiRobotConfig):
|
|
||||||
"""
|
|
||||||
Expected keys in config:
|
|
||||||
- ip, port, video_port for the remote connection.
|
|
||||||
- calibration_dir, leader_arms, follower_arms, max_relative_target, etc.
|
|
||||||
"""
|
|
||||||
self.robot_type = config.type
|
|
||||||
self.config = config
|
|
||||||
self.remote_ip = config.ip
|
|
||||||
self.remote_port = config.port
|
|
||||||
self.remote_port_video = config.video_port
|
|
||||||
self.calibration_dir = Path(self.config.calibration_dir)
|
|
||||||
self.logs = {}
|
|
||||||
|
|
||||||
self.teleop_keys = self.config.teleop_keys
|
|
||||||
|
|
||||||
# For teleoperation, the leader arm (local) is used to record the desired arm pose.
|
|
||||||
self.leader_arms = make_motors_buses_from_configs(self.config.leader_arms)
|
|
||||||
|
|
||||||
self.follower_arms = make_motors_buses_from_configs(self.config.follower_arms)
|
|
||||||
|
|
||||||
self.cameras = make_cameras_from_configs(self.config.cameras)
|
|
||||||
|
|
||||||
self.is_connected = False
|
|
||||||
|
|
||||||
self.last_frames = {}
|
|
||||||
self.last_present_speed = {}
|
|
||||||
self.last_remote_arm_state = torch.zeros(6, dtype=torch.float32)
|
|
||||||
|
|
||||||
# Define three speed levels and a current index
|
|
||||||
self.speed_levels = [
|
|
||||||
{"xy": 0.1, "theta": 30}, # slow
|
|
||||||
{"xy": 0.2, "theta": 60}, # medium
|
|
||||||
{"xy": 0.3, "theta": 90}, # fast
|
|
||||||
]
|
|
||||||
self.speed_index = 0 # Start at slow
|
|
||||||
|
|
||||||
# ZeroMQ context and sockets.
|
|
||||||
self.context = None
|
|
||||||
self.cmd_socket = None
|
|
||||||
self.video_socket = None
|
|
||||||
|
|
||||||
# Keyboard state for base teleoperation.
|
|
||||||
self.running = True
|
|
||||||
self.pressed_keys = {
|
|
||||||
"forward": False,
|
|
||||||
"backward": False,
|
|
||||||
"left": False,
|
|
||||||
"right": False,
|
|
||||||
"rotate_left": False,
|
|
||||||
"rotate_right": False,
|
|
||||||
}
|
|
||||||
|
|
||||||
if PYNPUT_AVAILABLE:
|
|
||||||
print("pynput is available - enabling local keyboard listener.")
|
|
||||||
self.listener = keyboard.Listener(
|
|
||||||
on_press=self.on_press,
|
|
||||||
on_release=self.on_release,
|
|
||||||
)
|
|
||||||
self.listener.start()
|
|
||||||
else:
|
|
||||||
print("pynput not available - skipping local keyboard listener.")
|
|
||||||
self.listener = None
|
|
||||||
|
|
||||||
def get_motor_names(self, arms: dict[str, MotorsBus]) -> list:
|
|
||||||
return [f"{arm}_{motor}" for arm, bus in arms.items() for motor in bus.motors]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def camera_features(self) -> dict:
|
|
||||||
cam_ft = {}
|
|
||||||
for cam_key, cam in self.cameras.items():
|
|
||||||
key = f"observation.images.{cam_key}"
|
|
||||||
cam_ft[key] = {
|
|
||||||
"shape": (cam.height, cam.width, cam.channels),
|
|
||||||
"names": ["height", "width", "channels"],
|
|
||||||
"info": None,
|
|
||||||
}
|
|
||||||
return cam_ft
|
|
||||||
|
|
||||||
@property
|
|
||||||
def motor_features(self) -> dict:
|
|
||||||
follower_arm_names = [
|
|
||||||
"shoulder_pan",
|
|
||||||
"shoulder_lift",
|
|
||||||
"elbow_flex",
|
|
||||||
"wrist_flex",
|
|
||||||
"wrist_roll",
|
|
||||||
"gripper",
|
|
||||||
]
|
|
||||||
observations = ["x_mm", "y_mm", "theta"]
|
|
||||||
combined_names = follower_arm_names + observations
|
|
||||||
return {
|
|
||||||
"action": {
|
|
||||||
"dtype": "float32",
|
|
||||||
"shape": (len(combined_names),),
|
|
||||||
"names": combined_names,
|
|
||||||
},
|
|
||||||
"observation.state": {
|
|
||||||
"dtype": "float32",
|
|
||||||
"shape": (len(combined_names),),
|
|
||||||
"names": combined_names,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def features(self):
|
|
||||||
return {**self.motor_features, **self.camera_features}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def has_camera(self):
|
|
||||||
return len(self.cameras) > 0
|
|
||||||
|
|
||||||
@property
|
|
||||||
def num_cameras(self):
|
|
||||||
return len(self.cameras)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def available_arms(self):
|
|
||||||
available = []
|
|
||||||
for name in self.leader_arms:
|
|
||||||
available.append(get_arm_id(name, "leader"))
|
|
||||||
for name in self.follower_arms:
|
|
||||||
available.append(get_arm_id(name, "follower"))
|
|
||||||
return available
|
|
||||||
|
|
||||||
def on_press(self, key):
|
|
||||||
try:
|
|
||||||
# Movement
|
|
||||||
if key.char == self.teleop_keys["forward"]:
|
|
||||||
self.pressed_keys["forward"] = True
|
|
||||||
elif key.char == self.teleop_keys["backward"]:
|
|
||||||
self.pressed_keys["backward"] = True
|
|
||||||
elif key.char == self.teleop_keys["left"]:
|
|
||||||
self.pressed_keys["left"] = True
|
|
||||||
elif key.char == self.teleop_keys["right"]:
|
|
||||||
self.pressed_keys["right"] = True
|
|
||||||
elif key.char == self.teleop_keys["rotate_left"]:
|
|
||||||
self.pressed_keys["rotate_left"] = True
|
|
||||||
elif key.char == self.teleop_keys["rotate_right"]:
|
|
||||||
self.pressed_keys["rotate_right"] = True
|
|
||||||
|
|
||||||
# Quit teleoperation
|
|
||||||
elif key.char == self.teleop_keys["quit"]:
|
|
||||||
self.running = False
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Speed control
|
|
||||||
elif key.char == self.teleop_keys["speed_up"]:
|
|
||||||
self.speed_index = min(self.speed_index + 1, 2)
|
|
||||||
print(f"Speed index increased to {self.speed_index}")
|
|
||||||
elif key.char == self.teleop_keys["speed_down"]:
|
|
||||||
self.speed_index = max(self.speed_index - 1, 0)
|
|
||||||
print(f"Speed index decreased to {self.speed_index}")
|
|
||||||
|
|
||||||
except AttributeError:
|
|
||||||
# e.g., if key is special like Key.esc
|
|
||||||
if key == keyboard.Key.esc:
|
|
||||||
self.running = False
|
|
||||||
return False
|
|
||||||
|
|
||||||
def on_release(self, key):
|
|
||||||
try:
|
|
||||||
if hasattr(key, "char"):
|
|
||||||
if key.char == self.teleop_keys["forward"]:
|
|
||||||
self.pressed_keys["forward"] = False
|
|
||||||
elif key.char == self.teleop_keys["backward"]:
|
|
||||||
self.pressed_keys["backward"] = False
|
|
||||||
elif key.char == self.teleop_keys["left"]:
|
|
||||||
self.pressed_keys["left"] = False
|
|
||||||
elif key.char == self.teleop_keys["right"]:
|
|
||||||
self.pressed_keys["right"] = False
|
|
||||||
elif key.char == self.teleop_keys["rotate_left"]:
|
|
||||||
self.pressed_keys["rotate_left"] = False
|
|
||||||
elif key.char == self.teleop_keys["rotate_right"]:
|
|
||||||
self.pressed_keys["rotate_right"] = False
|
|
||||||
except AttributeError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def connect(self):
|
|
||||||
if not self.leader_arms:
|
|
||||||
raise ValueError("MobileManipulator has no leader arm to connect.")
|
|
||||||
for name in self.leader_arms:
|
|
||||||
print(f"Connecting {name} leader arm.")
|
|
||||||
self.calibrate_leader()
|
|
||||||
|
|
||||||
# Set up ZeroMQ sockets to communicate with the remote mobile robot.
|
|
||||||
self.context = zmq.Context()
|
|
||||||
self.cmd_socket = self.context.socket(zmq.PUSH)
|
|
||||||
connection_string = f"tcp://{self.remote_ip}:{self.remote_port}"
|
|
||||||
self.cmd_socket.connect(connection_string)
|
|
||||||
self.cmd_socket.setsockopt(zmq.CONFLATE, 1)
|
|
||||||
self.video_socket = self.context.socket(zmq.PULL)
|
|
||||||
video_connection = f"tcp://{self.remote_ip}:{self.remote_port_video}"
|
|
||||||
self.video_socket.connect(video_connection)
|
|
||||||
self.video_socket.setsockopt(zmq.CONFLATE, 1)
|
|
||||||
print(
|
|
||||||
f"[INFO] Connected to remote robot at {connection_string} and video stream at {video_connection}."
|
|
||||||
)
|
|
||||||
self.is_connected = True
|
|
||||||
|
|
||||||
def load_or_run_calibration_(self, name, arm, arm_type):
|
|
||||||
arm_id = get_arm_id(name, arm_type)
|
|
||||||
arm_calib_path = self.calibration_dir / f"{arm_id}.json"
|
|
||||||
|
|
||||||
if arm_calib_path.exists():
|
|
||||||
with open(arm_calib_path) as f:
|
|
||||||
calibration = json.load(f)
|
|
||||||
else:
|
|
||||||
print(f"Missing calibration file '{arm_calib_path}'")
|
|
||||||
calibration = run_full_arm_calibration(arm, self.robot_type, name, arm_type)
|
|
||||||
print(f"Calibration is done! Saving calibration file '{arm_calib_path}'")
|
|
||||||
arm_calib_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
with open(arm_calib_path, "w") as f:
|
|
||||||
json.dump(calibration, f)
|
|
||||||
|
|
||||||
return calibration
|
|
||||||
|
|
||||||
def calibrate_leader(self):
|
|
||||||
for name, arm in self.leader_arms.items():
|
|
||||||
# Connect the bus
|
|
||||||
arm.connect()
|
|
||||||
|
|
||||||
# Disable torque on all motors
|
|
||||||
for motor_id in arm.motors:
|
|
||||||
arm.write("Torque_Enable", TorqueMode.DISABLED.value, motor_id)
|
|
||||||
|
|
||||||
# Now run calibration
|
|
||||||
calibration = self.load_or_run_calibration_(name, arm, "leader")
|
|
||||||
arm.set_calibration(calibration)
|
|
||||||
|
|
||||||
def calibrate_follower(self):
|
|
||||||
for name, bus in self.follower_arms.items():
|
|
||||||
bus.connect()
|
|
||||||
|
|
||||||
# Disable torque on all motors
|
|
||||||
for motor_id in bus.motors:
|
|
||||||
bus.write("Torque_Enable", 0, motor_id)
|
|
||||||
|
|
||||||
# Then filter out wheels
|
|
||||||
arm_only_dict = {k: v for k, v in bus.motors.items() if not k.startswith("wheel_")}
|
|
||||||
if not arm_only_dict:
|
|
||||||
continue
|
|
||||||
|
|
||||||
original_motors = bus.motors
|
|
||||||
bus.motors = arm_only_dict
|
|
||||||
|
|
||||||
calibration = self.load_or_run_calibration_(name, bus, "follower")
|
|
||||||
bus.set_calibration(calibration)
|
|
||||||
|
|
||||||
bus.motors = original_motors
|
|
||||||
|
|
||||||
def _get_data(self):
|
|
||||||
"""
|
|
||||||
Polls the video socket for up to 15 ms. If data arrives, decode only
|
|
||||||
the *latest* message, returning frames, speed, and arm state. If
|
|
||||||
nothing arrives for any field, use the last known values.
|
|
||||||
"""
|
|
||||||
frames = {}
|
|
||||||
present_speed = {}
|
|
||||||
remote_arm_state_tensor = torch.zeros(6, dtype=torch.float32)
|
|
||||||
|
|
||||||
# Poll up to 15 ms
|
|
||||||
poller = zmq.Poller()
|
|
||||||
poller.register(self.video_socket, zmq.POLLIN)
|
|
||||||
socks = dict(poller.poll(15))
|
|
||||||
if self.video_socket not in socks or socks[self.video_socket] != zmq.POLLIN:
|
|
||||||
# No new data arrived → reuse ALL old data
|
|
||||||
return (self.last_frames, self.last_present_speed, self.last_remote_arm_state)
|
|
||||||
|
|
||||||
# Drain all messages, keep only the last
|
|
||||||
last_msg = None
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
obs_string = self.video_socket.recv_string(zmq.NOBLOCK)
|
|
||||||
last_msg = obs_string
|
|
||||||
except zmq.Again:
|
|
||||||
break
|
|
||||||
|
|
||||||
if not last_msg:
|
|
||||||
# No new message → also reuse old
|
|
||||||
return (self.last_frames, self.last_present_speed, self.last_remote_arm_state)
|
|
||||||
|
|
||||||
# Decode only the final message
|
|
||||||
try:
|
|
||||||
observation = json.loads(last_msg)
|
|
||||||
|
|
||||||
images_dict = observation.get("images", {})
|
|
||||||
new_speed = observation.get("present_speed", {})
|
|
||||||
new_arm_state = observation.get("follower_arm_state", None)
|
|
||||||
|
|
||||||
# Convert images
|
|
||||||
for cam_name, image_b64 in images_dict.items():
|
|
||||||
if image_b64:
|
|
||||||
jpg_data = base64.b64decode(image_b64)
|
|
||||||
np_arr = np.frombuffer(jpg_data, dtype=np.uint8)
|
|
||||||
frame_candidate = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
|
|
||||||
if frame_candidate is not None:
|
|
||||||
frames[cam_name] = frame_candidate
|
|
||||||
|
|
||||||
# If remote_arm_state is None and frames is None there is no message then use the previous message
|
|
||||||
if new_arm_state is not None and frames is not None:
|
|
||||||
self.last_frames = frames
|
|
||||||
|
|
||||||
remote_arm_state_tensor = torch.tensor(new_arm_state, dtype=torch.float32)
|
|
||||||
self.last_remote_arm_state = remote_arm_state_tensor
|
|
||||||
|
|
||||||
present_speed = new_speed
|
|
||||||
self.last_present_speed = new_speed
|
|
||||||
else:
|
|
||||||
frames = self.last_frames
|
|
||||||
|
|
||||||
remote_arm_state_tensor = self.last_remote_arm_state
|
|
||||||
|
|
||||||
present_speed = self.last_present_speed
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[DEBUG] Error decoding video message: {e}")
|
|
||||||
# If decode fails, fall back to old data
|
|
||||||
return (self.last_frames, self.last_present_speed, self.last_remote_arm_state)
|
|
||||||
|
|
||||||
return frames, present_speed, remote_arm_state_tensor
|
|
||||||
|
|
||||||
def _process_present_speed(self, present_speed: dict) -> torch.Tensor:
|
|
||||||
state_tensor = torch.zeros(3, dtype=torch.int32)
|
|
||||||
if present_speed:
|
|
||||||
decoded = {key: MobileManipulator.raw_to_degps(value) for key, value in present_speed.items()}
|
|
||||||
if "1" in decoded:
|
|
||||||
state_tensor[0] = decoded["1"]
|
|
||||||
if "2" in decoded:
|
|
||||||
state_tensor[1] = decoded["2"]
|
|
||||||
if "3" in decoded:
|
|
||||||
state_tensor[2] = decoded["3"]
|
|
||||||
return state_tensor
|
|
||||||
|
|
||||||
def teleop_step(
|
|
||||||
self, record_data: bool = False
|
|
||||||
) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
|
|
||||||
if not self.is_connected:
|
|
||||||
raise DeviceNotConnectedError("MobileManipulator is not connected. Run `connect()` first.")
|
|
||||||
|
|
||||||
speed_setting = self.speed_levels[self.speed_index]
|
|
||||||
xy_speed = speed_setting["xy"] # e.g. 0.1, 0.25, or 0.4
|
|
||||||
theta_speed = speed_setting["theta"] # e.g. 30, 60, or 90
|
|
||||||
|
|
||||||
# Prepare to assign the position of the leader to the follower
|
|
||||||
arm_positions = []
|
|
||||||
for name in self.leader_arms:
|
|
||||||
pos = self.leader_arms[name].read("Present_Position")
|
|
||||||
pos_tensor = torch.from_numpy(pos).float()
|
|
||||||
arm_positions.extend(pos_tensor.tolist())
|
|
||||||
|
|
||||||
y_cmd = 0.0 # m/s forward/backward
|
|
||||||
x_cmd = 0.0 # m/s lateral
|
|
||||||
theta_cmd = 0.0 # deg/s rotation
|
|
||||||
if self.pressed_keys["forward"]:
|
|
||||||
y_cmd += xy_speed
|
|
||||||
if self.pressed_keys["backward"]:
|
|
||||||
y_cmd -= xy_speed
|
|
||||||
if self.pressed_keys["left"]:
|
|
||||||
x_cmd += xy_speed
|
|
||||||
if self.pressed_keys["right"]:
|
|
||||||
x_cmd -= xy_speed
|
|
||||||
if self.pressed_keys["rotate_left"]:
|
|
||||||
theta_cmd += theta_speed
|
|
||||||
if self.pressed_keys["rotate_right"]:
|
|
||||||
theta_cmd -= theta_speed
|
|
||||||
|
|
||||||
wheel_commands = self.body_to_wheel_raw(x_cmd, y_cmd, theta_cmd)
|
|
||||||
|
|
||||||
message = {"raw_velocity": wheel_commands, "arm_positions": arm_positions}
|
|
||||||
self.cmd_socket.send_string(json.dumps(message))
|
|
||||||
|
|
||||||
if not record_data:
|
|
||||||
return
|
|
||||||
|
|
||||||
obs_dict = self.capture_observation()
|
|
||||||
|
|
||||||
arm_state_tensor = torch.tensor(arm_positions, dtype=torch.float32)
|
|
||||||
|
|
||||||
wheel_velocity_tuple = self.wheel_raw_to_body(wheel_commands)
|
|
||||||
wheel_velocity_mm = (
|
|
||||||
wheel_velocity_tuple[0] * 1000.0,
|
|
||||||
wheel_velocity_tuple[1] * 1000.0,
|
|
||||||
wheel_velocity_tuple[2],
|
|
||||||
)
|
|
||||||
wheel_tensor = torch.tensor(wheel_velocity_mm, dtype=torch.float32)
|
|
||||||
action_tensor = torch.cat([arm_state_tensor, wheel_tensor])
|
|
||||||
action_dict = {"action": action_tensor}
|
|
||||||
|
|
||||||
return obs_dict, action_dict
|
|
||||||
|
|
||||||
def capture_observation(self) -> dict:
|
|
||||||
"""
|
|
||||||
Capture observations from the remote robot: current follower arm positions,
|
|
||||||
present wheel speeds (converted to body-frame velocities: x, y, theta),
|
|
||||||
and a camera frame.
|
|
||||||
"""
|
|
||||||
if not self.is_connected:
|
|
||||||
raise DeviceNotConnectedError("Not connected. Run `connect()` first.")
|
|
||||||
|
|
||||||
frames, present_speed, remote_arm_state_tensor = self._get_data()
|
|
||||||
|
|
||||||
body_state = self.wheel_raw_to_body(present_speed)
|
|
||||||
|
|
||||||
body_state_mm = (body_state[0] * 1000.0, body_state[1] * 1000.0, body_state[2]) # Convert x,y to mm/s
|
|
||||||
wheel_state_tensor = torch.tensor(body_state_mm, dtype=torch.float32)
|
|
||||||
combined_state_tensor = torch.cat((remote_arm_state_tensor, wheel_state_tensor), dim=0)
|
|
||||||
|
|
||||||
obs_dict = {"observation.state": combined_state_tensor}
|
|
||||||
|
|
||||||
# Loop over each configured camera
|
|
||||||
for cam_name, cam in self.cameras.items():
|
|
||||||
frame = frames.get(cam_name, None)
|
|
||||||
if frame is None:
|
|
||||||
# Create a black image using the camera's configured width, height, and channels
|
|
||||||
frame = np.zeros((cam.height, cam.width, cam.channels), dtype=np.uint8)
|
|
||||||
obs_dict[f"observation.images.{cam_name}"] = torch.from_numpy(frame)
|
|
||||||
|
|
||||||
return obs_dict
|
|
||||||
|
|
||||||
def send_action(self, action: torch.Tensor) -> torch.Tensor:
|
|
||||||
if not self.is_connected:
|
|
||||||
raise DeviceNotConnectedError("Not connected. Run `connect()` first.")
|
|
||||||
|
|
||||||
# Ensure the action tensor has at least 9 elements:
|
|
||||||
# - First 6: arm positions.
|
|
||||||
# - Last 3: base commands.
|
|
||||||
if action.numel() < 9:
|
|
||||||
# Pad with zeros if there are not enough elements.
|
|
||||||
padded = torch.zeros(9, dtype=action.dtype)
|
|
||||||
padded[: action.numel()] = action
|
|
||||||
action = padded
|
|
||||||
|
|
||||||
# Extract arm and base actions.
|
|
||||||
arm_actions = action[:6].flatten()
|
|
||||||
base_actions = action[6:].flatten()
|
|
||||||
|
|
||||||
x_cmd_mm = base_actions[0].item() # mm/s
|
|
||||||
y_cmd_mm = base_actions[1].item() # mm/s
|
|
||||||
theta_cmd = base_actions[2].item() # deg/s
|
|
||||||
|
|
||||||
# Convert mm/s to m/s for the kinematics calculations.
|
|
||||||
x_cmd = x_cmd_mm / 1000.0 # m/s
|
|
||||||
y_cmd = y_cmd_mm / 1000.0 # m/s
|
|
||||||
|
|
||||||
# Compute wheel commands from body commands.
|
|
||||||
wheel_commands = self.body_to_wheel_raw(x_cmd, y_cmd, theta_cmd)
|
|
||||||
|
|
||||||
arm_positions_list = arm_actions.tolist()
|
|
||||||
|
|
||||||
message = {"raw_velocity": wheel_commands, "arm_positions": arm_positions_list}
|
|
||||||
self.cmd_socket.send_string(json.dumps(message))
|
|
||||||
|
|
||||||
return action
|
|
||||||
|
|
||||||
def print_logs(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def disconnect(self):
|
|
||||||
if not self.is_connected:
|
|
||||||
raise DeviceNotConnectedError("Not connected.")
|
|
||||||
if self.cmd_socket:
|
|
||||||
stop_cmd = {
|
|
||||||
"raw_velocity": {"left_wheel": 0, "back_wheel": 0, "right_wheel": 0},
|
|
||||||
"arm_positions": {},
|
|
||||||
}
|
|
||||||
self.cmd_socket.send_string(json.dumps(stop_cmd))
|
|
||||||
self.cmd_socket.close()
|
|
||||||
if self.video_socket:
|
|
||||||
self.video_socket.close()
|
|
||||||
if self.context:
|
|
||||||
self.context.term()
|
|
||||||
if PYNPUT_AVAILABLE:
|
|
||||||
self.listener.stop()
|
|
||||||
self.is_connected = False
|
|
||||||
print("[INFO] Disconnected from remote robot.")
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
if getattr(self, "is_connected", False):
|
|
||||||
self.disconnect()
|
|
||||||
if PYNPUT_AVAILABLE:
|
|
||||||
self.listener.stop()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def degps_to_raw(degps: float) -> int:
|
|
||||||
steps_per_deg = 4096.0 / 360.0
|
|
||||||
speed_in_steps = abs(degps) * steps_per_deg
|
|
||||||
speed_int = int(round(speed_in_steps))
|
|
||||||
if speed_int > 0x7FFF:
|
|
||||||
speed_int = 0x7FFF
|
|
||||||
if degps < 0:
|
|
||||||
return speed_int | 0x8000
|
|
||||||
else:
|
|
||||||
return speed_int & 0x7FFF
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def raw_to_degps(raw_speed: int) -> float:
|
|
||||||
steps_per_deg = 4096.0 / 360.0
|
|
||||||
magnitude = raw_speed & 0x7FFF
|
|
||||||
degps = magnitude / steps_per_deg
|
|
||||||
if raw_speed & 0x8000:
|
|
||||||
degps = -degps
|
|
||||||
return degps
|
|
||||||
|
|
||||||
def body_to_wheel_raw(
|
|
||||||
self,
|
|
||||||
x_cmd: float,
|
|
||||||
y_cmd: float,
|
|
||||||
theta_cmd: float,
|
|
||||||
wheel_radius: float = 0.05,
|
|
||||||
base_radius: float = 0.125,
|
|
||||||
max_raw: int = 3000,
|
|
||||||
) -> dict:
|
|
||||||
"""
|
|
||||||
Convert desired body-frame velocities into wheel raw commands.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
x_cmd : Linear velocity in x (m/s).
|
|
||||||
y_cmd : Linear velocity in y (m/s).
|
|
||||||
theta_cmd : Rotational velocity (deg/s).
|
|
||||||
wheel_radius: Radius of each wheel (meters).
|
|
||||||
base_radius : Distance from the center of rotation to each wheel (meters).
|
|
||||||
max_raw : Maximum allowed raw command (ticks) per wheel.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A dictionary with wheel raw commands:
|
|
||||||
{"left_wheel": value, "back_wheel": value, "right_wheel": value}.
|
|
||||||
|
|
||||||
Notes:
|
|
||||||
- Internally, the method converts theta_cmd to rad/s for the kinematics.
|
|
||||||
- The raw command is computed from the wheels angular speed in deg/s
|
|
||||||
using degps_to_raw(). If any command exceeds max_raw, all commands
|
|
||||||
are scaled down proportionally.
|
|
||||||
"""
|
|
||||||
# Convert rotational velocity from deg/s to rad/s.
|
|
||||||
theta_rad = theta_cmd * (np.pi / 180.0)
|
|
||||||
# Create the body velocity vector [x, y, theta_rad].
|
|
||||||
velocity_vector = np.array([x_cmd, y_cmd, theta_rad])
|
|
||||||
|
|
||||||
# Define the wheel mounting angles (defined from y axis cw)
|
|
||||||
angles = np.radians(np.array([300, 180, 60]))
|
|
||||||
# Build the kinematic matrix: each row maps body velocities to a wheel’s linear speed.
|
|
||||||
# The third column (base_radius) accounts for the effect of rotation.
|
|
||||||
m = np.array([[np.cos(a), np.sin(a), base_radius] for a in angles])
|
|
||||||
|
|
||||||
# Compute each wheel’s linear speed (m/s) and then its angular speed (rad/s).
|
|
||||||
wheel_linear_speeds = m.dot(velocity_vector)
|
|
||||||
wheel_angular_speeds = wheel_linear_speeds / wheel_radius
|
|
||||||
|
|
||||||
# Convert wheel angular speeds from rad/s to deg/s.
|
|
||||||
wheel_degps = wheel_angular_speeds * (180.0 / np.pi)
|
|
||||||
|
|
||||||
# Scaling
|
|
||||||
steps_per_deg = 4096.0 / 360.0
|
|
||||||
raw_floats = [abs(degps) * steps_per_deg for degps in wheel_degps]
|
|
||||||
max_raw_computed = max(raw_floats)
|
|
||||||
if max_raw_computed > max_raw:
|
|
||||||
scale = max_raw / max_raw_computed
|
|
||||||
wheel_degps = wheel_degps * scale
|
|
||||||
|
|
||||||
# Convert each wheel’s angular speed (deg/s) to a raw integer.
|
|
||||||
wheel_raw = [MobileManipulator.degps_to_raw(deg) for deg in wheel_degps]
|
|
||||||
|
|
||||||
return {"left_wheel": wheel_raw[0], "back_wheel": wheel_raw[1], "right_wheel": wheel_raw[2]}
|
|
||||||
|
|
||||||
def wheel_raw_to_body(
|
|
||||||
self, wheel_raw: dict, wheel_radius: float = 0.05, base_radius: float = 0.125
|
|
||||||
) -> tuple:
|
|
||||||
"""
|
|
||||||
Convert wheel raw command feedback back into body-frame velocities.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
wheel_raw : Dictionary with raw wheel commands (keys: "left_wheel", "back_wheel", "right_wheel").
|
|
||||||
wheel_radius: Radius of each wheel (meters).
|
|
||||||
base_radius : Distance from the robot center to each wheel (meters).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A tuple (x_cmd, y_cmd, theta_cmd) where:
|
|
||||||
x_cmd : Linear velocity in x (m/s).
|
|
||||||
y_cmd : Linear velocity in y (m/s).
|
|
||||||
theta_cmd : Rotational velocity in deg/s.
|
|
||||||
"""
|
|
||||||
# Extract the raw values in order.
|
|
||||||
raw_list = [
|
|
||||||
int(wheel_raw.get("left_wheel", 0)),
|
|
||||||
int(wheel_raw.get("back_wheel", 0)),
|
|
||||||
int(wheel_raw.get("right_wheel", 0)),
|
|
||||||
]
|
|
||||||
|
|
||||||
# Convert each raw command back to an angular speed in deg/s.
|
|
||||||
wheel_degps = np.array([MobileManipulator.raw_to_degps(r) for r in raw_list])
|
|
||||||
# Convert from deg/s to rad/s.
|
|
||||||
wheel_radps = wheel_degps * (np.pi / 180.0)
|
|
||||||
# Compute each wheel’s linear speed (m/s) from its angular speed.
|
|
||||||
wheel_linear_speeds = wheel_radps * wheel_radius
|
|
||||||
|
|
||||||
# Define the wheel mounting angles (defined from y axis cw)
|
|
||||||
angles = np.radians(np.array([300, 180, 60]))
|
|
||||||
m = np.array([[np.cos(a), np.sin(a), base_radius] for a in angles])
|
|
||||||
|
|
||||||
# Solve the inverse kinematics: body_velocity = M⁻¹ · wheel_linear_speeds.
|
|
||||||
m_inv = np.linalg.inv(m)
|
|
||||||
velocity_vector = m_inv.dot(wheel_linear_speeds)
|
|
||||||
x_cmd, y_cmd, theta_rad = velocity_vector
|
|
||||||
theta_cmd = theta_rad * (180.0 / np.pi)
|
|
||||||
return (x_cmd, y_cmd, theta_cmd)
|
|
||||||
|
|
||||||
|
|
||||||
class LeKiwi:
|
|
||||||
def __init__(self, motor_bus):
|
|
||||||
"""
|
|
||||||
Initializes the LeKiwi with Feetech motors bus.
|
|
||||||
"""
|
|
||||||
self.motor_bus = motor_bus
|
|
||||||
self.motor_ids = ["left_wheel", "back_wheel", "right_wheel"]
|
|
||||||
|
|
||||||
# Initialize motors in velocity mode.
|
|
||||||
self.motor_bus.write("Lock", 0)
|
|
||||||
self.motor_bus.write("Mode", [1, 1, 1], self.motor_ids)
|
|
||||||
self.motor_bus.write("Lock", 1)
|
|
||||||
print("Motors set to velocity mode.")
|
|
||||||
|
|
||||||
def read_velocity(self):
|
|
||||||
"""
|
|
||||||
Reads the raw speeds for all wheels. Returns a dictionary with motor names:
|
|
||||||
"""
|
|
||||||
raw_speeds = self.motor_bus.read("Present_Speed", self.motor_ids)
|
|
||||||
return {
|
|
||||||
"left_wheel": int(raw_speeds[0]),
|
|
||||||
"back_wheel": int(raw_speeds[1]),
|
|
||||||
"right_wheel": int(raw_speeds[2]),
|
|
||||||
}
|
|
||||||
|
|
||||||
def set_velocity(self, command_speeds):
|
|
||||||
"""
|
|
||||||
Sends raw velocity commands (16-bit encoded values) directly to the motor bus.
|
|
||||||
The order of speeds must correspond to self.motor_ids.
|
|
||||||
"""
|
|
||||||
self.motor_bus.write("Goal_Speed", command_speeds, self.motor_ids)
|
|
||||||
|
|
||||||
def stop(self):
|
|
||||||
"""Stops the robot by setting all motor speeds to zero."""
|
|
||||||
self.motor_bus.write("Goal_Speed", [0, 0, 0], self.motor_ids)
|
|
||||||
print("Motors stopped.")
|
|
|
@ -31,9 +31,15 @@ conda create -y -n lerobot python=3.10 && conda activate lerobot
|
||||||
git clone https://github.com/huggingface/lerobot.git ~/lerobot
|
git clone https://github.com/huggingface/lerobot.git ~/lerobot
|
||||||
```
|
```
|
||||||
|
|
||||||
5. Install LeRobot with dependencies for the feetech motors:
|
5. Install ffmpeg in your environment:
|
||||||
|
When using `miniconda`, install `ffmpeg` in your environment:
|
||||||
```bash
|
```bash
|
||||||
cd ~/lerobot && pip install --no-binary=av -e ".[feetech]"
|
conda install ffmpeg -c conda-forge
|
||||||
|
```
|
||||||
|
|
||||||
|
6. Install LeRobot with dependencies for the feetech motors:
|
||||||
|
```bash
|
||||||
|
cd ~/lerobot && pip install -e ".[feetech]"
|
||||||
```
|
```
|
||||||
|
|
||||||
## Configure the motors
|
## Configure the motors
|
||||||
|
@ -212,6 +218,9 @@ python lerobot/scripts/control_robot.py \
|
||||||
|
|
||||||
**Teleop with displaying cameras**
|
**Teleop with displaying cameras**
|
||||||
Follow [this guide to setup your cameras](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#c-add-your-cameras-with-opencvcamera). Then you will be able to display the cameras on your computer while you are teleoperating by running the following code. This is useful to prepare your setup before recording your first dataset.
|
Follow [this guide to setup your cameras](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#c-add-your-cameras-with-opencvcamera). Then you will be able to display the cameras on your computer while you are teleoperating by running the following code. This is useful to prepare your setup before recording your first dataset.
|
||||||
|
|
||||||
|
> **NOTE:** To visualize the data, enable `--control.display_data=true`. This streams the data using `rerun`.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python lerobot/scripts/control_robot.py \
|
python lerobot/scripts/control_robot.py \
|
||||||
--robot.type=moss \
|
--robot.type=moss \
|
||||||
|
|
|
@ -23,7 +23,9 @@ class Robot(abc.ABC):
|
||||||
self.robot_type = self.name
|
self.robot_type = self.name
|
||||||
self.id = config.id
|
self.id = config.id
|
||||||
self.calibration_dir = (
|
self.calibration_dir = (
|
||||||
config.calibration_dir if config.calibration_dir else HF_LEROBOT_CALIBRATION / ROBOTS / self.name
|
Path(config.calibration_dir)
|
||||||
|
if config.calibration_dir
|
||||||
|
else Path(HF_LEROBOT_CALIBRATION / ROBOTS / self.name)
|
||||||
)
|
)
|
||||||
self.calibration_dir.mkdir(parents=True, exist_ok=True)
|
self.calibration_dir.mkdir(parents=True, exist_ok=True)
|
||||||
self.calibration_fpath = self.calibration_dir / f"{self.id}.json"
|
self.calibration_fpath = self.calibration_dir / f"{self.id}.json"
|
||||||
|
|
|
@ -57,9 +57,15 @@ conda activate lerobot
|
||||||
git clone https://github.com/huggingface/lerobot.git ~/lerobot
|
git clone https://github.com/huggingface/lerobot.git ~/lerobot
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 5. Install LeRobot with dependencies for the feetech motors:
|
#### 5. Install ffmpeg in your environment:
|
||||||
|
When using `miniconda`, install `ffmpeg` in your environment:
|
||||||
```bash
|
```bash
|
||||||
cd ~/lerobot && pip install --no-binary=av -e ".[feetech]"
|
conda install ffmpeg -c conda-forge
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 6. Install LeRobot with dependencies for the feetech motors:
|
||||||
|
```bash
|
||||||
|
cd ~/lerobot && pip install -e ".[feetech]"
|
||||||
```
|
```
|
||||||
|
|
||||||
Great :hugs:! You are now done installing LeRobot and we can begin assembling the SO100 arms :robot:.
|
Great :hugs:! You are now done installing LeRobot and we can begin assembling the SO100 arms :robot:.
|
||||||
|
@ -491,6 +497,9 @@ python lerobot/scripts/control_robot.py \
|
||||||
|
|
||||||
#### a. Teleop with displaying cameras
|
#### a. Teleop with displaying cameras
|
||||||
Follow [this guide to setup your cameras](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#c-add-your-cameras-with-opencvcamera). Then you will be able to display the cameras on your computer while you are teleoperating by running the following code. This is useful to prepare your setup before recording your first dataset.
|
Follow [this guide to setup your cameras](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#c-add-your-cameras-with-opencvcamera). Then you will be able to display the cameras on your computer while you are teleoperating by running the following code. This is useful to prepare your setup before recording your first dataset.
|
||||||
|
|
||||||
|
> **NOTE:** To visualize the data, enable `--control.display_data=true`. This streams the data using `rerun`.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python lerobot/scripts/control_robot.py \
|
python lerobot/scripts/control_robot.py \
|
||||||
--robot.type=so100 \
|
--robot.type=so100 \
|
||||||
|
|
|
@ -55,6 +55,7 @@ class SO100Follower(Robot):
|
||||||
"wrist_roll": Motor(5, "sts3215", MotorNormMode.RANGE_M100_100),
|
"wrist_roll": Motor(5, "sts3215", MotorNormMode.RANGE_M100_100),
|
||||||
"gripper": Motor(6, "sts3215", MotorNormMode.RANGE_0_100),
|
"gripper": Motor(6, "sts3215", MotorNormMode.RANGE_0_100),
|
||||||
},
|
},
|
||||||
|
calibration=self.calibration,
|
||||||
)
|
)
|
||||||
self.cameras = make_cameras_from_configs(config.cameras)
|
self.cameras = make_cameras_from_configs(config.cameras)
|
||||||
|
|
||||||
|
@ -120,7 +121,7 @@ class SO100Follower(Robot):
|
||||||
|
|
||||||
full_turn_motor = "wrist_roll"
|
full_turn_motor = "wrist_roll"
|
||||||
unknown_range_motors = [name for name in self.arm.names if name != full_turn_motor]
|
unknown_range_motors = [name for name in self.arm.names if name != full_turn_motor]
|
||||||
logger.info(
|
print(
|
||||||
f"Move all joints except '{full_turn_motor}' sequentially through their "
|
f"Move all joints except '{full_turn_motor}' sequentially through their "
|
||||||
"entire ranges of motion.\nRecording positions. Press ENTER to stop..."
|
"entire ranges of motion.\nRecording positions. Press ENTER to stop..."
|
||||||
)
|
)
|
||||||
|
@ -143,7 +144,7 @@ class SO100Follower(Robot):
|
||||||
print("Calibration saved to", self.calibration_fpath)
|
print("Calibration saved to", self.calibration_fpath)
|
||||||
|
|
||||||
def configure(self) -> None:
|
def configure(self) -> None:
|
||||||
self.arm.disable_torque()
|
with self.arm.torque_disabled():
|
||||||
self.arm.configure_motors()
|
self.arm.configure_motors()
|
||||||
for name in self.arm.names:
|
for name in self.arm.names:
|
||||||
self.arm.write("Operating_Mode", name, OperatingMode.POSITION.value)
|
self.arm.write("Operating_Mode", name, OperatingMode.POSITION.value)
|
||||||
|
@ -152,12 +153,6 @@ class SO100Follower(Robot):
|
||||||
# Set I_Coefficient and D_Coefficient to default value 0 and 32
|
# Set I_Coefficient and D_Coefficient to default value 0 and 32
|
||||||
self.arm.write("I_Coefficient", name, 0)
|
self.arm.write("I_Coefficient", name, 0)
|
||||||
self.arm.write("D_Coefficient", name, 32)
|
self.arm.write("D_Coefficient", name, 32)
|
||||||
# Set Maximum_Acceleration to 254 to speedup acceleration and deceleration of
|
|
||||||
# the motors. Note: this address is not in the official STS3215 Memory Table
|
|
||||||
self.arm.write("Maximum_Acceleration", name, 254)
|
|
||||||
self.arm.write("Acceleration", name, 254)
|
|
||||||
|
|
||||||
self.arm.enable_torque()
|
|
||||||
|
|
||||||
def get_observation(self) -> dict[str, Any]:
|
def get_observation(self) -> dict[str, Any]:
|
||||||
if not self.is_connected:
|
if not self.is_connected:
|
||||||
|
|
|
@ -43,14 +43,19 @@ conda create -y -n lerobot python=3.10 && conda activate lerobot
|
||||||
git clone https://github.com/huggingface/lerobot.git ~/lerobot
|
git clone https://github.com/huggingface/lerobot.git ~/lerobot
|
||||||
```
|
```
|
||||||
|
|
||||||
6. Install LeRobot with stretch dependencies:
|
6. When using `miniconda`, install `ffmpeg` in your environment:
|
||||||
```bash
|
```bash
|
||||||
cd ~/lerobot && pip install --no-binary=av -e ".[stretch]"
|
conda install ffmpeg -c conda-forge
|
||||||
|
```
|
||||||
|
|
||||||
|
7. Install LeRobot with stretch dependencies:
|
||||||
|
```bash
|
||||||
|
cd ~/lerobot && pip install -e ".[stretch]"
|
||||||
```
|
```
|
||||||
|
|
||||||
> **Note:** If you get this message, you can ignore it: `ERROR: pip's dependency resolver does not currently take into account all the packages that are installed.`
|
> **Note:** If you get this message, you can ignore it: `ERROR: pip's dependency resolver does not currently take into account all the packages that are installed.`
|
||||||
|
|
||||||
7. Run a [system check](https://docs.hello-robot.com/0.3/getting_started/stretch_hardware_overview/#system-check) to make sure your robot is ready:
|
8. Run a [system check](https://docs.hello-robot.com/0.3/getting_started/stretch_hardware_overview/#system-check) to make sure your robot is ready:
|
||||||
```bash
|
```bash
|
||||||
stretch_system_check.py
|
stretch_system_check.py
|
||||||
```
|
```
|
||||||
|
@ -97,6 +102,8 @@ This is equivalent to running `stretch_robot_home.py`
|
||||||
Before trying teleoperation, you need activate the gamepad controller by pressing the middle button. For more info, see Stretch's [doc](https://docs.hello-robot.com/0.3/getting_started/hello_robot/#gamepad-teleoperation).
|
Before trying teleoperation, you need activate the gamepad controller by pressing the middle button. For more info, see Stretch's [doc](https://docs.hello-robot.com/0.3/getting_started/hello_robot/#gamepad-teleoperation).
|
||||||
|
|
||||||
Now try out teleoperation (see above documentation to learn about the gamepad controls):
|
Now try out teleoperation (see above documentation to learn about the gamepad controls):
|
||||||
|
|
||||||
|
> **NOTE:** To visualize the data, enable `--control.display_data=true`. This streams the data using `rerun`.
|
||||||
```bash
|
```bash
|
||||||
python lerobot/scripts/control_robot.py \
|
python lerobot/scripts/control_robot.py \
|
||||||
--robot.type=stretch \
|
--robot.type=stretch \
|
||||||
|
|
|
@ -49,25 +49,26 @@ def make_robot_config(robot_type: str, **kwargs) -> RobotConfig:
|
||||||
|
|
||||||
return Stretch3RobotConfig(**kwargs)
|
return Stretch3RobotConfig(**kwargs)
|
||||||
elif robot_type == "lekiwi":
|
elif robot_type == "lekiwi":
|
||||||
from .lekiwi.configuration_lekiwi import LeKiwiRobotConfig
|
from .lekiwi.config_lekiwi import LeKiwiConfig
|
||||||
|
|
||||||
return LeKiwiRobotConfig(**kwargs)
|
return LeKiwiConfig(**kwargs)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Robot type '{robot_type}' is not available.")
|
raise ValueError(f"Robot type '{robot_type}' is not available.")
|
||||||
|
|
||||||
|
|
||||||
def make_robot_from_config(config: RobotConfig):
|
def make_robot_from_config(config: RobotConfig):
|
||||||
from .lekiwi.configuration_lekiwi import LeKiwiRobotConfig
|
from .lekiwi.config_lekiwi import LeKiwiConfig
|
||||||
from .manipulator import ManipulatorRobotConfig
|
from .manipulator import ManipulatorRobotConfig
|
||||||
|
|
||||||
if isinstance(config, ManipulatorRobotConfig):
|
if isinstance(config, ManipulatorRobotConfig):
|
||||||
from lerobot.common.robots.manipulator import ManipulatorRobot
|
from lerobot.common.robots.manipulator import ManipulatorRobot
|
||||||
|
|
||||||
return ManipulatorRobot(config)
|
return ManipulatorRobot(config)
|
||||||
elif isinstance(config, LeKiwiRobotConfig):
|
elif isinstance(config, LeKiwiConfig):
|
||||||
from lerobot.common.robots.mobile_manipulator import MobileManipulator
|
# TODO(Steven): Change when we decide what to do with these scripts
|
||||||
|
# from lerobot.common.robots.mobile_manipulator import MobileManipulator
|
||||||
return MobileManipulator(config)
|
# return MobileManipulator(config)
|
||||||
|
...
|
||||||
else:
|
else:
|
||||||
from lerobot.common.robots.stretch3.robot_stretch3 import Stretch3Robot
|
from lerobot.common.robots.stretch3.robot_stretch3 import Stretch3Robot
|
||||||
|
|
||||||
|
|
|
@ -30,9 +30,14 @@ conda create -y -n lerobot python=3.10 && conda activate lerobot
|
||||||
git clone https://github.com/huggingface/lerobot.git ~/lerobot
|
git clone https://github.com/huggingface/lerobot.git ~/lerobot
|
||||||
```
|
```
|
||||||
|
|
||||||
5. Install LeRobot with dependencies for the Aloha motors (dynamixel) and cameras (intelrealsense):
|
5. When using `miniconda`, install `ffmpeg` in your environment:
|
||||||
```bash
|
```bash
|
||||||
cd ~/lerobot && pip install --no-binary=av -e ".[dynamixel, intelrealsense]"
|
conda install ffmpeg -c conda-forge
|
||||||
|
```
|
||||||
|
|
||||||
|
6. Install LeRobot with dependencies for the Aloha motors (dynamixel) and cameras (intelrealsense):
|
||||||
|
```bash
|
||||||
|
cd ~/lerobot && pip install -e ".[dynamixel, intelrealsense]"
|
||||||
```
|
```
|
||||||
|
|
||||||
## Teleoperate
|
## Teleoperate
|
||||||
|
@ -43,6 +48,9 @@ Teleoperation consists in manually operating the leader arms to move the followe
|
||||||
2. Our code assumes that your robot has been assembled following Trossen Robotics instructions. This allows us to skip calibration, as we use the pre-defined calibration files in `.cache/calibration/aloha_default`. If you replace a motor, make sure you follow the exact instructions from Trossen Robotics.
|
2. Our code assumes that your robot has been assembled following Trossen Robotics instructions. This allows us to skip calibration, as we use the pre-defined calibration files in `.cache/calibration/aloha_default`. If you replace a motor, make sure you follow the exact instructions from Trossen Robotics.
|
||||||
|
|
||||||
By running the following code, you can start your first **SAFE** teleoperation:
|
By running the following code, you can start your first **SAFE** teleoperation:
|
||||||
|
|
||||||
|
> **NOTE:** To visualize the data, enable `--control.display_data=true`. This streams the data using `rerun`.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python lerobot/scripts/control_robot.py \
|
python lerobot/scripts/control_robot.py \
|
||||||
--robot.type=aloha \
|
--robot.type=aloha \
|
||||||
|
|
|
@ -117,7 +117,7 @@ class ViperX(Robot):
|
||||||
|
|
||||||
full_turn_motors = ["shoulder_pan", "wrist_roll"]
|
full_turn_motors = ["shoulder_pan", "wrist_roll"]
|
||||||
unknown_range_motors = [name for name in self.arm.names if name not in full_turn_motors]
|
unknown_range_motors = [name for name in self.arm.names if name not in full_turn_motors]
|
||||||
logger.info(
|
print(
|
||||||
f"Move all joints except {full_turn_motors} sequentially through their entire "
|
f"Move all joints except {full_turn_motors} sequentially through their entire "
|
||||||
"ranges of motion.\nRecording positions. Press ENTER to stop..."
|
"ranges of motion.\nRecording positions. Press ENTER to stop..."
|
||||||
)
|
)
|
||||||
|
@ -141,7 +141,7 @@ class ViperX(Robot):
|
||||||
logger.info(f"Calibration saved to {self.calibration_fpath}")
|
logger.info(f"Calibration saved to {self.calibration_fpath}")
|
||||||
|
|
||||||
def configure(self) -> None:
|
def configure(self) -> None:
|
||||||
self.arm.disable_torque()
|
with self.arm.torque_disabled():
|
||||||
self.arm.configure_motors()
|
self.arm.configure_motors()
|
||||||
|
|
||||||
# Set secondary/shadow ID for shoulder and elbow. These joints have two motors.
|
# Set secondary/shadow ID for shoulder and elbow. These joints have two motors.
|
||||||
|
@ -154,19 +154,18 @@ class ViperX(Robot):
|
||||||
# TODO(aliberts): remove as it's actually useless in position control
|
# TODO(aliberts): remove as it's actually useless in position control
|
||||||
self.arm.write("Velocity_Limit", 131)
|
self.arm.write("Velocity_Limit", 131)
|
||||||
|
|
||||||
# Use 'extended position mode' for all motors except gripper, because in joint mode the servos can't
|
# Use 'extended position mode' for all motors except gripper, because in joint mode the servos
|
||||||
# rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while assembling the arm,
|
# can't rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while assembling
|
||||||
# you could end up with a servo with a position 0 or 4095 at a crucial point. See:
|
# the arm, you could end up with a servo with a position 0 or 4095 at a crucial point.
|
||||||
# https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11
|
# See: https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11
|
||||||
for name in self.arm.names:
|
for name in self.arm.names:
|
||||||
if name != "gripper":
|
if name != "gripper":
|
||||||
self.arm.write("Operating_Mode", name, OperatingMode.EXTENDED_POSITION.value)
|
self.arm.write("Operating_Mode", name, OperatingMode.EXTENDED_POSITION.value)
|
||||||
|
|
||||||
# Use 'position control current based' for follower gripper to be limited by the limit of the current.
|
# Use 'position control current based' for follower gripper to be limited by the limit of the
|
||||||
# It can grasp an object without forcing too much even tho, it's goal position is a complete grasp
|
# current. It can grasp an object without forcing too much even tho, it's goal position is a
|
||||||
# (both gripper fingers are ordered to join and reach a touch).
|
# complete grasp (both gripper fingers are ordered to join and reach a touch).
|
||||||
self.arm.write("Operating_Mode", "gripper", OperatingMode.CURRENT_POSITION.value)
|
self.arm.write("Operating_Mode", "gripper", OperatingMode.CURRENT_POSITION.value)
|
||||||
self.arm.enable_torque()
|
|
||||||
|
|
||||||
def get_observation(self) -> dict[str, Any]:
|
def get_observation(self) -> dict[str, Any]:
|
||||||
"""The returned observations do not have a batch dimension."""
|
"""The returned observations do not have a batch dimension."""
|
||||||
|
|
|
@ -22,4 +22,5 @@ from ..config import TeleoperatorConfig
|
||||||
@TeleoperatorConfig.register_subclass("keyboard")
|
@TeleoperatorConfig.register_subclass("keyboard")
|
||||||
@dataclass
|
@dataclass
|
||||||
class KeyboardTeleopConfig(TeleoperatorConfig):
|
class KeyboardTeleopConfig(TeleoperatorConfig):
|
||||||
|
# TODO(Steven): Consider setting in here the keys that we want to capture/listen
|
||||||
mock: bool = False
|
mock: bool = False
|
||||||
|
|
|
@ -19,8 +19,7 @@ import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
|
from typing import Any
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||||
|
|
||||||
|
@ -59,48 +58,59 @@ class KeyboardTeleop(Teleoperator):
|
||||||
self.event_queue = Queue()
|
self.event_queue = Queue()
|
||||||
self.current_pressed = {}
|
self.current_pressed = {}
|
||||||
self.listener = None
|
self.listener = None
|
||||||
self.is_connected = False
|
self._is_connected = False
|
||||||
self.logs = {}
|
self.logs = {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def action_feature(self) -> dict:
|
def action_feature(self) -> dict:
|
||||||
return {
|
# TODO(Steven): Change this when we agree what should this return
|
||||||
"dtype": "float32",
|
...
|
||||||
"shape": (len(self.arm),),
|
|
||||||
"names": {"motors": list(self.arm.motors)},
|
|
||||||
}
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def feedback_feature(self) -> dict:
|
def feedback_feature(self) -> dict:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_connected(self) -> bool:
|
||||||
|
return self._is_connected
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_calibrated(self) -> bool:
|
||||||
|
pass
|
||||||
|
|
||||||
def connect(self) -> None:
|
def connect(self) -> None:
|
||||||
if self.is_connected:
|
# TODO(Steven): Consider early return instead of raising a warning
|
||||||
|
# if self._is_connected:
|
||||||
|
# logging.warning(
|
||||||
|
# "Keyboard is already connected. Do not run `robot.connect()` twice."
|
||||||
|
# )
|
||||||
|
# return self._is_connected
|
||||||
|
if self._is_connected:
|
||||||
raise DeviceAlreadyConnectedError(
|
raise DeviceAlreadyConnectedError(
|
||||||
"ManipulatorRobot is already connected. Do not run `robot.connect()` twice."
|
"Keyboard is already connected. Do not run `robot.connect()` twice."
|
||||||
)
|
)
|
||||||
|
|
||||||
if PYNPUT_AVAILABLE:
|
if PYNPUT_AVAILABLE:
|
||||||
logging.info("pynput is available - enabling local keyboard listener.")
|
logging.info("pynput is available - enabling local keyboard listener.")
|
||||||
self.listener = keyboard.Listener(
|
self.listener = keyboard.Listener(
|
||||||
on_press=self.on_press,
|
on_press=self._on_press,
|
||||||
on_release=self.on_release,
|
on_release=self._on_release,
|
||||||
)
|
)
|
||||||
self.listener.start()
|
self.listener.start()
|
||||||
else:
|
else:
|
||||||
logging.info("pynput not available - skipping local keyboard listener.")
|
logging.info("pynput not available - skipping local keyboard listener.")
|
||||||
self.listener = None
|
self.listener = None
|
||||||
|
|
||||||
self.is_connected = True
|
self._is_connected = True
|
||||||
|
|
||||||
def calibrate(self) -> None:
|
def calibrate(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_press(self, key):
|
def _on_press(self, key):
|
||||||
if hasattr(key, "char"):
|
if hasattr(key, "char"):
|
||||||
self.event_queue.put((key.char, True))
|
self.event_queue.put((key.char, True))
|
||||||
|
|
||||||
def on_release(self, key):
|
def _on_release(self, key):
|
||||||
if hasattr(key, "char"):
|
if hasattr(key, "char"):
|
||||||
self.event_queue.put((key.char, False))
|
self.event_queue.put((key.char, False))
|
||||||
if key == keyboard.Key.esc:
|
if key == keyboard.Key.esc:
|
||||||
|
@ -112,10 +122,13 @@ class KeyboardTeleop(Teleoperator):
|
||||||
key_char, is_pressed = self.event_queue.get_nowait()
|
key_char, is_pressed = self.event_queue.get_nowait()
|
||||||
self.current_pressed[key_char] = is_pressed
|
self.current_pressed[key_char] = is_pressed
|
||||||
|
|
||||||
def get_action(self) -> np.ndarray:
|
def configure(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_action(self) -> dict[str, Any]:
|
||||||
before_read_t = time.perf_counter()
|
before_read_t = time.perf_counter()
|
||||||
|
|
||||||
if not self.is_connected:
|
if not self._is_connected:
|
||||||
raise DeviceNotConnectedError(
|
raise DeviceNotConnectedError(
|
||||||
"KeyboardTeleop is not connected. You need to run `connect()` before `get_action()`."
|
"KeyboardTeleop is not connected. You need to run `connect()` before `get_action()`."
|
||||||
)
|
)
|
||||||
|
@ -126,17 +139,17 @@ class KeyboardTeleop(Teleoperator):
|
||||||
action = {key for key, val in self.current_pressed.items() if val}
|
action = {key for key, val in self.current_pressed.items() if val}
|
||||||
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
|
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
|
||||||
|
|
||||||
return np.array(list(action))
|
return dict.fromkeys(action, None)
|
||||||
|
|
||||||
def send_feedback(self, feedback: np.ndarray) -> None:
|
def send_feedback(self, feedback: dict[str, Any]) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def disconnect(self) -> None:
|
def disconnect(self) -> None:
|
||||||
if not self.is_connected:
|
if not self._is_connected:
|
||||||
raise DeviceNotConnectedError(
|
raise DeviceNotConnectedError(
|
||||||
"KeyboardTeleop is not connected. You need to run `robot.connect()` before `disconnect()`."
|
"KeyboardTeleop is not connected. You need to run `robot.connect()` before `disconnect()`."
|
||||||
)
|
)
|
||||||
if self.listener is not None:
|
if self.listener is not None:
|
||||||
self.listener.stop()
|
self.listener.stop()
|
||||||
|
|
||||||
self.is_connected = False
|
self._is_connected = False
|
||||||
|
|
|
@ -102,7 +102,7 @@ class KochLeader(Teleoperator):
|
||||||
|
|
||||||
full_turn_motors = ["shoulder_pan", "wrist_roll"]
|
full_turn_motors = ["shoulder_pan", "wrist_roll"]
|
||||||
unknown_range_motors = [name for name in self.arm.names if name not in full_turn_motors]
|
unknown_range_motors = [name for name in self.arm.names if name not in full_turn_motors]
|
||||||
logger.info(
|
print(
|
||||||
f"Move all joints except {full_turn_motors} sequentially through their "
|
f"Move all joints except {full_turn_motors} sequentially through their "
|
||||||
"entire ranges of motion.\nRecording positions. Press ENTER to stop..."
|
"entire ranges of motion.\nRecording positions. Press ENTER to stop..."
|
||||||
)
|
)
|
||||||
|
|
|
@ -24,3 +24,4 @@ from ..config import TeleoperatorConfig
|
||||||
class SO100LeaderConfig(TeleoperatorConfig):
|
class SO100LeaderConfig(TeleoperatorConfig):
|
||||||
# Port to connect to the arm
|
# Port to connect to the arm
|
||||||
port: str
|
port: str
|
||||||
|
id = "so100"
|
||||||
|
|
|
@ -51,6 +51,7 @@ class SO100Leader(Teleoperator):
|
||||||
"wrist_roll": Motor(5, "sts3215", MotorNormMode.RANGE_M100_100),
|
"wrist_roll": Motor(5, "sts3215", MotorNormMode.RANGE_M100_100),
|
||||||
"gripper": Motor(6, "sts3215", MotorNormMode.RANGE_0_100),
|
"gripper": Motor(6, "sts3215", MotorNormMode.RANGE_0_100),
|
||||||
},
|
},
|
||||||
|
calibration=self.calibration,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -95,7 +96,7 @@ class SO100Leader(Teleoperator):
|
||||||
|
|
||||||
full_turn_motor = "wrist_roll"
|
full_turn_motor = "wrist_roll"
|
||||||
unknown_range_motors = [name for name in self.arm.names if name != full_turn_motor]
|
unknown_range_motors = [name for name in self.arm.names if name != full_turn_motor]
|
||||||
logger.info(
|
print(
|
||||||
f"Move all joints except '{full_turn_motor}' sequentially through their "
|
f"Move all joints except '{full_turn_motor}' sequentially through their "
|
||||||
"entire ranges of motion.\nRecording positions. Press ENTER to stop..."
|
"entire ranges of motion.\nRecording positions. Press ENTER to stop..."
|
||||||
)
|
)
|
||||||
|
|
|
@ -99,7 +99,7 @@ class WidowX(Teleoperator):
|
||||||
|
|
||||||
full_turn_motors = ["shoulder_pan", "wrist_roll"]
|
full_turn_motors = ["shoulder_pan", "wrist_roll"]
|
||||||
unknown_range_motors = [name for name in self.arm.names if name not in full_turn_motors]
|
unknown_range_motors = [name for name in self.arm.names if name not in full_turn_motors]
|
||||||
logger.info(
|
print(
|
||||||
f"Move all joints except {full_turn_motors} sequentially through their "
|
f"Move all joints except {full_turn_motors} sequentially through their "
|
||||||
"entire ranges of motion.\nRecording positions. Press ENTER to stop..."
|
"entire ranges of motion.\nRecording positions. Press ENTER to stop..."
|
||||||
)
|
)
|
||||||
|
|
|
@ -24,7 +24,7 @@ from contextlib import nullcontext
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from functools import cache
|
from functools import cache
|
||||||
|
|
||||||
import cv2
|
import rerun as rr
|
||||||
import torch
|
import torch
|
||||||
from deepdiff import DeepDiff
|
from deepdiff import DeepDiff
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
@ -174,13 +174,13 @@ def warmup_record(
|
||||||
events,
|
events,
|
||||||
enable_teleoperation,
|
enable_teleoperation,
|
||||||
warmup_time_s,
|
warmup_time_s,
|
||||||
display_cameras,
|
display_data,
|
||||||
fps,
|
fps,
|
||||||
):
|
):
|
||||||
control_loop(
|
control_loop(
|
||||||
robot=robot,
|
robot=robot,
|
||||||
control_time_s=warmup_time_s,
|
control_time_s=warmup_time_s,
|
||||||
display_cameras=display_cameras,
|
display_data=display_data,
|
||||||
events=events,
|
events=events,
|
||||||
fps=fps,
|
fps=fps,
|
||||||
teleoperate=enable_teleoperation,
|
teleoperate=enable_teleoperation,
|
||||||
|
@ -192,7 +192,7 @@ def record_episode(
|
||||||
dataset,
|
dataset,
|
||||||
events,
|
events,
|
||||||
episode_time_s,
|
episode_time_s,
|
||||||
display_cameras,
|
display_data,
|
||||||
policy,
|
policy,
|
||||||
fps,
|
fps,
|
||||||
single_task,
|
single_task,
|
||||||
|
@ -200,7 +200,7 @@ def record_episode(
|
||||||
control_loop(
|
control_loop(
|
||||||
robot=robot,
|
robot=robot,
|
||||||
control_time_s=episode_time_s,
|
control_time_s=episode_time_s,
|
||||||
display_cameras=display_cameras,
|
display_data=display_data,
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
events=events,
|
events=events,
|
||||||
policy=policy,
|
policy=policy,
|
||||||
|
@ -215,7 +215,7 @@ def control_loop(
|
||||||
robot,
|
robot,
|
||||||
control_time_s=None,
|
control_time_s=None,
|
||||||
teleoperate=False,
|
teleoperate=False,
|
||||||
display_cameras=False,
|
display_data=False,
|
||||||
dataset: LeRobotDataset | None = None,
|
dataset: LeRobotDataset | None = None,
|
||||||
events=None,
|
events=None,
|
||||||
policy: PreTrainedPolicy = None,
|
policy: PreTrainedPolicy = None,
|
||||||
|
@ -264,11 +264,15 @@ def control_loop(
|
||||||
frame = {**observation, **action, "task": single_task}
|
frame = {**observation, **action, "task": single_task}
|
||||||
dataset.add_frame(frame)
|
dataset.add_frame(frame)
|
||||||
|
|
||||||
if display_cameras and not is_headless():
|
# TODO(Steven): This should be more general (for RemoteRobot instead of checking the name, but anyways it will change soon)
|
||||||
|
if (display_data and not is_headless()) or (display_data and robot.robot_type.startswith("lekiwi")):
|
||||||
|
for k, v in action.items():
|
||||||
|
for i, vv in enumerate(v):
|
||||||
|
rr.log(f"sent_{k}_{i}", rr.Scalar(vv.numpy()))
|
||||||
|
|
||||||
image_keys = [key for key in observation if "image" in key]
|
image_keys = [key for key in observation if "image" in key]
|
||||||
for key in image_keys:
|
for key in image_keys:
|
||||||
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
rr.log(key, rr.Image(observation[key].numpy()), static=True)
|
||||||
cv2.waitKey(1)
|
|
||||||
|
|
||||||
if fps is not None:
|
if fps is not None:
|
||||||
dt_s = time.perf_counter() - start_loop_t
|
dt_s = time.perf_counter() - start_loop_t
|
||||||
|
@ -297,16 +301,12 @@ def reset_environment(robot, events, reset_time_s, fps):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def stop_recording(robot, listener, display_cameras):
|
def stop_recording(robot, listener, display_data):
|
||||||
robot.disconnect()
|
robot.disconnect()
|
||||||
|
|
||||||
if not is_headless():
|
if not is_headless() and listener is not None:
|
||||||
if listener is not None:
|
|
||||||
listener.stop()
|
listener.stop()
|
||||||
|
|
||||||
if display_cameras:
|
|
||||||
cv2.destroyAllWindows()
|
|
||||||
|
|
||||||
|
|
||||||
def sanity_check_dataset_name(repo_id, policy_cfg):
|
def sanity_check_dataset_name(repo_id, policy_cfg):
|
||||||
_, dataset_name = repo_id.split("/")
|
_, dataset_name = repo_id.split("/")
|
||||||
|
|
|
@ -41,7 +41,7 @@ class TeleoperateControlConfig(ControlConfig):
|
||||||
fps: int | None = None
|
fps: int | None = None
|
||||||
teleop_time_s: float | None = None
|
teleop_time_s: float | None = None
|
||||||
# Display all cameras on screen
|
# Display all cameras on screen
|
||||||
display_cameras: bool = True
|
display_data: bool = False
|
||||||
|
|
||||||
|
|
||||||
@ControlConfig.register_subclass("record")
|
@ControlConfig.register_subclass("record")
|
||||||
|
@ -82,7 +82,7 @@ class RecordControlConfig(ControlConfig):
|
||||||
# Not enough threads might cause low camera fps.
|
# Not enough threads might cause low camera fps.
|
||||||
num_image_writer_threads_per_camera: int = 4
|
num_image_writer_threads_per_camera: int = 4
|
||||||
# Display all cameras on screen
|
# Display all cameras on screen
|
||||||
display_cameras: bool = True
|
display_data: bool = False
|
||||||
# Use vocal synthesis to read events.
|
# Use vocal synthesis to read events.
|
||||||
play_sounds: bool = True
|
play_sounds: bool = True
|
||||||
# Resume recording on an existing dataset.
|
# Resume recording on an existing dataset.
|
||||||
|
@ -116,6 +116,11 @@ class ReplayControlConfig(ControlConfig):
|
||||||
@dataclass
|
@dataclass
|
||||||
class RemoteRobotConfig(ControlConfig):
|
class RemoteRobotConfig(ControlConfig):
|
||||||
log_interval: int = 100
|
log_interval: int = 100
|
||||||
|
# Display all cameras on screen
|
||||||
|
display_data: bool = False
|
||||||
|
# Rerun configuration for remote robot (https://ref.rerun.io/docs/python/0.22.1/common/initialization_functions/#rerun.connect_tcp)
|
||||||
|
viewer_ip: str | None = None
|
||||||
|
viewer_port: str | None = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
@ -135,10 +135,13 @@ python lerobot/scripts/control_robot.py \
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
|
|
||||||
|
import rerun as rr
|
||||||
|
|
||||||
# from safetensors.torch import load_file, save_file
|
# from safetensors.torch import load_file, save_file
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
|
@ -146,6 +149,7 @@ from lerobot.common.robots.utils import Robot, make_robot_from_config
|
||||||
from lerobot.common.utils.control_utils import (
|
from lerobot.common.utils.control_utils import (
|
||||||
control_loop,
|
control_loop,
|
||||||
init_keyboard_listener,
|
init_keyboard_listener,
|
||||||
|
is_headless,
|
||||||
log_control_info,
|
log_control_info,
|
||||||
record_episode,
|
record_episode,
|
||||||
reset_environment,
|
reset_environment,
|
||||||
|
@ -159,6 +163,7 @@ from lerobot.common.utils.utils import has_method, init_logging, log_say
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.configs.control import (
|
from lerobot.configs.control import (
|
||||||
CalibrateControlConfig,
|
CalibrateControlConfig,
|
||||||
|
ControlConfig,
|
||||||
ControlPipelineConfig,
|
ControlPipelineConfig,
|
||||||
RecordControlConfig,
|
RecordControlConfig,
|
||||||
RemoteRobotConfig,
|
RemoteRobotConfig,
|
||||||
|
@ -232,7 +237,7 @@ def teleoperate(robot: Robot, cfg: TeleoperateControlConfig):
|
||||||
control_time_s=cfg.teleop_time_s,
|
control_time_s=cfg.teleop_time_s,
|
||||||
fps=cfg.fps,
|
fps=cfg.fps,
|
||||||
teleoperate=True,
|
teleoperate=True,
|
||||||
display_cameras=cfg.display_cameras,
|
display_data=cfg.display_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -280,7 +285,7 @@ def record(
|
||||||
# 3. place the cameras windows on screen
|
# 3. place the cameras windows on screen
|
||||||
enable_teleoperation = policy is None
|
enable_teleoperation = policy is None
|
||||||
log_say("Warmup record", cfg.play_sounds)
|
log_say("Warmup record", cfg.play_sounds)
|
||||||
warmup_record(robot, events, enable_teleoperation, cfg.warmup_time_s, cfg.display_cameras, cfg.fps)
|
warmup_record(robot, events, enable_teleoperation, cfg.warmup_time_s, cfg.display_data, cfg.fps)
|
||||||
|
|
||||||
if has_method(robot, "teleop_safety_stop"):
|
if has_method(robot, "teleop_safety_stop"):
|
||||||
robot.teleop_safety_stop()
|
robot.teleop_safety_stop()
|
||||||
|
@ -296,7 +301,7 @@ def record(
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
events=events,
|
events=events,
|
||||||
episode_time_s=cfg.episode_time_s,
|
episode_time_s=cfg.episode_time_s,
|
||||||
display_cameras=cfg.display_cameras,
|
display_data=cfg.display_data,
|
||||||
policy=policy,
|
policy=policy,
|
||||||
fps=cfg.fps,
|
fps=cfg.fps,
|
||||||
single_task=cfg.single_task,
|
single_task=cfg.single_task,
|
||||||
|
@ -326,7 +331,7 @@ def record(
|
||||||
break
|
break
|
||||||
|
|
||||||
log_say("Stop recording", cfg.play_sounds, blocking=True)
|
log_say("Stop recording", cfg.play_sounds, blocking=True)
|
||||||
stop_recording(robot, listener, cfg.display_cameras)
|
stop_recording(robot, listener, cfg.display_data)
|
||||||
|
|
||||||
if cfg.push_to_hub:
|
if cfg.push_to_hub:
|
||||||
dataset.push_to_hub(tags=cfg.tags, private=cfg.private)
|
dataset.push_to_hub(tags=cfg.tags, private=cfg.private)
|
||||||
|
@ -363,6 +368,40 @@ def replay(
|
||||||
log_control_info(robot, dt_s, fps=cfg.fps)
|
log_control_info(robot, dt_s, fps=cfg.fps)
|
||||||
|
|
||||||
|
|
||||||
|
def _init_rerun(control_config: ControlConfig, session_name: str = "lerobot_control_loop") -> None:
|
||||||
|
"""Initializes the Rerun SDK for visualizing the control loop.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
control_config: Configuration determining data display and robot type.
|
||||||
|
session_name: Rerun session name. Defaults to "lerobot_control_loop".
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If viewer IP is missing for non-remote configurations with display enabled.
|
||||||
|
"""
|
||||||
|
if (control_config.display_data and not is_headless()) or (
|
||||||
|
control_config.display_data and isinstance(control_config, RemoteRobotConfig)
|
||||||
|
):
|
||||||
|
# Configure Rerun flush batch size default to 8KB if not set
|
||||||
|
batch_size = os.getenv("RERUN_FLUSH_NUM_BYTES", "8000")
|
||||||
|
os.environ["RERUN_FLUSH_NUM_BYTES"] = batch_size
|
||||||
|
|
||||||
|
# Initialize Rerun based on configuration
|
||||||
|
rr.init(session_name)
|
||||||
|
if isinstance(control_config, RemoteRobotConfig):
|
||||||
|
viewer_ip = control_config.viewer_ip
|
||||||
|
viewer_port = control_config.viewer_port
|
||||||
|
if not viewer_ip or not viewer_port:
|
||||||
|
raise ValueError(
|
||||||
|
"Viewer IP & Port are required for remote config. Set via config file/CLI or disable control_config.display_data."
|
||||||
|
)
|
||||||
|
logging.info(f"Connecting to viewer at {viewer_ip}:{viewer_port}")
|
||||||
|
rr.connect_tcp(f"{viewer_ip}:{viewer_port}")
|
||||||
|
else:
|
||||||
|
# Get memory limit for rerun viewer parameters
|
||||||
|
memory_limit = os.getenv("LEROBOT_RERUN_MEMORY_LIMIT", "10%")
|
||||||
|
rr.spawn(memory_limit=memory_limit)
|
||||||
|
|
||||||
|
|
||||||
@parser.wrap()
|
@parser.wrap()
|
||||||
def control_robot(cfg: ControlPipelineConfig):
|
def control_robot(cfg: ControlPipelineConfig):
|
||||||
init_logging()
|
init_logging()
|
||||||
|
@ -370,18 +409,24 @@ def control_robot(cfg: ControlPipelineConfig):
|
||||||
|
|
||||||
robot = make_robot_from_config(cfg.robot)
|
robot = make_robot_from_config(cfg.robot)
|
||||||
|
|
||||||
|
# TODO(Steven): Blueprint for fixed window size
|
||||||
|
|
||||||
if isinstance(cfg.control, CalibrateControlConfig):
|
if isinstance(cfg.control, CalibrateControlConfig):
|
||||||
calibrate(robot, cfg.control)
|
calibrate(robot, cfg.control)
|
||||||
elif isinstance(cfg.control, TeleoperateControlConfig):
|
elif isinstance(cfg.control, TeleoperateControlConfig):
|
||||||
|
_init_rerun(control_config=cfg.control, session_name="lerobot_control_loop_teleop")
|
||||||
teleoperate(robot, cfg.control)
|
teleoperate(robot, cfg.control)
|
||||||
elif isinstance(cfg.control, RecordControlConfig):
|
elif isinstance(cfg.control, RecordControlConfig):
|
||||||
|
_init_rerun(control_config=cfg.control, session_name="lerobot_control_loop_record")
|
||||||
record(robot, cfg.control)
|
record(robot, cfg.control)
|
||||||
elif isinstance(cfg.control, ReplayControlConfig):
|
elif isinstance(cfg.control, ReplayControlConfig):
|
||||||
replay(robot, cfg.control)
|
replay(robot, cfg.control)
|
||||||
elif isinstance(cfg.control, RemoteRobotConfig):
|
elif isinstance(cfg.control, RemoteRobotConfig):
|
||||||
from lerobot.common.robots.lekiwi.lekiwi_remote import run_lekiwi
|
...
|
||||||
|
# TODO(Steven): Change this when we decide what to do with the control_robot script
|
||||||
run_lekiwi(cfg.robot)
|
# from lerobot.common.robots.lekiwi.old_lekiwi_remote import run_lekiwi
|
||||||
|
# _init_rerun(control_config=cfg.control, session_name="lerobot_control_loop_remote")
|
||||||
|
# run_lekiwi(cfg.robot)
|
||||||
|
|
||||||
if robot.is_connected:
|
if robot.is_connected:
|
||||||
# Disconnect manually to avoid a "Core dump" during process
|
# Disconnect manually to avoid a "Core dump" during process
|
||||||
|
|
|
@ -60,9 +60,9 @@ dependencies = [
|
||||||
"jsonlines>=4.0.0",
|
"jsonlines>=4.0.0",
|
||||||
"numba>=0.59.0",
|
"numba>=0.59.0",
|
||||||
"omegaconf>=2.3.0",
|
"omegaconf>=2.3.0",
|
||||||
"opencv-python>=4.9.0",
|
"opencv-python-headless>=4.9.0",
|
||||||
"packaging>=24.2",
|
"packaging>=24.2",
|
||||||
"av>=12.0.5,<13.0.0",
|
"av>=12.0.5",
|
||||||
"pymunk>=6.6.0",
|
"pymunk>=6.6.0",
|
||||||
"pynput>=1.7.7",
|
"pynput>=1.7.7",
|
||||||
"pyzmq>=26.2.1",
|
"pyzmq>=26.2.1",
|
||||||
|
|
|
@ -5,7 +5,7 @@ import dynamixel_sdk as dxl
|
||||||
import serial
|
import serial
|
||||||
from mock_serial.mock_serial import MockSerial
|
from mock_serial.mock_serial import MockSerial
|
||||||
|
|
||||||
from lerobot.common.motors.dynamixel import X_SERIES_CONTROL_TABLE, DynamixelMotorsBus
|
from lerobot.common.motors.dynamixel.dynamixel import _split_into_byte_chunks
|
||||||
|
|
||||||
from .mock_serial_patch import WaitableStub
|
from .mock_serial_patch import WaitableStub
|
||||||
|
|
||||||
|
@ -45,41 +45,6 @@ DXL_CRC_TABLE = [
|
||||||
0x8213, 0x0216, 0x021C, 0x8219, 0x0208, 0x820D, 0x8207, 0x0202
|
0x8213, 0x0216, 0x021C, 0x8219, 0x0208, 0x820D, 0x8207, 0x0202
|
||||||
] # fmt: skip
|
] # fmt: skip
|
||||||
|
|
||||||
# https://emanual.robotis.com/docs/en/dxl/protocol2/#instruction
|
|
||||||
INSTRUCTION_TYPES = {
|
|
||||||
"Ping": dxl.INST_PING, # Checks whether the Packet has arrived at a device with the same ID as the specified packet ID
|
|
||||||
"Read": dxl.INST_READ, # Read data from the Device
|
|
||||||
"Write": dxl.INST_WRITE, # Write data to the Device
|
|
||||||
"Reg_Write": dxl.INST_REG_WRITE, # Register the Instruction Packet in standby status; Packet can later be executed using the Action command
|
|
||||||
"Action": dxl.INST_ACTION, # Executes a Packet that was registered beforehand using Reg Write
|
|
||||||
"Factory_Reset": dxl.INST_FACTORY_RESET, # Resets the Control Table to its initial factory default settings
|
|
||||||
"Reboot": dxl.INST_REBOOT, # Reboot the Device
|
|
||||||
"Clear": dxl.INST_CLEAR, # Reset certain information stored in memory
|
|
||||||
"Control_Table_Backup": 0x20, # Store current Control Table status data to a Backup or to restore backup EEPROM data.
|
|
||||||
"Status": dxl.INST_STATUS, # Return packet sent following the execution of an Instruction Packet
|
|
||||||
"Sync_Read": dxl.INST_SYNC_READ, # Read data from multiple devices with the same Address with the same length at once
|
|
||||||
"Sync_Write": dxl.INST_SYNC_WRITE, # Write data to multiple devices with the same Address with the same length at once
|
|
||||||
"Fast_Sync_Read": 0x8A, # Read data from multiple devices with the same Address with the same length at once
|
|
||||||
"Bulk_Read": dxl.INST_BULK_READ, # Read data from multiple devices with different Addresses with different lengths at once
|
|
||||||
"Bulk_Write": dxl.INST_BULK_WRITE, # Write data to multiple devices with different Addresses with different lengths at once
|
|
||||||
"Fast_Bulk_Read": 0x9A, # Read data from multiple devices with different Addresses with different lengths at once
|
|
||||||
} # fmt: skip
|
|
||||||
|
|
||||||
# https://emanual.robotis.com/docs/en/dxl/protocol2/#error
|
|
||||||
ERROR_TYPE = {
|
|
||||||
"Success": 0x00, # No error
|
|
||||||
"Result_Fail": dxl.ERRNUM_RESULT_FAIL, # Failed to process the sent Instruction Packet
|
|
||||||
"Instruction_Error": dxl.ERRNUM_INSTRUCTION, # An undefined Instruction has been usedAction has been used without Reg Write
|
|
||||||
"CRC_Error": dxl.ERRNUM_CRC, # The CRC of the sent Packet does not match the expected value
|
|
||||||
"Data_Range_Error": dxl.ERRNUM_DATA_RANGE, # Data to be written to the specified Address is outside the range of the minimum/maximum value
|
|
||||||
"Data_Length_Error": dxl.ERRNUM_DATA_LENGTH, # Attempted to write Data that is shorter than the required data length of the specified Address
|
|
||||||
# (ex: when you attempt to only use 2 bytes of a register that has been defined as 4 bytes)
|
|
||||||
"Data_Limit_Error": dxl.ERRNUM_DATA_LIMIT, # Data to be written to the specified Address is outside of the configured Limit value
|
|
||||||
"Access_Error": dxl.ERRNUM_ACCESS, # Attempted to write a value to an Address that is Read Only or has not been defined
|
|
||||||
# Attempted to read a value from an Address that is Write Only or has not been defined
|
|
||||||
# Attempted to write a value to an EEPROM register while Torque was Enabled.
|
|
||||||
} # fmt: skip
|
|
||||||
|
|
||||||
|
|
||||||
class MockDynamixelPacketv2(abc.ABC):
|
class MockDynamixelPacketv2(abc.ABC):
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -186,14 +151,14 @@ class MockInstructionPacket(MockDynamixelPacketv2):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _build(cls, dxl_id: int, params: list[int], length: int, instruct_type: str) -> list[int]:
|
def _build(cls, dxl_id: int, params: list[int], length: int, instruction: int) -> list[int]:
|
||||||
instruct_value = INSTRUCTION_TYPES[instruct_type]
|
length = len(params) + 3
|
||||||
return [
|
return [
|
||||||
0xFF, 0xFF, 0xFD, 0x00, # header
|
0xFF, 0xFF, 0xFD, 0x00, # header
|
||||||
dxl_id, # servo id
|
dxl_id, # servo id
|
||||||
dxl.DXL_LOBYTE(length), # length_l
|
dxl.DXL_LOBYTE(length), # length_l
|
||||||
dxl.DXL_HIBYTE(length), # length_h
|
dxl.DXL_HIBYTE(length), # length_h
|
||||||
instruct_value, # instruction type
|
instruction, # instruction type
|
||||||
*params, # data bytes
|
*params, # data bytes
|
||||||
0x00, 0x00 # placeholder for CRC
|
0x00, 0x00 # placeholder for CRC
|
||||||
] # fmt: skip
|
] # fmt: skip
|
||||||
|
@ -209,8 +174,39 @@ class MockInstructionPacket(MockDynamixelPacketv2):
|
||||||
|
|
||||||
No parameters required.
|
No parameters required.
|
||||||
"""
|
"""
|
||||||
params, length = [], 3
|
return cls.build(dxl_id=dxl_id, params=[], length=3, instruction=dxl.INST_PING)
|
||||||
return cls.build(dxl_id=dxl_id, params=params, length=length, instruct_type="Ping")
|
|
||||||
|
@classmethod
|
||||||
|
def read(
|
||||||
|
cls,
|
||||||
|
dxl_id: int,
|
||||||
|
start_address: int,
|
||||||
|
data_length: int,
|
||||||
|
) -> bytes:
|
||||||
|
"""
|
||||||
|
Builds a "Read" instruction.
|
||||||
|
https://emanual.robotis.com/docs/en/dxl/protocol2/#read-0x02
|
||||||
|
|
||||||
|
The parameters for Read (Protocol 2.0) are:
|
||||||
|
param[0] = start_address L
|
||||||
|
param[1] = start_address H
|
||||||
|
param[2] = data_length L
|
||||||
|
param[3] = data_length H
|
||||||
|
|
||||||
|
And 'length' = data_length + 5, where:
|
||||||
|
+1 is for instruction byte,
|
||||||
|
+2 is for the length bytes,
|
||||||
|
+2 is for the CRC at the end.
|
||||||
|
"""
|
||||||
|
params = [
|
||||||
|
dxl.DXL_LOBYTE(start_address),
|
||||||
|
dxl.DXL_HIBYTE(start_address),
|
||||||
|
dxl.DXL_LOBYTE(data_length),
|
||||||
|
dxl.DXL_HIBYTE(data_length),
|
||||||
|
]
|
||||||
|
length = len(params) + 3
|
||||||
|
# length = data_length + 5
|
||||||
|
return cls.build(dxl_id=dxl_id, params=params, length=length, instruction=dxl.INST_READ)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def write(
|
def write(
|
||||||
|
@ -237,14 +233,14 @@ class MockInstructionPacket(MockDynamixelPacketv2):
|
||||||
+2 is for the length bytes,
|
+2 is for the length bytes,
|
||||||
+2 is for the CRC at the end.
|
+2 is for the CRC at the end.
|
||||||
"""
|
"""
|
||||||
data = DynamixelMotorsBus._split_int_to_bytes(value, data_length)
|
data = _split_into_byte_chunks(value, data_length)
|
||||||
params = [
|
params = [
|
||||||
dxl.DXL_LOBYTE(start_address),
|
dxl.DXL_LOBYTE(start_address),
|
||||||
dxl.DXL_HIBYTE(start_address),
|
dxl.DXL_HIBYTE(start_address),
|
||||||
*data,
|
*data,
|
||||||
]
|
]
|
||||||
length = data_length + 5
|
length = data_length + 5
|
||||||
return cls.build(dxl_id=dxl_id, params=params, length=length, instruct_type="Write")
|
return cls.build(dxl_id=dxl_id, params=params, length=length, instruction=dxl.INST_WRITE)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sync_read(
|
def sync_read(
|
||||||
|
@ -278,7 +274,9 @@ class MockInstructionPacket(MockDynamixelPacketv2):
|
||||||
*dxl_ids,
|
*dxl_ids,
|
||||||
]
|
]
|
||||||
length = len(dxl_ids) + 7
|
length = len(dxl_ids) + 7
|
||||||
return cls.build(dxl_id=dxl.BROADCAST_ID, params=params, length=length, instruct_type="Sync_Read")
|
return cls.build(
|
||||||
|
dxl_id=dxl.BROADCAST_ID, params=params, length=length, instruction=dxl.INST_SYNC_READ
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sync_write(
|
def sync_write(
|
||||||
|
@ -315,7 +313,7 @@ class MockInstructionPacket(MockDynamixelPacketv2):
|
||||||
"""
|
"""
|
||||||
data = []
|
data = []
|
||||||
for id_, value in ids_values.items():
|
for id_, value in ids_values.items():
|
||||||
split_value = DynamixelMotorsBus._split_int_to_bytes(value, data_length)
|
split_value = _split_into_byte_chunks(value, data_length)
|
||||||
data += [id_, *split_value]
|
data += [id_, *split_value]
|
||||||
params = [
|
params = [
|
||||||
dxl.DXL_LOBYTE(start_address),
|
dxl.DXL_LOBYTE(start_address),
|
||||||
|
@ -325,7 +323,9 @@ class MockInstructionPacket(MockDynamixelPacketv2):
|
||||||
*data,
|
*data,
|
||||||
]
|
]
|
||||||
length = len(ids_values) * (1 + data_length) + 7
|
length = len(ids_values) * (1 + data_length) + 7
|
||||||
return cls.build(dxl_id=dxl.BROADCAST_ID, params=params, length=length, instruct_type="Sync_Write")
|
return cls.build(
|
||||||
|
dxl_id=dxl.BROADCAST_ID, params=params, length=length, instruction=dxl.INST_SYNC_WRITE
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MockStatusPacket(MockDynamixelPacketv2):
|
class MockStatusPacket(MockDynamixelPacketv2):
|
||||||
|
@ -341,21 +341,20 @@ class MockStatusPacket(MockDynamixelPacketv2):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _build(cls, dxl_id: int, params: list[int], length: int, error: str = "Success") -> list[int]:
|
def _build(cls, dxl_id: int, params: list[int], length: int, error: int = 0) -> list[int]:
|
||||||
err_byte = ERROR_TYPE[error]
|
|
||||||
return [
|
return [
|
||||||
0xFF, 0xFF, 0xFD, 0x00, # header
|
0xFF, 0xFF, 0xFD, 0x00, # header
|
||||||
dxl_id, # servo id
|
dxl_id, # servo id
|
||||||
dxl.DXL_LOBYTE(length), # length_l
|
dxl.DXL_LOBYTE(length), # length_l
|
||||||
dxl.DXL_HIBYTE(length), # length_h
|
dxl.DXL_HIBYTE(length), # length_h
|
||||||
0x55, # instruction = 'status'
|
0x55, # instruction = 'status'
|
||||||
err_byte, # error
|
error, # error
|
||||||
*params, # data bytes
|
*params, # data bytes
|
||||||
0x00, 0x00 # placeholder for CRC
|
0x00, 0x00 # placeholder for CRC
|
||||||
] # fmt: skip
|
] # fmt: skip
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def ping(cls, dxl_id: int, model_nb: int = 1190, firm_ver: int = 50) -> bytes:
|
def ping(cls, dxl_id: int, model_nb: int = 1190, firm_ver: int = 50, error: int = 0) -> bytes:
|
||||||
"""
|
"""
|
||||||
Builds a 'Ping' status packet.
|
Builds a 'Ping' status packet.
|
||||||
https://emanual.robotis.com/docs/en/dxl/protocol2/#ping-0x01
|
https://emanual.robotis.com/docs/en/dxl/protocol2/#ping-0x01
|
||||||
|
@ -372,10 +371,10 @@ class MockStatusPacket(MockDynamixelPacketv2):
|
||||||
"""
|
"""
|
||||||
params = [dxl.DXL_LOBYTE(model_nb), dxl.DXL_HIBYTE(model_nb), firm_ver]
|
params = [dxl.DXL_LOBYTE(model_nb), dxl.DXL_HIBYTE(model_nb), firm_ver]
|
||||||
length = 7
|
length = 7
|
||||||
return cls.build(dxl_id, params=params, length=length)
|
return cls.build(dxl_id, params=params, length=length, error=error)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def read(cls, dxl_id: int, value: int, param_length: int) -> bytes:
|
def read(cls, dxl_id: int, value: int, param_length: int, error: int = 0) -> bytes:
|
||||||
"""
|
"""
|
||||||
Builds a 'Read' status packet (also works for 'Sync Read')
|
Builds a 'Read' status packet (also works for 'Sync Read')
|
||||||
https://emanual.robotis.com/docs/en/dxl/protocol2/#read-0x02
|
https://emanual.robotis.com/docs/en/dxl/protocol2/#read-0x02
|
||||||
|
@ -389,9 +388,9 @@ class MockStatusPacket(MockDynamixelPacketv2):
|
||||||
Returns:
|
Returns:
|
||||||
bytes: The raw 'Present_Position' status packet ready to be sent through serial.
|
bytes: The raw 'Present_Position' status packet ready to be sent through serial.
|
||||||
"""
|
"""
|
||||||
params = DynamixelMotorsBus._split_int_to_bytes(value, param_length)
|
params = _split_into_byte_chunks(value, param_length)
|
||||||
length = param_length + 4
|
length = param_length + 4
|
||||||
return cls.build(dxl_id, params=params, length=length)
|
return cls.build(dxl_id, params=params, length=length, error=error)
|
||||||
|
|
||||||
|
|
||||||
class MockPortHandler(dxl.PortHandler):
|
class MockPortHandler(dxl.PortHandler):
|
||||||
|
@ -425,8 +424,6 @@ class MockMotors(MockSerial):
|
||||||
instruction packets. It is meant to test MotorsBus classes.
|
instruction packets. It is meant to test MotorsBus classes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
ctrl_table = X_SERIES_CONTROL_TABLE
|
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -455,10 +452,10 @@ class MockMotors(MockSerial):
|
||||||
return stub_name
|
return stub_name
|
||||||
|
|
||||||
def build_ping_stub(
|
def build_ping_stub(
|
||||||
self, dxl_id: int, model_nb: int, firm_ver: int = 50, num_invalid_try: int = 0
|
self, dxl_id: int, model_nb: int, firm_ver: int = 50, num_invalid_try: int = 0, error: int = 0
|
||||||
) -> str:
|
) -> str:
|
||||||
ping_request = MockInstructionPacket.ping(dxl_id)
|
ping_request = MockInstructionPacket.ping(dxl_id)
|
||||||
return_packet = MockStatusPacket.ping(dxl_id, model_nb, firm_ver)
|
return_packet = MockStatusPacket.ping(dxl_id, model_nb, firm_ver, error)
|
||||||
ping_response = self._build_send_fn(return_packet, num_invalid_try)
|
ping_response = self._build_send_fn(return_packet, num_invalid_try)
|
||||||
stub_name = f"Ping_{dxl_id}"
|
stub_name = f"Ping_{dxl_id}"
|
||||||
self.stub(
|
self.stub(
|
||||||
|
@ -468,14 +465,63 @@ class MockMotors(MockSerial):
|
||||||
)
|
)
|
||||||
return stub_name
|
return stub_name
|
||||||
|
|
||||||
def build_sync_read_stub(
|
def build_read_stub(
|
||||||
self, data_name: str, ids_values: dict[int, int] | None = None, num_invalid_try: int = 0
|
self,
|
||||||
|
address: int,
|
||||||
|
length: int,
|
||||||
|
dxl_id: int,
|
||||||
|
value: int,
|
||||||
|
reply: bool = True,
|
||||||
|
error: int = 0,
|
||||||
|
num_invalid_try: int = 0,
|
||||||
|
) -> str:
|
||||||
|
read_request = MockInstructionPacket.read(dxl_id, address, length)
|
||||||
|
return_packet = MockStatusPacket.read(dxl_id, value, length, error) if reply else b""
|
||||||
|
read_response = self._build_send_fn(return_packet, num_invalid_try)
|
||||||
|
stub_name = f"Read_{address}_{length}_{dxl_id}_{value}_{error}"
|
||||||
|
self.stub(
|
||||||
|
name=stub_name,
|
||||||
|
receive_bytes=read_request,
|
||||||
|
send_fn=read_response,
|
||||||
|
)
|
||||||
|
return stub_name
|
||||||
|
|
||||||
|
def build_write_stub(
|
||||||
|
self,
|
||||||
|
address: int,
|
||||||
|
length: int,
|
||||||
|
dxl_id: int,
|
||||||
|
value: int,
|
||||||
|
reply: bool = True,
|
||||||
|
error: int = 0,
|
||||||
|
num_invalid_try: int = 0,
|
||||||
|
) -> str:
|
||||||
|
sync_read_request = MockInstructionPacket.write(dxl_id, value, address, length)
|
||||||
|
return_packet = MockStatusPacket.build(dxl_id, params=[], length=4, error=error) if reply else b""
|
||||||
|
stub_name = f"Write_{address}_{length}_{dxl_id}"
|
||||||
|
self.stub(
|
||||||
|
name=stub_name,
|
||||||
|
receive_bytes=sync_read_request,
|
||||||
|
send_fn=self._build_send_fn(return_packet, num_invalid_try),
|
||||||
|
)
|
||||||
|
return stub_name
|
||||||
|
|
||||||
|
def build_sync_read_stub(
|
||||||
|
self,
|
||||||
|
address: int,
|
||||||
|
length: int,
|
||||||
|
ids_values: dict[int, int],
|
||||||
|
reply: bool = True,
|
||||||
|
num_invalid_try: int = 0,
|
||||||
) -> str:
|
) -> str:
|
||||||
address, length = self.ctrl_table[data_name]
|
|
||||||
sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length)
|
sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length)
|
||||||
return_packets = b"".join(MockStatusPacket.read(id_, pos, length) for id_, pos in ids_values.items())
|
return_packets = (
|
||||||
|
b"".join(MockStatusPacket.read(id_, pos, length) for id_, pos in ids_values.items())
|
||||||
|
if reply
|
||||||
|
else b""
|
||||||
|
)
|
||||||
sync_read_response = self._build_send_fn(return_packets, num_invalid_try)
|
sync_read_response = self._build_send_fn(return_packets, num_invalid_try)
|
||||||
stub_name = f"Sync_Read_{data_name}_" + "_".join([str(id_) for id_ in ids_values])
|
stub_name = f"Sync_Read_{address}_{length}_" + "_".join([str(id_) for id_ in ids_values])
|
||||||
self.stub(
|
self.stub(
|
||||||
name=stub_name,
|
name=stub_name,
|
||||||
receive_bytes=sync_read_request,
|
receive_bytes=sync_read_request,
|
||||||
|
@ -484,11 +530,10 @@ class MockMotors(MockSerial):
|
||||||
return stub_name
|
return stub_name
|
||||||
|
|
||||||
def build_sequential_sync_read_stub(
|
def build_sequential_sync_read_stub(
|
||||||
self, data_name: str, ids_values: dict[int, list[int]] | None = None
|
self, address: int, length: int, ids_values: dict[int, list[int]] | None = None
|
||||||
) -> str:
|
) -> str:
|
||||||
sequence_length = len(next(iter(ids_values.values())))
|
sequence_length = len(next(iter(ids_values.values())))
|
||||||
assert all(len(positions) == sequence_length for positions in ids_values.values())
|
assert all(len(positions) == sequence_length for positions in ids_values.values())
|
||||||
address, length = self.ctrl_table[data_name]
|
|
||||||
sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length)
|
sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length)
|
||||||
sequential_packets = []
|
sequential_packets = []
|
||||||
for count in range(sequence_length):
|
for count in range(sequence_length):
|
||||||
|
@ -498,7 +543,7 @@ class MockMotors(MockSerial):
|
||||||
sequential_packets.append(return_packets)
|
sequential_packets.append(return_packets)
|
||||||
|
|
||||||
sync_read_response = self._build_sequential_send_fn(sequential_packets)
|
sync_read_response = self._build_sequential_send_fn(sequential_packets)
|
||||||
stub_name = f"Seq_Sync_Read_{data_name}_" + "_".join([str(id_) for id_ in ids_values])
|
stub_name = f"Seq_Sync_Read_{address}_{length}_" + "_".join([str(id_) for id_ in ids_values])
|
||||||
self.stub(
|
self.stub(
|
||||||
name=stub_name,
|
name=stub_name,
|
||||||
receive_bytes=sync_read_request,
|
receive_bytes=sync_read_request,
|
||||||
|
@ -507,11 +552,10 @@ class MockMotors(MockSerial):
|
||||||
return stub_name
|
return stub_name
|
||||||
|
|
||||||
def build_sync_write_stub(
|
def build_sync_write_stub(
|
||||||
self, data_name: str, ids_values: dict[int, int] | None = None, num_invalid_try: int = 0
|
self, address: int, length: int, ids_values: dict[int, int], num_invalid_try: int = 0
|
||||||
) -> str:
|
) -> str:
|
||||||
address, length = self.ctrl_table[data_name]
|
|
||||||
sync_read_request = MockInstructionPacket.sync_write(ids_values, address, length)
|
sync_read_request = MockInstructionPacket.sync_write(ids_values, address, length)
|
||||||
stub_name = f"Sync_Write_{data_name}_" + "_".join([str(id_) for id_ in ids_values])
|
stub_name = f"Sync_Write_{address}_{length}_" + "_".join([str(id_) for id_ in ids_values])
|
||||||
self.stub(
|
self.stub(
|
||||||
name=stub_name,
|
name=stub_name,
|
||||||
receive_bytes=sync_read_request,
|
receive_bytes=sync_read_request,
|
||||||
|
@ -519,20 +563,6 @@ class MockMotors(MockSerial):
|
||||||
)
|
)
|
||||||
return stub_name
|
return stub_name
|
||||||
|
|
||||||
def build_write_stub(
|
|
||||||
self, data_name: str, dxl_id: int, value: int, error: str = "Success", num_invalid_try: int = 0
|
|
||||||
) -> str:
|
|
||||||
address, length = self.ctrl_table[data_name]
|
|
||||||
sync_read_request = MockInstructionPacket.write(dxl_id, value, address, length)
|
|
||||||
return_packet = MockStatusPacket.build(dxl_id, params=[], length=4, error=error)
|
|
||||||
stub_name = f"Write_{data_name}_{dxl_id}"
|
|
||||||
self.stub(
|
|
||||||
name=stub_name,
|
|
||||||
receive_bytes=sync_read_request,
|
|
||||||
send_fn=self._build_send_fn(return_packet, num_invalid_try),
|
|
||||||
)
|
|
||||||
return stub_name
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _build_send_fn(packet: bytes, num_invalid_try: int = 0) -> Callable[[int], bytes]:
|
def _build_send_fn(packet: bytes, num_invalid_try: int = 0) -> Callable[[int], bytes]:
|
||||||
def send_fn(_call_count: int) -> bytes:
|
def send_fn(_call_count: int) -> bytes:
|
||||||
|
|
|
@ -5,32 +5,10 @@ import scservo_sdk as scs
|
||||||
import serial
|
import serial
|
||||||
from mock_serial import MockSerial
|
from mock_serial import MockSerial
|
||||||
|
|
||||||
from lerobot.common.motors.feetech import STS_SMS_SERIES_CONTROL_TABLE, FeetechMotorsBus
|
from lerobot.common.motors.feetech.feetech import _split_into_byte_chunks, patch_setPacketTimeout
|
||||||
from lerobot.common.motors.feetech.feetech import patch_setPacketTimeout
|
|
||||||
|
|
||||||
from .mock_serial_patch import WaitableStub
|
from .mock_serial_patch import WaitableStub
|
||||||
|
|
||||||
# https://files.waveshare.com/upload/2/27/Communication_Protocol_User_Manual-EN%28191218-0923%29.pdf
|
|
||||||
INSTRUCTION_TYPES = {
|
|
||||||
"Read": scs.INST_PING, # Read data from the Device
|
|
||||||
"Ping": scs.INST_READ, # Checks whether the Packet has arrived at a device with the same ID as the specified packet ID
|
|
||||||
"Write": scs.INST_WRITE, # Write data to the Device
|
|
||||||
"Reg_Write": scs.INST_REG_WRITE, # Register the Instruction Packet in standby status; Packet can later be executed using the Action command
|
|
||||||
"Action": scs.INST_ACTION, # Executes a Packet that was registered beforehand using Reg Write
|
|
||||||
"Factory_Reset": 0x06, # Resets the Control Table to its initial factory default settings
|
|
||||||
"Sync_Write": scs.INST_SYNC_WRITE, # Write data to multiple devices with the same Address with the same length at once
|
|
||||||
"Sync_Read": scs.INST_SYNC_READ, # Read data from multiple devices with the same Address with the same length at once
|
|
||||||
} # fmt: skip
|
|
||||||
|
|
||||||
ERROR_TYPE = {
|
|
||||||
"Success": 0x00,
|
|
||||||
"Voltage": scs.ERRBIT_VOLTAGE,
|
|
||||||
"Angle": scs.ERRBIT_ANGLE,
|
|
||||||
"Overheat": scs.ERRBIT_OVERHEAT,
|
|
||||||
"Overele": scs.ERRBIT_OVERELE,
|
|
||||||
"Overload": scs.ERRBIT_OVERLOAD,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class MockFeetechPacket(abc.ABC):
|
class MockFeetechPacket(abc.ABC):
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -49,7 +27,7 @@ class MockFeetechPacket(abc.ABC):
|
||||||
for id_ in range(2, len(packet) - 1): # except header & checksum
|
for id_ in range(2, len(packet) - 1): # except header & checksum
|
||||||
checksum += packet[id_]
|
checksum += packet[id_]
|
||||||
|
|
||||||
packet[-1] = scs.SCS_LOBYTE(~checksum)
|
packet[-1] = ~checksum & 0xFF
|
||||||
|
|
||||||
return packet
|
return packet
|
||||||
|
|
||||||
|
@ -68,13 +46,12 @@ class MockInstructionPacket(MockFeetechPacket):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _build(cls, scs_id: int, params: list[int], length: int, instruct_type: str) -> list[int]:
|
def _build(cls, scs_id: int, params: list[int], length: int, instruction: int) -> list[int]:
|
||||||
instruct_value = INSTRUCTION_TYPES[instruct_type]
|
|
||||||
return [
|
return [
|
||||||
0xFF, 0xFF, # header
|
0xFF, 0xFF, # header
|
||||||
scs_id, # servo id
|
scs_id, # servo id
|
||||||
length, # length
|
length, # length
|
||||||
instruct_value, # instruction type
|
instruction, # instruction type
|
||||||
*params, # data bytes
|
*params, # data bytes
|
||||||
0x00, # placeholder for checksum
|
0x00, # placeholder for checksum
|
||||||
] # fmt: skip
|
] # fmt: skip
|
||||||
|
@ -89,7 +66,7 @@ class MockInstructionPacket(MockFeetechPacket):
|
||||||
|
|
||||||
No parameters required.
|
No parameters required.
|
||||||
"""
|
"""
|
||||||
return cls.build(scs_id=scs_id, params=[], length=2, instruct_type="Ping")
|
return cls.build(scs_id=scs_id, params=[], length=2, instruction=scs.INST_PING)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def read(
|
def read(
|
||||||
|
@ -113,7 +90,7 @@ class MockInstructionPacket(MockFeetechPacket):
|
||||||
"""
|
"""
|
||||||
params = [start_address, data_length]
|
params = [start_address, data_length]
|
||||||
length = 4
|
length = 4
|
||||||
return cls.build(scs_id=scs_id, params=params, length=length, instruct_type="Read")
|
return cls.build(scs_id=scs_id, params=params, length=length, instruction=scs.INST_READ)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def write(
|
def write(
|
||||||
|
@ -139,10 +116,10 @@ class MockInstructionPacket(MockFeetechPacket):
|
||||||
+1 is for the length bytes,
|
+1 is for the length bytes,
|
||||||
+1 is for the checksum at the end.
|
+1 is for the checksum at the end.
|
||||||
"""
|
"""
|
||||||
data = FeetechMotorsBus._split_int_to_bytes(value, data_length)
|
data = _split_into_byte_chunks(value, data_length)
|
||||||
params = [start_address, *data]
|
params = [start_address, *data]
|
||||||
length = data_length + 3
|
length = data_length + 3
|
||||||
return cls.build(scs_id=scs_id, params=params, length=length, instruct_type="Write")
|
return cls.build(scs_id=scs_id, params=params, length=length, instruction=scs.INST_WRITE)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sync_read(
|
def sync_read(
|
||||||
|
@ -167,7 +144,9 @@ class MockInstructionPacket(MockFeetechPacket):
|
||||||
"""
|
"""
|
||||||
params = [start_address, data_length, *scs_ids]
|
params = [start_address, data_length, *scs_ids]
|
||||||
length = len(scs_ids) + 4
|
length = len(scs_ids) + 4
|
||||||
return cls.build(scs_id=scs.BROADCAST_ID, params=params, length=length, instruct_type="Sync_Read")
|
return cls.build(
|
||||||
|
scs_id=scs.BROADCAST_ID, params=params, length=length, instruction=scs.INST_SYNC_READ
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sync_write(
|
def sync_write(
|
||||||
|
@ -201,11 +180,13 @@ class MockInstructionPacket(MockFeetechPacket):
|
||||||
"""
|
"""
|
||||||
data = []
|
data = []
|
||||||
for id_, value in ids_values.items():
|
for id_, value in ids_values.items():
|
||||||
split_value = FeetechMotorsBus._split_int_to_bytes(value, data_length)
|
split_value = _split_into_byte_chunks(value, data_length)
|
||||||
data += [id_, *split_value]
|
data += [id_, *split_value]
|
||||||
params = [start_address, data_length, *data]
|
params = [start_address, data_length, *data]
|
||||||
length = len(ids_values) * (1 + data_length) + 4
|
length = len(ids_values) * (1 + data_length) + 4
|
||||||
return cls.build(scs_id=scs.BROADCAST_ID, params=params, length=length, instruct_type="Sync_Write")
|
return cls.build(
|
||||||
|
scs_id=scs.BROADCAST_ID, params=params, length=length, instruction=scs.INST_SYNC_WRITE
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MockStatusPacket(MockFeetechPacket):
|
class MockStatusPacket(MockFeetechPacket):
|
||||||
|
@ -222,19 +203,18 @@ class MockStatusPacket(MockFeetechPacket):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _build(cls, scs_id: int, params: list[int], length: int, error: str = "Success") -> list[int]:
|
def _build(cls, scs_id: int, params: list[int], length: int, error: int = 0) -> list[int]:
|
||||||
err_byte = ERROR_TYPE[error]
|
|
||||||
return [
|
return [
|
||||||
0xFF, 0xFF, # header
|
0xFF, 0xFF, # header
|
||||||
scs_id, # servo id
|
scs_id, # servo id
|
||||||
length, # length
|
length, # length
|
||||||
err_byte, # status
|
error, # status
|
||||||
*params, # data bytes
|
*params, # data bytes
|
||||||
0x00, # placeholder for checksum
|
0x00, # placeholder for checksum
|
||||||
] # fmt: skip
|
] # fmt: skip
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def ping(cls, scs_id: int, error: str = "Success") -> bytes:
|
def ping(cls, scs_id: int, error: int = 0) -> bytes:
|
||||||
"""Builds a 'Ping' status packet.
|
"""Builds a 'Ping' status packet.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -247,7 +227,7 @@ class MockStatusPacket(MockFeetechPacket):
|
||||||
return cls.build(scs_id, params=[], length=2, error=error)
|
return cls.build(scs_id, params=[], length=2, error=error)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def read(cls, scs_id: int, value: int, param_length: int) -> bytes:
|
def read(cls, scs_id: int, value: int, param_length: int, error: int = 0) -> bytes:
|
||||||
"""Builds a 'Read' status packet.
|
"""Builds a 'Read' status packet.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -258,9 +238,9 @@ class MockStatusPacket(MockFeetechPacket):
|
||||||
Returns:
|
Returns:
|
||||||
bytes: The raw 'Sync Read' status packet ready to be sent through serial.
|
bytes: The raw 'Sync Read' status packet ready to be sent through serial.
|
||||||
"""
|
"""
|
||||||
params = FeetechMotorsBus._split_int_to_bytes(value, param_length)
|
params = _split_into_byte_chunks(value, param_length)
|
||||||
length = param_length + 2
|
length = param_length + 2
|
||||||
return cls.build(scs_id, params=params, length=length)
|
return cls.build(scs_id, params=params, length=length, error=error)
|
||||||
|
|
||||||
|
|
||||||
class MockPortHandler(scs.PortHandler):
|
class MockPortHandler(scs.PortHandler):
|
||||||
|
@ -297,8 +277,6 @@ class MockMotors(MockSerial):
|
||||||
instruction packets. It is meant to test MotorsBus classes.
|
instruction packets. It is meant to test MotorsBus classes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
ctrl_table = STS_SMS_SERIES_CONTROL_TABLE
|
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -323,11 +301,11 @@ class MockMotors(MockSerial):
|
||||||
)
|
)
|
||||||
return stub_name
|
return stub_name
|
||||||
|
|
||||||
def build_ping_stub(self, scs_id: int, num_invalid_try: int = 0) -> str:
|
def build_ping_stub(self, scs_id: int, num_invalid_try: int = 0, error: int = 0) -> str:
|
||||||
ping_request = MockInstructionPacket.ping(scs_id)
|
ping_request = MockInstructionPacket.ping(scs_id)
|
||||||
return_packet = MockStatusPacket.ping(scs_id)
|
return_packet = MockStatusPacket.ping(scs_id, error)
|
||||||
ping_response = self._build_send_fn(return_packet, num_invalid_try)
|
ping_response = self._build_send_fn(return_packet, num_invalid_try)
|
||||||
stub_name = f"Ping_{scs_id}"
|
stub_name = f"Ping_{scs_id}_{error}"
|
||||||
self.stub(
|
self.stub(
|
||||||
name=stub_name,
|
name=stub_name,
|
||||||
receive_bytes=ping_request,
|
receive_bytes=ping_request,
|
||||||
|
@ -336,13 +314,19 @@ class MockMotors(MockSerial):
|
||||||
return stub_name
|
return stub_name
|
||||||
|
|
||||||
def build_read_stub(
|
def build_read_stub(
|
||||||
self, data_name: str, scs_id: int, value: int | None = None, num_invalid_try: int = 0
|
self,
|
||||||
|
address: int,
|
||||||
|
length: int,
|
||||||
|
scs_id: int,
|
||||||
|
value: int,
|
||||||
|
reply: bool = True,
|
||||||
|
error: int = 0,
|
||||||
|
num_invalid_try: int = 0,
|
||||||
) -> str:
|
) -> str:
|
||||||
address, length = self.ctrl_table[data_name]
|
|
||||||
read_request = MockInstructionPacket.read(scs_id, address, length)
|
read_request = MockInstructionPacket.read(scs_id, address, length)
|
||||||
return_packet = MockStatusPacket.read(scs_id, value, length)
|
return_packet = MockStatusPacket.read(scs_id, value, length, error) if reply else b""
|
||||||
read_response = self._build_send_fn(return_packet, num_invalid_try)
|
read_response = self._build_send_fn(return_packet, num_invalid_try)
|
||||||
stub_name = f"Read_{data_name}_{scs_id}"
|
stub_name = f"Read_{address}_{length}_{scs_id}_{value}_{error}"
|
||||||
self.stub(
|
self.stub(
|
||||||
name=stub_name,
|
name=stub_name,
|
||||||
receive_bytes=read_request,
|
receive_bytes=read_request,
|
||||||
|
@ -350,15 +334,42 @@ class MockMotors(MockSerial):
|
||||||
)
|
)
|
||||||
return stub_name
|
return stub_name
|
||||||
|
|
||||||
def build_sync_read_stub(
|
def build_write_stub(
|
||||||
self, data_name: str, ids_values: dict[int, int] | None = None, num_invalid_try: int = 0
|
self,
|
||||||
|
address: int,
|
||||||
|
length: int,
|
||||||
|
scs_id: int,
|
||||||
|
value: int,
|
||||||
|
reply: bool = True,
|
||||||
|
error: int = 0,
|
||||||
|
num_invalid_try: int = 0,
|
||||||
) -> str:
|
) -> str:
|
||||||
address, length = self.ctrl_table[data_name]
|
sync_read_request = MockInstructionPacket.write(scs_id, value, address, length)
|
||||||
sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length)
|
return_packet = MockStatusPacket.build(scs_id, params=[], length=2, error=error) if reply else b""
|
||||||
return_packets = b"".join(MockStatusPacket.read(id_, pos, length) for id_, pos in ids_values.items())
|
stub_name = f"Write_{address}_{length}_{scs_id}"
|
||||||
|
self.stub(
|
||||||
|
name=stub_name,
|
||||||
|
receive_bytes=sync_read_request,
|
||||||
|
send_fn=self._build_send_fn(return_packet, num_invalid_try),
|
||||||
|
)
|
||||||
|
return stub_name
|
||||||
|
|
||||||
|
def build_sync_read_stub(
|
||||||
|
self,
|
||||||
|
address: int,
|
||||||
|
length: int,
|
||||||
|
ids_values: dict[int, int],
|
||||||
|
reply: bool = True,
|
||||||
|
num_invalid_try: int = 0,
|
||||||
|
) -> str:
|
||||||
|
sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length)
|
||||||
|
return_packets = (
|
||||||
|
b"".join(MockStatusPacket.read(id_, pos, length) for id_, pos in ids_values.items())
|
||||||
|
if reply
|
||||||
|
else b""
|
||||||
|
)
|
||||||
sync_read_response = self._build_send_fn(return_packets, num_invalid_try)
|
sync_read_response = self._build_send_fn(return_packets, num_invalid_try)
|
||||||
stub_name = f"Sync_Read_{data_name}_" + "_".join([str(id_) for id_ in ids_values])
|
stub_name = f"Sync_Read_{address}_{length}_" + "_".join([str(id_) for id_ in ids_values])
|
||||||
self.stub(
|
self.stub(
|
||||||
name=stub_name,
|
name=stub_name,
|
||||||
receive_bytes=sync_read_request,
|
receive_bytes=sync_read_request,
|
||||||
|
@ -367,11 +378,10 @@ class MockMotors(MockSerial):
|
||||||
return stub_name
|
return stub_name
|
||||||
|
|
||||||
def build_sequential_sync_read_stub(
|
def build_sequential_sync_read_stub(
|
||||||
self, data_name: str, ids_values: dict[int, list[int]] | None = None
|
self, address: int, length: int, ids_values: dict[int, list[int]] | None = None
|
||||||
) -> str:
|
) -> str:
|
||||||
sequence_length = len(next(iter(ids_values.values())))
|
sequence_length = len(next(iter(ids_values.values())))
|
||||||
assert all(len(positions) == sequence_length for positions in ids_values.values())
|
assert all(len(positions) == sequence_length for positions in ids_values.values())
|
||||||
address, length = self.ctrl_table[data_name]
|
|
||||||
sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length)
|
sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length)
|
||||||
sequential_packets = []
|
sequential_packets = []
|
||||||
for count in range(sequence_length):
|
for count in range(sequence_length):
|
||||||
|
@ -381,7 +391,7 @@ class MockMotors(MockSerial):
|
||||||
sequential_packets.append(return_packets)
|
sequential_packets.append(return_packets)
|
||||||
|
|
||||||
sync_read_response = self._build_sequential_send_fn(sequential_packets)
|
sync_read_response = self._build_sequential_send_fn(sequential_packets)
|
||||||
stub_name = f"Seq_Sync_Read_{data_name}_" + "_".join([str(id_) for id_ in ids_values])
|
stub_name = f"Seq_Sync_Read_{address}_{length}_" + "_".join([str(id_) for id_ in ids_values])
|
||||||
self.stub(
|
self.stub(
|
||||||
name=stub_name,
|
name=stub_name,
|
||||||
receive_bytes=sync_read_request,
|
receive_bytes=sync_read_request,
|
||||||
|
@ -390,11 +400,10 @@ class MockMotors(MockSerial):
|
||||||
return stub_name
|
return stub_name
|
||||||
|
|
||||||
def build_sync_write_stub(
|
def build_sync_write_stub(
|
||||||
self, data_name: str, ids_values: dict[int, int] | None = None, num_invalid_try: int = 0
|
self, address: int, length: int, ids_values: dict[int, int], num_invalid_try: int = 0
|
||||||
) -> str:
|
) -> str:
|
||||||
address, length = self.ctrl_table[data_name]
|
|
||||||
sync_read_request = MockInstructionPacket.sync_write(ids_values, address, length)
|
sync_read_request = MockInstructionPacket.sync_write(ids_values, address, length)
|
||||||
stub_name = f"Sync_Write_{data_name}_" + "_".join([str(id_) for id_ in ids_values])
|
stub_name = f"Sync_Write_{address}_{length}_" + "_".join([str(id_) for id_ in ids_values])
|
||||||
self.stub(
|
self.stub(
|
||||||
name=stub_name,
|
name=stub_name,
|
||||||
receive_bytes=sync_read_request,
|
receive_bytes=sync_read_request,
|
||||||
|
@ -402,20 +411,6 @@ class MockMotors(MockSerial):
|
||||||
)
|
)
|
||||||
return stub_name
|
return stub_name
|
||||||
|
|
||||||
def build_write_stub(
|
|
||||||
self, data_name: str, scs_id: int, value: int, error: str = "Success", num_invalid_try: int = 0
|
|
||||||
) -> str:
|
|
||||||
address, length = self.ctrl_table[data_name]
|
|
||||||
sync_read_request = MockInstructionPacket.write(scs_id, value, address, length)
|
|
||||||
return_packet = MockStatusPacket.build(scs_id, params=[], length=2, error=error)
|
|
||||||
stub_name = f"Write_{data_name}_{scs_id}"
|
|
||||||
self.stub(
|
|
||||||
name=stub_name,
|
|
||||||
receive_bytes=sync_read_request,
|
|
||||||
send_fn=self._build_send_fn(return_packet, num_invalid_try),
|
|
||||||
)
|
|
||||||
return stub_name
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _build_send_fn(packet: bytes, num_invalid_try: int = 0) -> Callable[[int], bytes]:
|
def _build_send_fn(packet: bytes, num_invalid_try: int = 0) -> Callable[[int], bytes]:
|
||||||
def send_fn(_call_count: int) -> bytes:
|
def send_fn(_call_count: int) -> bytes:
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import re
|
||||||
import sys
|
import sys
|
||||||
from typing import Generator
|
from typing import Generator
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
@ -7,6 +8,7 @@ import pytest
|
||||||
|
|
||||||
from lerobot.common.motors import Motor, MotorCalibration, MotorNormMode
|
from lerobot.common.motors import Motor, MotorCalibration, MotorNormMode
|
||||||
from lerobot.common.motors.dynamixel import MODEL_NUMBER_TABLE, DynamixelMotorsBus
|
from lerobot.common.motors.dynamixel import MODEL_NUMBER_TABLE, DynamixelMotorsBus
|
||||||
|
from lerobot.common.motors.dynamixel.tables import X_SERIES_CONTROL_TABLE
|
||||||
from lerobot.common.utils.encoding_utils import encode_twos_complement
|
from lerobot.common.utils.encoding_utils import encode_twos_complement
|
||||||
from tests.mocks.mock_dynamixel import MockMotors, MockPortHandler
|
from tests.mocks.mock_dynamixel import MockMotors, MockPortHandler
|
||||||
|
|
||||||
|
@ -62,48 +64,21 @@ def test_autouse_patch():
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"value, n_bytes, expected",
|
"value, length, expected",
|
||||||
[
|
[
|
||||||
(0x12, 1, [0x12]),
|
(0x12, 1, [0x12]),
|
||||||
(0x1234, 2, [0x34, 0x12]),
|
(0x1234, 2, [0x34, 0x12]),
|
||||||
(0x12345678, 4, [0x78, 0x56, 0x34, 0x12]),
|
(0x12345678, 4, [0x78, 0x56, 0x34, 0x12]),
|
||||||
(0, 1, [0x00]),
|
|
||||||
(0, 2, [0x00, 0x00]),
|
|
||||||
(0, 4, [0x00, 0x00, 0x00, 0x00]),
|
|
||||||
(255, 1, [0xFF]),
|
|
||||||
(65535, 2, [0xFF, 0xFF]),
|
|
||||||
(4294967295, 4, [0xFF, 0xFF, 0xFF, 0xFF]),
|
|
||||||
],
|
],
|
||||||
ids=[
|
ids=[
|
||||||
"1 byte",
|
"1 byte",
|
||||||
"2 bytes",
|
"2 bytes",
|
||||||
"4 bytes",
|
"4 bytes",
|
||||||
"0 with 1 byte",
|
|
||||||
"0 with 2 bytes",
|
|
||||||
"0 with 4 bytes",
|
|
||||||
"max single byte",
|
|
||||||
"max two bytes",
|
|
||||||
"max four bytes",
|
|
||||||
],
|
],
|
||||||
) # fmt: skip
|
) # fmt: skip
|
||||||
def test_split_int_to_bytes(value, n_bytes, expected):
|
def test__split_into_byte_chunks(value, length, expected):
|
||||||
assert DynamixelMotorsBus._split_int_to_bytes(value, n_bytes) == expected
|
bus = DynamixelMotorsBus("", {})
|
||||||
|
assert bus._split_into_byte_chunks(value, length) == expected
|
||||||
|
|
||||||
def test_split_int_to_bytes_invalid_n_bytes():
|
|
||||||
with pytest.raises(NotImplementedError):
|
|
||||||
DynamixelMotorsBus._split_int_to_bytes(100, 3)
|
|
||||||
|
|
||||||
|
|
||||||
def test_split_int_to_bytes_negative_numbers():
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
neg = DynamixelMotorsBus._split_int_to_bytes(-1, 1)
|
|
||||||
print(neg)
|
|
||||||
|
|
||||||
|
|
||||||
def test_split_int_to_bytes_large_number():
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
DynamixelMotorsBus._split_int_to_bytes(2**32, 4) # 4-byte max is 0xFFFFFFFF
|
|
||||||
|
|
||||||
|
|
||||||
def test_abc_implementation(dummy_motors):
|
def test_abc_implementation(dummy_motors):
|
||||||
|
@ -114,204 +89,195 @@ def test_abc_implementation(dummy_motors):
|
||||||
@pytest.mark.parametrize("id_", [1, 2, 3])
|
@pytest.mark.parametrize("id_", [1, 2, 3])
|
||||||
def test_ping(id_, mock_motors, dummy_motors):
|
def test_ping(id_, mock_motors, dummy_motors):
|
||||||
expected_model_nb = MODEL_NUMBER_TABLE[dummy_motors[f"dummy_{id_}"].model]
|
expected_model_nb = MODEL_NUMBER_TABLE[dummy_motors[f"dummy_{id_}"].model]
|
||||||
stub_name = mock_motors.build_ping_stub(id_, expected_model_nb)
|
stub = mock_motors.build_ping_stub(id_, expected_model_nb)
|
||||||
motors_bus = DynamixelMotorsBus(
|
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
||||||
port=mock_motors.port,
|
bus.connect(handshake=False)
|
||||||
motors=dummy_motors,
|
|
||||||
)
|
|
||||||
motors_bus.connect(assert_motors_exist=False)
|
|
||||||
|
|
||||||
ping_model_nb = motors_bus.ping(id_)
|
ping_model_nb = bus.ping(id_)
|
||||||
|
|
||||||
assert ping_model_nb == expected_model_nb
|
assert ping_model_nb == expected_model_nb
|
||||||
assert mock_motors.stubs[stub_name].called
|
assert mock_motors.stubs[stub].called
|
||||||
|
|
||||||
|
|
||||||
def test_broadcast_ping(mock_motors, dummy_motors):
|
def test_broadcast_ping(mock_motors, dummy_motors):
|
||||||
models = {m.id: m.model for m in dummy_motors.values()}
|
models = {m.id: m.model for m in dummy_motors.values()}
|
||||||
expected_model_nbs = {id_: MODEL_NUMBER_TABLE[model] for id_, model in models.items()}
|
expected_model_nbs = {id_: MODEL_NUMBER_TABLE[model] for id_, model in models.items()}
|
||||||
stub_name = mock_motors.build_broadcast_ping_stub(expected_model_nbs)
|
stub = mock_motors.build_broadcast_ping_stub(expected_model_nbs)
|
||||||
motors_bus = DynamixelMotorsBus(
|
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
||||||
port=mock_motors.port,
|
bus.connect(handshake=False)
|
||||||
motors=dummy_motors,
|
|
||||||
)
|
|
||||||
motors_bus.connect(assert_motors_exist=False)
|
|
||||||
|
|
||||||
ping_model_nbs = motors_bus.broadcast_ping()
|
ping_model_nbs = bus.broadcast_ping()
|
||||||
|
|
||||||
assert ping_model_nbs == expected_model_nbs
|
assert ping_model_nbs == expected_model_nbs
|
||||||
assert mock_motors.stubs[stub_name].called
|
assert mock_motors.stubs[stub].called
|
||||||
|
|
||||||
|
|
||||||
def test_sync_read_none(mock_motors, dummy_motors):
|
|
||||||
expected_positions = {
|
|
||||||
"dummy_1": 1337,
|
|
||||||
"dummy_2": 42,
|
|
||||||
"dummy_3": 4016,
|
|
||||||
}
|
|
||||||
ids_values = dict(zip([1, 2, 3], expected_positions.values(), strict=True))
|
|
||||||
stub_name = mock_motors.build_sync_read_stub("Present_Position", ids_values)
|
|
||||||
motors_bus = DynamixelMotorsBus(
|
|
||||||
port=mock_motors.port,
|
|
||||||
motors=dummy_motors,
|
|
||||||
)
|
|
||||||
motors_bus.connect(assert_motors_exist=False)
|
|
||||||
|
|
||||||
read_positions = motors_bus.sync_read("Present_Position", normalize=False)
|
|
||||||
|
|
||||||
assert mock_motors.stubs[stub_name].called
|
|
||||||
assert read_positions == expected_positions
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"id_, position",
|
"addr, length, id_, value",
|
||||||
[
|
[
|
||||||
(1, 1337),
|
(0, 1, 1, 2),
|
||||||
(2, 42),
|
(10, 2, 2, 999),
|
||||||
(3, 4016),
|
(42, 4, 3, 1337),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_sync_read_single_value(id_, position, mock_motors, dummy_motors):
|
def test__read(addr, length, id_, value, mock_motors, dummy_motors):
|
||||||
expected_position = {f"dummy_{id_}": position}
|
stub = mock_motors.build_read_stub(addr, length, id_, value)
|
||||||
stub_name = mock_motors.build_sync_read_stub("Present_Position", {id_: position})
|
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
||||||
motors_bus = DynamixelMotorsBus(
|
bus.connect(handshake=False)
|
||||||
port=mock_motors.port,
|
|
||||||
motors=dummy_motors,
|
|
||||||
)
|
|
||||||
motors_bus.connect(assert_motors_exist=False)
|
|
||||||
|
|
||||||
read_position = motors_bus.sync_read("Present_Position", f"dummy_{id_}", normalize=False)
|
read_value, _, _ = bus._read(addr, length, id_)
|
||||||
|
|
||||||
assert mock_motors.stubs[stub_name].called
|
assert mock_motors.stubs[stub].called
|
||||||
assert read_position == expected_position
|
assert read_value == value
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize("raise_on_error", (True, False))
|
||||||
"ids, positions",
|
def test__read_error(raise_on_error, mock_motors, dummy_motors):
|
||||||
[
|
addr, length, id_, value, error = (10, 4, 1, 1337, dxl.ERRNUM_DATA_LIMIT)
|
||||||
([1], [1337]),
|
stub = mock_motors.build_read_stub(addr, length, id_, value, error=error)
|
||||||
([1, 2], [1337, 42]),
|
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
||||||
([1, 2, 3], [1337, 42, 4016]),
|
bus.connect(handshake=False)
|
||||||
],
|
|
||||||
ids=["1 motor", "2 motors", "3 motors"],
|
|
||||||
) # fmt: skip
|
|
||||||
def test_sync_read(ids, positions, mock_motors, dummy_motors):
|
|
||||||
assert len(ids) == len(positions)
|
|
||||||
names = [f"dummy_{dxl_id}" for dxl_id in ids]
|
|
||||||
expected_positions = dict(zip(names, positions, strict=True))
|
|
||||||
ids_values = dict(zip(ids, positions, strict=True))
|
|
||||||
stub_name = mock_motors.build_sync_read_stub("Present_Position", ids_values)
|
|
||||||
motors_bus = DynamixelMotorsBus(
|
|
||||||
port=mock_motors.port,
|
|
||||||
motors=dummy_motors,
|
|
||||||
)
|
|
||||||
motors_bus.connect(assert_motors_exist=False)
|
|
||||||
|
|
||||||
read_positions = motors_bus.sync_read("Present_Position", names, normalize=False)
|
if raise_on_error:
|
||||||
|
with pytest.raises(
|
||||||
assert mock_motors.stubs[stub_name].called
|
RuntimeError, match=re.escape("[RxPacketError] The data value exceeds the limit value!")
|
||||||
assert read_positions == expected_positions
|
):
|
||||||
|
bus._read(addr, length, id_, raise_on_error=raise_on_error)
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"num_retry, num_invalid_try, pos",
|
|
||||||
[
|
|
||||||
(0, 2, 1337),
|
|
||||||
(2, 3, 42),
|
|
||||||
(3, 2, 4016),
|
|
||||||
(2, 1, 999),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_sync_read_num_retry(num_retry, num_invalid_try, pos, mock_motors, dummy_motors):
|
|
||||||
expected_position = {"dummy_1": pos}
|
|
||||||
stub_name = mock_motors.build_sync_read_stub(
|
|
||||||
"Present_Position", {1: pos}, num_invalid_try=num_invalid_try
|
|
||||||
)
|
|
||||||
motors_bus = DynamixelMotorsBus(
|
|
||||||
port=mock_motors.port,
|
|
||||||
motors=dummy_motors,
|
|
||||||
)
|
|
||||||
motors_bus.connect(assert_motors_exist=False)
|
|
||||||
|
|
||||||
if num_retry >= num_invalid_try:
|
|
||||||
pos_dict = motors_bus.sync_read("Present_Position", "dummy_1", normalize=False, num_retry=num_retry)
|
|
||||||
assert pos_dict == expected_position
|
|
||||||
else:
|
else:
|
||||||
with pytest.raises(ConnectionError):
|
_, _, read_error = bus._read(addr, length, id_, raise_on_error=raise_on_error)
|
||||||
_ = motors_bus.sync_read("Present_Position", "dummy_1", normalize=False, num_retry=num_retry)
|
assert read_error == error
|
||||||
|
|
||||||
expected_calls = min(1 + num_retry, 1 + num_invalid_try)
|
assert mock_motors.stubs[stub].called
|
||||||
assert mock_motors.stubs[stub_name].calls == expected_calls
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("raise_on_error", (True, False))
|
||||||
|
def test__read_comm(raise_on_error, mock_motors, dummy_motors):
|
||||||
|
addr, length, id_, value = (10, 4, 1, 1337)
|
||||||
|
stub = mock_motors.build_read_stub(addr, length, id_, value, reply=False)
|
||||||
|
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
||||||
|
bus.connect(handshake=False)
|
||||||
|
|
||||||
|
if raise_on_error:
|
||||||
|
with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")):
|
||||||
|
bus._read(addr, length, id_, raise_on_error=raise_on_error)
|
||||||
|
else:
|
||||||
|
_, read_comm, _ = bus._read(addr, length, id_, raise_on_error=raise_on_error)
|
||||||
|
assert read_comm == dxl.COMM_RX_TIMEOUT
|
||||||
|
|
||||||
|
assert mock_motors.stubs[stub].called
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"data_name, value",
|
"addr, length, id_, value",
|
||||||
[
|
[
|
||||||
("Torque_Enable", 0),
|
(0, 1, 1, 2),
|
||||||
("Torque_Enable", 1),
|
(10, 2, 2, 999),
|
||||||
("Goal_Position", 1337),
|
(42, 4, 3, 1337),
|
||||||
("Goal_Position", 42),
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_sync_write_single_value(data_name, value, mock_motors, dummy_motors):
|
def test__write(addr, length, id_, value, mock_motors, dummy_motors):
|
||||||
ids_values = {m.id: value for m in dummy_motors.values()}
|
stub = mock_motors.build_write_stub(addr, length, id_, value)
|
||||||
stub_name = mock_motors.build_sync_write_stub(data_name, ids_values)
|
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
||||||
motors_bus = DynamixelMotorsBus(
|
bus.connect(handshake=False)
|
||||||
port=mock_motors.port,
|
|
||||||
motors=dummy_motors,
|
|
||||||
)
|
|
||||||
motors_bus.connect(assert_motors_exist=False)
|
|
||||||
|
|
||||||
motors_bus.sync_write(data_name, value, normalize=False)
|
comm, error = bus._write(addr, length, id_, value)
|
||||||
|
|
||||||
assert mock_motors.stubs[stub_name].wait_called()
|
assert mock_motors.stubs[stub].called
|
||||||
|
assert comm == dxl.COMM_SUCCESS
|
||||||
|
assert error == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("raise_on_error", (True, False))
|
||||||
|
def test__write_error(raise_on_error, mock_motors, dummy_motors):
|
||||||
|
addr, length, id_, value, error = (10, 4, 1, 1337, dxl.ERRNUM_DATA_LIMIT)
|
||||||
|
stub = mock_motors.build_write_stub(addr, length, id_, value, error=error)
|
||||||
|
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
||||||
|
bus.connect(handshake=False)
|
||||||
|
|
||||||
|
if raise_on_error:
|
||||||
|
with pytest.raises(
|
||||||
|
RuntimeError, match=re.escape("[RxPacketError] The data value exceeds the limit value!")
|
||||||
|
):
|
||||||
|
bus._write(addr, length, id_, value, raise_on_error=raise_on_error)
|
||||||
|
else:
|
||||||
|
_, write_error = bus._write(addr, length, id_, value, raise_on_error=raise_on_error)
|
||||||
|
assert write_error == error
|
||||||
|
|
||||||
|
assert mock_motors.stubs[stub].called
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("raise_on_error", (True, False))
|
||||||
|
def test__write_comm(raise_on_error, mock_motors, dummy_motors):
|
||||||
|
addr, length, id_, value = (10, 4, 1, 1337)
|
||||||
|
stub = mock_motors.build_write_stub(addr, length, id_, value, reply=False)
|
||||||
|
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
||||||
|
bus.connect(handshake=False)
|
||||||
|
|
||||||
|
if raise_on_error:
|
||||||
|
with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")):
|
||||||
|
bus._write(addr, length, id_, value, raise_on_error=raise_on_error)
|
||||||
|
else:
|
||||||
|
write_comm, _ = bus._write(addr, length, id_, value, raise_on_error=raise_on_error)
|
||||||
|
assert write_comm == dxl.COMM_RX_TIMEOUT
|
||||||
|
|
||||||
|
assert mock_motors.stubs[stub].called
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"ids, positions",
|
"addr, length, ids_values",
|
||||||
[
|
[
|
||||||
([1], [1337]),
|
(0, 1, {1: 4}),
|
||||||
([1, 2], [1337, 42]),
|
(10, 2, {1: 1337, 2: 42}),
|
||||||
([1, 2, 3], [1337, 42, 4016]),
|
(42, 4, {1: 1337, 2: 42, 3: 4016}),
|
||||||
],
|
],
|
||||||
ids=["1 motor", "2 motors", "3 motors"],
|
ids=["1 motor", "2 motors", "3 motors"],
|
||||||
) # fmt: skip
|
)
|
||||||
def test_sync_write(ids, positions, mock_motors, dummy_motors):
|
def test__sync_read(addr, length, ids_values, mock_motors, dummy_motors):
|
||||||
assert len(ids) == len(positions)
|
stub = mock_motors.build_sync_read_stub(addr, length, ids_values)
|
||||||
ids_values = dict(zip(ids, positions, strict=True))
|
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
||||||
stub_name = mock_motors.build_sync_write_stub("Goal_Position", ids_values)
|
bus.connect(handshake=False)
|
||||||
motors_bus = DynamixelMotorsBus(
|
|
||||||
port=mock_motors.port,
|
|
||||||
motors=dummy_motors,
|
|
||||||
)
|
|
||||||
motors_bus.connect(assert_motors_exist=False)
|
|
||||||
|
|
||||||
write_values = {f"dummy_{id_}": pos for id_, pos in ids_values.items()}
|
read_values, _ = bus._sync_read(addr, length, list(ids_values))
|
||||||
motors_bus.sync_write("Goal_Position", write_values, normalize=False)
|
|
||||||
|
|
||||||
assert mock_motors.stubs[stub_name].wait_called()
|
assert mock_motors.stubs[stub].called
|
||||||
|
assert read_values == ids_values
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("raise_on_error", (True, False))
|
||||||
|
def test__sync_read_comm(raise_on_error, mock_motors, dummy_motors):
|
||||||
|
addr, length, ids_values = (10, 4, {1: 1337})
|
||||||
|
stub = mock_motors.build_sync_read_stub(addr, length, ids_values, reply=False)
|
||||||
|
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
||||||
|
bus.connect(handshake=False)
|
||||||
|
|
||||||
|
if raise_on_error:
|
||||||
|
with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")):
|
||||||
|
bus._sync_read(addr, length, list(ids_values), raise_on_error=raise_on_error)
|
||||||
|
else:
|
||||||
|
_, read_comm = bus._sync_read(addr, length, list(ids_values), raise_on_error=raise_on_error)
|
||||||
|
assert read_comm == dxl.COMM_RX_TIMEOUT
|
||||||
|
|
||||||
|
assert mock_motors.stubs[stub].called
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"data_name, dxl_id, value",
|
"addr, length, ids_values",
|
||||||
[
|
[
|
||||||
("Torque_Enable", 1, 0),
|
(0, 1, {1: 4}),
|
||||||
("Torque_Enable", 1, 1),
|
(10, 2, {1: 1337, 2: 42}),
|
||||||
("Goal_Position", 2, 1337),
|
(42, 4, {1: 1337, 2: 42, 3: 4016}),
|
||||||
("Goal_Position", 3, 42),
|
|
||||||
],
|
],
|
||||||
|
ids=["1 motor", "2 motors", "3 motors"],
|
||||||
)
|
)
|
||||||
def test_write(data_name, dxl_id, value, mock_motors, dummy_motors):
|
def test__sync_write(addr, length, ids_values, mock_motors, dummy_motors):
|
||||||
stub_name = mock_motors.build_write_stub(data_name, dxl_id, value)
|
stub = mock_motors.build_sync_write_stub(addr, length, ids_values)
|
||||||
motors_bus = DynamixelMotorsBus(
|
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
||||||
port=mock_motors.port,
|
bus.connect(handshake=False)
|
||||||
motors=dummy_motors,
|
|
||||||
)
|
|
||||||
motors_bus.connect(assert_motors_exist=False)
|
|
||||||
|
|
||||||
motors_bus.write(data_name, f"dummy_{dxl_id}", value, normalize=False)
|
comm = bus._sync_write(addr, length, ids_values)
|
||||||
|
|
||||||
assert mock_motors.stubs[stub_name].called
|
assert mock_motors.stubs[stub].wait_called()
|
||||||
|
assert comm == dxl.COMM_SUCCESS
|
||||||
|
|
||||||
|
|
||||||
def test_is_calibrated(mock_motors, dummy_motors, dummy_calibration):
|
def test_is_calibrated(mock_motors, dummy_motors, dummy_calibration):
|
||||||
|
@ -319,18 +285,18 @@ def test_is_calibrated(mock_motors, dummy_motors, dummy_calibration):
|
||||||
encoded_homings = {m.id: encode_twos_complement(m.homing_offset, 4) for m in dummy_calibration.values()}
|
encoded_homings = {m.id: encode_twos_complement(m.homing_offset, 4) for m in dummy_calibration.values()}
|
||||||
mins = {m.id: m.range_min for m in dummy_calibration.values()}
|
mins = {m.id: m.range_min for m in dummy_calibration.values()}
|
||||||
maxes = {m.id: m.range_max for m in dummy_calibration.values()}
|
maxes = {m.id: m.range_max for m in dummy_calibration.values()}
|
||||||
drive_modes_stub = mock_motors.build_sync_read_stub("Drive_Mode", drive_modes)
|
drive_modes_stub = mock_motors.build_sync_read_stub(*X_SERIES_CONTROL_TABLE["Drive_Mode"], drive_modes)
|
||||||
offsets_stub = mock_motors.build_sync_read_stub("Homing_Offset", encoded_homings)
|
offsets_stub = mock_motors.build_sync_read_stub(*X_SERIES_CONTROL_TABLE["Homing_Offset"], encoded_homings)
|
||||||
mins_stub = mock_motors.build_sync_read_stub("Min_Position_Limit", mins)
|
mins_stub = mock_motors.build_sync_read_stub(*X_SERIES_CONTROL_TABLE["Min_Position_Limit"], mins)
|
||||||
maxes_stub = mock_motors.build_sync_read_stub("Max_Position_Limit", maxes)
|
maxes_stub = mock_motors.build_sync_read_stub(*X_SERIES_CONTROL_TABLE["Max_Position_Limit"], maxes)
|
||||||
motors_bus = DynamixelMotorsBus(
|
bus = DynamixelMotorsBus(
|
||||||
port=mock_motors.port,
|
port=mock_motors.port,
|
||||||
motors=dummy_motors,
|
motors=dummy_motors,
|
||||||
calibration=dummy_calibration,
|
calibration=dummy_calibration,
|
||||||
)
|
)
|
||||||
motors_bus.connect(assert_motors_exist=False)
|
bus.connect(handshake=False)
|
||||||
|
|
||||||
is_calibrated = motors_bus.is_calibrated
|
is_calibrated = bus.is_calibrated
|
||||||
|
|
||||||
assert is_calibrated
|
assert is_calibrated
|
||||||
assert mock_motors.stubs[drive_modes_stub].called
|
assert mock_motors.stubs[drive_modes_stub].called
|
||||||
|
@ -344,17 +310,20 @@ def test_reset_calibration(mock_motors, dummy_motors):
|
||||||
write_mins_stubs = []
|
write_mins_stubs = []
|
||||||
write_maxes_stubs = []
|
write_maxes_stubs = []
|
||||||
for motor in dummy_motors.values():
|
for motor in dummy_motors.values():
|
||||||
write_homing_stubs.append(mock_motors.build_write_stub("Homing_Offset", motor.id, 0))
|
write_homing_stubs.append(
|
||||||
write_mins_stubs.append(mock_motors.build_write_stub("Min_Position_Limit", motor.id, 0))
|
mock_motors.build_write_stub(*X_SERIES_CONTROL_TABLE["Homing_Offset"], motor.id, 0)
|
||||||
write_maxes_stubs.append(mock_motors.build_write_stub("Max_Position_Limit", motor.id, 4095))
|
)
|
||||||
|
write_mins_stubs.append(
|
||||||
motors_bus = DynamixelMotorsBus(
|
mock_motors.build_write_stub(*X_SERIES_CONTROL_TABLE["Min_Position_Limit"], motor.id, 0)
|
||||||
port=mock_motors.port,
|
)
|
||||||
motors=dummy_motors,
|
write_maxes_stubs.append(
|
||||||
|
mock_motors.build_write_stub(*X_SERIES_CONTROL_TABLE["Max_Position_Limit"], motor.id, 4095)
|
||||||
)
|
)
|
||||||
motors_bus.connect(assert_motors_exist=False)
|
|
||||||
|
|
||||||
motors_bus.reset_calibration()
|
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
||||||
|
bus.connect(handshake=False)
|
||||||
|
|
||||||
|
bus.reset_calibration()
|
||||||
|
|
||||||
assert all(mock_motors.stubs[stub].called for stub in write_homing_stubs)
|
assert all(mock_motors.stubs[stub].called for stub in write_homing_stubs)
|
||||||
assert all(mock_motors.stubs[stub].called for stub in write_mins_stubs)
|
assert all(mock_motors.stubs[stub].called for stub in write_mins_stubs)
|
||||||
|
@ -376,23 +345,22 @@ def test_set_half_turn_homings(mock_motors, dummy_motors):
|
||||||
2: 2005, # 2047 - 42
|
2: 2005, # 2047 - 42
|
||||||
3: -1625, # 2047 - 3672
|
3: -1625, # 2047 - 3672
|
||||||
}
|
}
|
||||||
read_pos_stub = mock_motors.build_sync_read_stub("Present_Position", current_positions)
|
read_pos_stub = mock_motors.build_sync_read_stub(
|
||||||
|
*X_SERIES_CONTROL_TABLE["Present_Position"], current_positions
|
||||||
|
)
|
||||||
write_homing_stubs = []
|
write_homing_stubs = []
|
||||||
for id_, homing in expected_homings.items():
|
for id_, homing in expected_homings.items():
|
||||||
encoded_homing = encode_twos_complement(homing, 4)
|
encoded_homing = encode_twos_complement(homing, 4)
|
||||||
stub = mock_motors.build_write_stub("Homing_Offset", id_, encoded_homing)
|
stub = mock_motors.build_write_stub(*X_SERIES_CONTROL_TABLE["Homing_Offset"], id_, encoded_homing)
|
||||||
write_homing_stubs.append(stub)
|
write_homing_stubs.append(stub)
|
||||||
|
|
||||||
motors_bus = DynamixelMotorsBus(
|
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
||||||
port=mock_motors.port,
|
bus.connect(handshake=False)
|
||||||
motors=dummy_motors,
|
bus.reset_calibration = MagicMock()
|
||||||
)
|
|
||||||
motors_bus.connect(assert_motors_exist=False)
|
|
||||||
motors_bus.reset_calibration = MagicMock()
|
|
||||||
|
|
||||||
motors_bus.set_half_turn_homings()
|
bus.set_half_turn_homings()
|
||||||
|
|
||||||
motors_bus.reset_calibration.assert_called_once()
|
bus.reset_calibration.assert_called_once()
|
||||||
assert mock_motors.stubs[read_pos_stub].called
|
assert mock_motors.stubs[read_pos_stub].called
|
||||||
assert all(mock_motors.stubs[stub].called for stub in write_homing_stubs)
|
assert all(mock_motors.stubs[stub].called for stub in write_homing_stubs)
|
||||||
|
|
||||||
|
@ -413,15 +381,14 @@ def test_record_ranges_of_motion(mock_motors, dummy_motors):
|
||||||
"dummy_2": 3600,
|
"dummy_2": 3600,
|
||||||
"dummy_3": 4002,
|
"dummy_3": 4002,
|
||||||
}
|
}
|
||||||
read_pos_stub = mock_motors.build_sequential_sync_read_stub("Present_Position", positions)
|
read_pos_stub = mock_motors.build_sequential_sync_read_stub(
|
||||||
with patch("lerobot.common.motors.motors_bus.enter_pressed", side_effect=[False, True]):
|
*X_SERIES_CONTROL_TABLE["Present_Position"], positions
|
||||||
motors_bus = DynamixelMotorsBus(
|
|
||||||
port=mock_motors.port,
|
|
||||||
motors=dummy_motors,
|
|
||||||
)
|
)
|
||||||
motors_bus.connect(assert_motors_exist=False)
|
with patch("lerobot.common.motors.motors_bus.enter_pressed", side_effect=[False, True]):
|
||||||
|
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
||||||
|
bus.connect(handshake=False)
|
||||||
|
|
||||||
mins, maxes = motors_bus.record_ranges_of_motion(display_values=False)
|
mins, maxes = bus.record_ranges_of_motion(display_values=False)
|
||||||
|
|
||||||
assert mock_motors.stubs[read_pos_stub].calls == 3
|
assert mock_motors.stubs[read_pos_stub].calls == 3
|
||||||
assert mins == expected_mins
|
assert mins == expected_mins
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import re
|
||||||
import sys
|
import sys
|
||||||
from typing import Generator
|
from typing import Generator
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
@ -6,7 +7,8 @@ import pytest
|
||||||
import scservo_sdk as scs
|
import scservo_sdk as scs
|
||||||
|
|
||||||
from lerobot.common.motors import Motor, MotorCalibration, MotorNormMode
|
from lerobot.common.motors import Motor, MotorCalibration, MotorNormMode
|
||||||
from lerobot.common.motors.feetech import MODEL_NUMBER_TABLE, FeetechMotorsBus
|
from lerobot.common.motors.feetech import MODEL_NUMBER, MODEL_NUMBER_TABLE, FeetechMotorsBus
|
||||||
|
from lerobot.common.motors.feetech.tables import STS_SMS_SERIES_CONTROL_TABLE
|
||||||
from lerobot.common.utils.encoding_utils import encode_sign_magnitude
|
from lerobot.common.utils.encoding_utils import encode_sign_magnitude
|
||||||
from tests.mocks.mock_feetech import MockMotors, MockPortHandler
|
from tests.mocks.mock_feetech import MockMotors, MockPortHandler
|
||||||
|
|
||||||
|
@ -61,48 +63,27 @@ def test_autouse_patch():
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"value, n_bytes, expected",
|
"protocol, value, length, expected",
|
||||||
[
|
[
|
||||||
(0x12, 1, [0x12]),
|
(0, 0x12, 1, [0x12]),
|
||||||
(0x1234, 2, [0x34, 0x12]),
|
(1, 0x12, 1, [0x12]),
|
||||||
(0x12345678, 4, [0x78, 0x56, 0x34, 0x12]),
|
(0, 0x1234, 2, [0x34, 0x12]),
|
||||||
(0, 1, [0x00]),
|
(1, 0x1234, 2, [0x12, 0x34]),
|
||||||
(0, 2, [0x00, 0x00]),
|
(0, 0x12345678, 4, [0x78, 0x56, 0x34, 0x12]),
|
||||||
(0, 4, [0x00, 0x00, 0x00, 0x00]),
|
(1, 0x12345678, 4, [0x56, 0x78, 0x12, 0x34]),
|
||||||
(255, 1, [0xFF]),
|
|
||||||
(65535, 2, [0xFF, 0xFF]),
|
|
||||||
(4294967295, 4, [0xFF, 0xFF, 0xFF, 0xFF]),
|
|
||||||
],
|
],
|
||||||
ids=[
|
ids=[
|
||||||
"1 byte",
|
"P0: 1 byte",
|
||||||
"2 bytes",
|
"P1: 1 byte",
|
||||||
"4 bytes",
|
"P0: 2 bytes",
|
||||||
"0 with 1 byte",
|
"P1: 2 bytes",
|
||||||
"0 with 2 bytes",
|
"P0: 4 bytes",
|
||||||
"0 with 4 bytes",
|
"P1: 4 bytes",
|
||||||
"max single byte",
|
|
||||||
"max two bytes",
|
|
||||||
"max four bytes",
|
|
||||||
],
|
],
|
||||||
) # fmt: skip
|
) # fmt: skip
|
||||||
def test_split_int_to_bytes(value, n_bytes, expected):
|
def test__split_into_byte_chunks(protocol, value, length, expected):
|
||||||
assert FeetechMotorsBus._split_int_to_bytes(value, n_bytes) == expected
|
bus = FeetechMotorsBus("", {}, protocol_version=protocol)
|
||||||
|
assert bus._split_into_byte_chunks(value, length) == expected
|
||||||
|
|
||||||
def test_split_int_to_bytes_invalid_n_bytes():
|
|
||||||
with pytest.raises(NotImplementedError):
|
|
||||||
FeetechMotorsBus._split_int_to_bytes(100, 3)
|
|
||||||
|
|
||||||
|
|
||||||
def test_split_int_to_bytes_negative_numbers():
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
neg = FeetechMotorsBus._split_int_to_bytes(-1, 1)
|
|
||||||
print(neg)
|
|
||||||
|
|
||||||
|
|
||||||
def test_split_int_to_bytes_large_number():
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
FeetechMotorsBus._split_int_to_bytes(2**32, 4) # 4-byte max is 0xFFFFFFFF
|
|
||||||
|
|
||||||
|
|
||||||
def test_abc_implementation(dummy_motors):
|
def test_abc_implementation(dummy_motors):
|
||||||
|
@ -110,35 +91,19 @@ def test_abc_implementation(dummy_motors):
|
||||||
FeetechMotorsBus(port="/dev/dummy-port", motors=dummy_motors)
|
FeetechMotorsBus(port="/dev/dummy-port", motors=dummy_motors)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip("TODO")
|
|
||||||
def test_scan_port(mock_motors):
|
|
||||||
expected = {
|
|
||||||
9_600: {1: 777},
|
|
||||||
57_600: {2: 777},
|
|
||||||
500_000: {237: 777},
|
|
||||||
}
|
|
||||||
expected_model_nbs = {id_: model for d in expected.values() for id_, model in d.items()}
|
|
||||||
ping_stub = mock_motors.build_broadcast_ping_stub(list(expected_model_nbs))
|
|
||||||
mobel_nb_stub = mock_motors.build_sync_read_stub("Model_Number", expected_model_nbs)
|
|
||||||
found = FeetechMotorsBus.scan_port(mock_motors.port)
|
|
||||||
|
|
||||||
assert found == expected
|
|
||||||
assert mock_motors.stubs[ping_stub].called
|
|
||||||
assert mock_motors.stubs[mobel_nb_stub].called
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("id_", [1, 2, 3])
|
@pytest.mark.parametrize("id_", [1, 2, 3])
|
||||||
def test_ping(id_, mock_motors, dummy_motors):
|
def test_ping(id_, mock_motors, dummy_motors):
|
||||||
expected_model_nb = MODEL_NUMBER_TABLE[dummy_motors[f"dummy_{id_}"].model]
|
expected_model_nb = MODEL_NUMBER_TABLE[dummy_motors[f"dummy_{id_}"].model]
|
||||||
|
addr, length = MODEL_NUMBER
|
||||||
ping_stub = mock_motors.build_ping_stub(id_)
|
ping_stub = mock_motors.build_ping_stub(id_)
|
||||||
mobel_nb_stub = mock_motors.build_read_stub("Model_Number", id_, expected_model_nb)
|
mobel_nb_stub = mock_motors.build_read_stub(addr, length, id_, expected_model_nb)
|
||||||
motors_bus = FeetechMotorsBus(
|
bus = FeetechMotorsBus(
|
||||||
port=mock_motors.port,
|
port=mock_motors.port,
|
||||||
motors=dummy_motors,
|
motors=dummy_motors,
|
||||||
)
|
)
|
||||||
motors_bus.connect(assert_motors_exist=False)
|
bus.connect(handshake=False)
|
||||||
|
|
||||||
ping_model_nb = motors_bus.ping(id_)
|
ping_model_nb = bus.ping(id_)
|
||||||
|
|
||||||
assert ping_model_nb == expected_model_nb
|
assert ping_model_nb == expected_model_nb
|
||||||
assert mock_motors.stubs[ping_stub].called
|
assert mock_motors.stubs[ping_stub].called
|
||||||
|
@ -147,208 +112,221 @@ def test_ping(id_, mock_motors, dummy_motors):
|
||||||
|
|
||||||
def test_broadcast_ping(mock_motors, dummy_motors):
|
def test_broadcast_ping(mock_motors, dummy_motors):
|
||||||
models = {m.id: m.model for m in dummy_motors.values()}
|
models = {m.id: m.model for m in dummy_motors.values()}
|
||||||
expected_model_nbs = {id_: MODEL_NUMBER_TABLE[model] for id_, model in models.items()}
|
addr, length = MODEL_NUMBER
|
||||||
ping_stub = mock_motors.build_broadcast_ping_stub(list(models))
|
ping_stub = mock_motors.build_broadcast_ping_stub(list(models))
|
||||||
mobel_nb_stub = mock_motors.build_sync_read_stub("Model_Number", expected_model_nbs)
|
mobel_nb_stubs = []
|
||||||
motors_bus = FeetechMotorsBus(
|
expected_model_nbs = {}
|
||||||
|
for id_, model in models.items():
|
||||||
|
model_nb = MODEL_NUMBER_TABLE[model]
|
||||||
|
stub = mock_motors.build_read_stub(addr, length, id_, model_nb)
|
||||||
|
expected_model_nbs[id_] = model_nb
|
||||||
|
mobel_nb_stubs.append(stub)
|
||||||
|
bus = FeetechMotorsBus(
|
||||||
port=mock_motors.port,
|
port=mock_motors.port,
|
||||||
motors=dummy_motors,
|
motors=dummy_motors,
|
||||||
)
|
)
|
||||||
motors_bus.connect(assert_motors_exist=False)
|
bus.connect(handshake=False)
|
||||||
|
|
||||||
ping_model_nbs = motors_bus.broadcast_ping()
|
ping_model_nbs = bus.broadcast_ping()
|
||||||
|
|
||||||
assert ping_model_nbs == expected_model_nbs
|
assert ping_model_nbs == expected_model_nbs
|
||||||
assert mock_motors.stubs[ping_stub].called
|
assert mock_motors.stubs[ping_stub].called
|
||||||
assert mock_motors.stubs[mobel_nb_stub].called
|
assert all(mock_motors.stubs[stub].called for stub in mobel_nb_stubs)
|
||||||
|
|
||||||
|
|
||||||
def test_sync_read_none(mock_motors, dummy_motors):
|
|
||||||
expected_positions = {
|
|
||||||
"dummy_1": 1337,
|
|
||||||
"dummy_2": 42,
|
|
||||||
"dummy_3": 4016,
|
|
||||||
}
|
|
||||||
ids_values = dict(zip([1, 2, 3], expected_positions.values(), strict=True))
|
|
||||||
stub_name = mock_motors.build_sync_read_stub("Present_Position", ids_values)
|
|
||||||
motors_bus = FeetechMotorsBus(
|
|
||||||
port=mock_motors.port,
|
|
||||||
motors=dummy_motors,
|
|
||||||
)
|
|
||||||
motors_bus.connect(assert_motors_exist=False)
|
|
||||||
|
|
||||||
read_positions = motors_bus.sync_read("Present_Position", normalize=False)
|
|
||||||
|
|
||||||
assert mock_motors.stubs[stub_name].called
|
|
||||||
assert read_positions == expected_positions
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"id_, position",
|
"addr, length, id_, value",
|
||||||
[
|
[
|
||||||
(1, 1337),
|
(0, 1, 1, 2),
|
||||||
(2, 42),
|
(10, 2, 2, 999),
|
||||||
(3, 4016),
|
(42, 4, 3, 1337),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_sync_read_single_value(id_, position, mock_motors, dummy_motors):
|
def test__read(addr, length, id_, value, mock_motors, dummy_motors):
|
||||||
expected_position = {f"dummy_{id_}": position}
|
stub = mock_motors.build_read_stub(addr, length, id_, value)
|
||||||
stub_name = mock_motors.build_sync_read_stub("Present_Position", {id_: position})
|
bus = FeetechMotorsBus(
|
||||||
motors_bus = FeetechMotorsBus(
|
|
||||||
port=mock_motors.port,
|
port=mock_motors.port,
|
||||||
motors=dummy_motors,
|
motors=dummy_motors,
|
||||||
)
|
)
|
||||||
motors_bus.connect(assert_motors_exist=False)
|
bus.connect(handshake=False)
|
||||||
|
|
||||||
read_position = motors_bus.sync_read("Present_Position", f"dummy_{id_}", normalize=False)
|
read_value, _, _ = bus._read(addr, length, id_)
|
||||||
|
|
||||||
assert mock_motors.stubs[stub_name].called
|
assert mock_motors.stubs[stub].called
|
||||||
assert read_position == expected_position
|
assert read_value == value
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize("raise_on_error", (True, False))
|
||||||
"ids, positions",
|
def test__read_error(raise_on_error, mock_motors, dummy_motors):
|
||||||
[
|
addr, length, id_, value, error = (10, 4, 1, 1337, scs.ERRBIT_VOLTAGE)
|
||||||
([1], [1337]),
|
stub = mock_motors.build_read_stub(addr, length, id_, value, error=error)
|
||||||
([1, 2], [1337, 42]),
|
bus = FeetechMotorsBus(
|
||||||
([1, 2, 3], [1337, 42, 4016]),
|
|
||||||
],
|
|
||||||
ids=["1 motor", "2 motors", "3 motors"],
|
|
||||||
) # fmt: skip
|
|
||||||
def test_sync_read(ids, positions, mock_motors, dummy_motors):
|
|
||||||
assert len(ids) == len(positions)
|
|
||||||
names = [f"dummy_{dxl_id}" for dxl_id in ids]
|
|
||||||
expected_positions = dict(zip(names, positions, strict=True))
|
|
||||||
ids_values = dict(zip(ids, positions, strict=True))
|
|
||||||
stub_name = mock_motors.build_sync_read_stub("Present_Position", ids_values)
|
|
||||||
motors_bus = FeetechMotorsBus(
|
|
||||||
port=mock_motors.port,
|
port=mock_motors.port,
|
||||||
motors=dummy_motors,
|
motors=dummy_motors,
|
||||||
)
|
)
|
||||||
motors_bus.connect(assert_motors_exist=False)
|
bus.connect(handshake=False)
|
||||||
|
|
||||||
read_positions = motors_bus.sync_read("Present_Position", names, normalize=False)
|
if raise_on_error:
|
||||||
|
with pytest.raises(RuntimeError, match=re.escape("[RxPacketError] Input voltage error!")):
|
||||||
assert mock_motors.stubs[stub_name].called
|
bus._read(addr, length, id_, raise_on_error=raise_on_error)
|
||||||
assert read_positions == expected_positions
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"num_retry, num_invalid_try, pos",
|
|
||||||
[
|
|
||||||
(0, 2, 1337),
|
|
||||||
(2, 3, 42),
|
|
||||||
(3, 2, 4016),
|
|
||||||
(2, 1, 999),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_sync_read_num_retry(num_retry, num_invalid_try, pos, mock_motors, dummy_motors):
|
|
||||||
expected_position = {"dummy_1": pos}
|
|
||||||
stub_name = mock_motors.build_sync_read_stub(
|
|
||||||
"Present_Position", {1: pos}, num_invalid_try=num_invalid_try
|
|
||||||
)
|
|
||||||
motors_bus = FeetechMotorsBus(
|
|
||||||
port=mock_motors.port,
|
|
||||||
motors=dummy_motors,
|
|
||||||
)
|
|
||||||
motors_bus.connect(assert_motors_exist=False)
|
|
||||||
|
|
||||||
if num_retry >= num_invalid_try:
|
|
||||||
pos_dict = motors_bus.sync_read("Present_Position", "dummy_1", normalize=False, num_retry=num_retry)
|
|
||||||
assert pos_dict == expected_position
|
|
||||||
else:
|
else:
|
||||||
with pytest.raises(ConnectionError):
|
_, _, read_error = bus._read(addr, length, id_, raise_on_error=raise_on_error)
|
||||||
_ = motors_bus.sync_read("Present_Position", "dummy_1", normalize=False, num_retry=num_retry)
|
assert read_error == error
|
||||||
|
|
||||||
expected_calls = min(1 + num_retry, 1 + num_invalid_try)
|
assert mock_motors.stubs[stub].called
|
||||||
assert mock_motors.stubs[stub_name].calls == expected_calls
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize("raise_on_error", (True, False))
|
||||||
"data_name, value",
|
def test__read_comm(raise_on_error, mock_motors, dummy_motors):
|
||||||
[
|
addr, length, id_, value = (10, 4, 1, 1337)
|
||||||
("Torque_Enable", 0),
|
stub = mock_motors.build_read_stub(addr, length, id_, value, reply=False)
|
||||||
("Torque_Enable", 1),
|
bus = FeetechMotorsBus(
|
||||||
("Goal_Position", 1337),
|
|
||||||
("Goal_Position", 42),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_sync_write_single_value(data_name, value, mock_motors, dummy_motors):
|
|
||||||
ids_values = {m.id: value for m in dummy_motors.values()}
|
|
||||||
stub_name = mock_motors.build_sync_write_stub(data_name, ids_values)
|
|
||||||
motors_bus = FeetechMotorsBus(
|
|
||||||
port=mock_motors.port,
|
port=mock_motors.port,
|
||||||
motors=dummy_motors,
|
motors=dummy_motors,
|
||||||
)
|
)
|
||||||
motors_bus.connect(assert_motors_exist=False)
|
bus.connect(handshake=False)
|
||||||
|
|
||||||
motors_bus.sync_write(data_name, value, normalize=False)
|
if raise_on_error:
|
||||||
|
with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")):
|
||||||
|
bus._read(addr, length, id_, raise_on_error=raise_on_error)
|
||||||
|
else:
|
||||||
|
_, read_comm, _ = bus._read(addr, length, id_, raise_on_error=raise_on_error)
|
||||||
|
assert read_comm == scs.COMM_RX_TIMEOUT
|
||||||
|
|
||||||
assert mock_motors.stubs[stub_name].wait_called()
|
assert mock_motors.stubs[stub].called
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"ids, positions",
|
"addr, length, id_, value",
|
||||||
[
|
[
|
||||||
([1], [1337]),
|
(0, 1, 1, 2),
|
||||||
([1, 2], [1337, 42]),
|
(10, 2, 2, 999),
|
||||||
([1, 2, 3], [1337, 42, 4016]),
|
(42, 4, 3, 1337),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test__write(addr, length, id_, value, mock_motors, dummy_motors):
|
||||||
|
stub = mock_motors.build_write_stub(addr, length, id_, value)
|
||||||
|
bus = FeetechMotorsBus(
|
||||||
|
port=mock_motors.port,
|
||||||
|
motors=dummy_motors,
|
||||||
|
)
|
||||||
|
bus.connect(handshake=False)
|
||||||
|
|
||||||
|
comm, error = bus._write(addr, length, id_, value)
|
||||||
|
|
||||||
|
assert mock_motors.stubs[stub].called
|
||||||
|
assert comm == scs.COMM_SUCCESS
|
||||||
|
assert error == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("raise_on_error", (True, False))
|
||||||
|
def test__write_error(raise_on_error, mock_motors, dummy_motors):
|
||||||
|
addr, length, id_, value, error = (10, 4, 1, 1337, scs.ERRBIT_VOLTAGE)
|
||||||
|
stub = mock_motors.build_write_stub(addr, length, id_, value, error=error)
|
||||||
|
bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
||||||
|
bus.connect(handshake=False)
|
||||||
|
|
||||||
|
if raise_on_error:
|
||||||
|
with pytest.raises(RuntimeError, match=re.escape("[RxPacketError] Input voltage error!")):
|
||||||
|
bus._write(addr, length, id_, value, raise_on_error=raise_on_error)
|
||||||
|
else:
|
||||||
|
_, write_error = bus._write(addr, length, id_, value, raise_on_error=raise_on_error)
|
||||||
|
assert write_error == error
|
||||||
|
|
||||||
|
assert mock_motors.stubs[stub].called
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("raise_on_error", (True, False))
|
||||||
|
def test__write_comm(raise_on_error, mock_motors, dummy_motors):
|
||||||
|
addr, length, id_, value = (10, 4, 1, 1337)
|
||||||
|
stub = mock_motors.build_write_stub(addr, length, id_, value, reply=False)
|
||||||
|
bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
||||||
|
bus.connect(handshake=False)
|
||||||
|
|
||||||
|
if raise_on_error:
|
||||||
|
with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")):
|
||||||
|
bus._write(addr, length, id_, value, raise_on_error=raise_on_error)
|
||||||
|
else:
|
||||||
|
write_comm, _ = bus._write(addr, length, id_, value, raise_on_error=raise_on_error)
|
||||||
|
assert write_comm == scs.COMM_RX_TIMEOUT
|
||||||
|
|
||||||
|
assert mock_motors.stubs[stub].called
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"addr, length, ids_values",
|
||||||
|
[
|
||||||
|
(0, 1, {1: 4}),
|
||||||
|
(10, 2, {1: 1337, 2: 42}),
|
||||||
|
(42, 4, {1: 1337, 2: 42, 3: 4016}),
|
||||||
],
|
],
|
||||||
ids=["1 motor", "2 motors", "3 motors"],
|
ids=["1 motor", "2 motors", "3 motors"],
|
||||||
) # fmt: skip
|
)
|
||||||
def test_sync_write(ids, positions, mock_motors, dummy_motors):
|
def test__sync_read(addr, length, ids_values, mock_motors, dummy_motors):
|
||||||
assert len(ids) == len(positions)
|
stub = mock_motors.build_sync_read_stub(addr, length, ids_values)
|
||||||
ids_values = dict(zip(ids, positions, strict=True))
|
bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
||||||
stub_name = mock_motors.build_sync_write_stub("Goal_Position", ids_values)
|
bus.connect(handshake=False)
|
||||||
motors_bus = FeetechMotorsBus(
|
|
||||||
port=mock_motors.port,
|
|
||||||
motors=dummy_motors,
|
|
||||||
)
|
|
||||||
motors_bus.connect(assert_motors_exist=False)
|
|
||||||
|
|
||||||
write_values = {f"dummy_{id_}": pos for id_, pos in ids_values.items()}
|
read_values, _ = bus._sync_read(addr, length, list(ids_values))
|
||||||
motors_bus.sync_write("Goal_Position", write_values, normalize=False)
|
|
||||||
|
|
||||||
assert mock_motors.stubs[stub_name].wait_called()
|
assert mock_motors.stubs[stub].called
|
||||||
|
assert read_values == ids_values
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("raise_on_error", (True, False))
|
||||||
|
def test__sync_read_comm(raise_on_error, mock_motors, dummy_motors):
|
||||||
|
addr, length, ids_values = (10, 4, {1: 1337})
|
||||||
|
stub = mock_motors.build_sync_read_stub(addr, length, ids_values, reply=False)
|
||||||
|
bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
||||||
|
bus.connect(handshake=False)
|
||||||
|
|
||||||
|
if raise_on_error:
|
||||||
|
with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")):
|
||||||
|
bus._sync_read(addr, length, list(ids_values), raise_on_error=raise_on_error)
|
||||||
|
else:
|
||||||
|
_, read_comm = bus._sync_read(addr, length, list(ids_values), raise_on_error=raise_on_error)
|
||||||
|
assert read_comm == scs.COMM_RX_TIMEOUT
|
||||||
|
|
||||||
|
assert mock_motors.stubs[stub].called
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"data_name, dxl_id, value",
|
"addr, length, ids_values",
|
||||||
[
|
[
|
||||||
("Torque_Enable", 1, 0),
|
(0, 1, {1: 4}),
|
||||||
("Torque_Enable", 1, 1),
|
(10, 2, {1: 1337, 2: 42}),
|
||||||
("Goal_Position", 2, 1337),
|
(42, 4, {1: 1337, 2: 42, 3: 4016}),
|
||||||
("Goal_Position", 3, 42),
|
|
||||||
],
|
],
|
||||||
|
ids=["1 motor", "2 motors", "3 motors"],
|
||||||
)
|
)
|
||||||
def test_write(data_name, dxl_id, value, mock_motors, dummy_motors):
|
def test__sync_write(addr, length, ids_values, mock_motors, dummy_motors):
|
||||||
stub_name = mock_motors.build_write_stub(data_name, dxl_id, value)
|
stub = mock_motors.build_sync_write_stub(addr, length, ids_values)
|
||||||
motors_bus = FeetechMotorsBus(
|
bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
||||||
port=mock_motors.port,
|
bus.connect(handshake=False)
|
||||||
motors=dummy_motors,
|
|
||||||
)
|
|
||||||
motors_bus.connect(assert_motors_exist=False)
|
|
||||||
|
|
||||||
motors_bus.write(data_name, f"dummy_{dxl_id}", value, normalize=False)
|
comm = bus._sync_write(addr, length, ids_values)
|
||||||
|
|
||||||
assert mock_motors.stubs[stub_name].called
|
assert mock_motors.stubs[stub].wait_called()
|
||||||
|
assert comm == scs.COMM_SUCCESS
|
||||||
|
|
||||||
|
|
||||||
def test_is_calibrated(mock_motors, dummy_motors, dummy_calibration):
|
def test_is_calibrated(mock_motors, dummy_motors, dummy_calibration):
|
||||||
encoded_homings = {m.id: encode_sign_magnitude(m.homing_offset, 11) for m in dummy_calibration.values()}
|
encoded_homings = {m.id: encode_sign_magnitude(m.homing_offset, 11) for m in dummy_calibration.values()}
|
||||||
mins = {m.id: m.range_min for m in dummy_calibration.values()}
|
mins = {m.id: m.range_min for m in dummy_calibration.values()}
|
||||||
maxes = {m.id: m.range_max for m in dummy_calibration.values()}
|
maxes = {m.id: m.range_max for m in dummy_calibration.values()}
|
||||||
offsets_stub = mock_motors.build_sync_read_stub("Homing_Offset", encoded_homings)
|
offsets_stub = mock_motors.build_sync_read_stub(
|
||||||
mins_stub = mock_motors.build_sync_read_stub("Min_Position_Limit", mins)
|
*STS_SMS_SERIES_CONTROL_TABLE["Homing_Offset"], encoded_homings
|
||||||
maxes_stub = mock_motors.build_sync_read_stub("Max_Position_Limit", maxes)
|
)
|
||||||
motors_bus = FeetechMotorsBus(
|
mins_stub = mock_motors.build_sync_read_stub(*STS_SMS_SERIES_CONTROL_TABLE["Min_Position_Limit"], mins)
|
||||||
|
maxes_stub = mock_motors.build_sync_read_stub(*STS_SMS_SERIES_CONTROL_TABLE["Max_Position_Limit"], maxes)
|
||||||
|
bus = FeetechMotorsBus(
|
||||||
port=mock_motors.port,
|
port=mock_motors.port,
|
||||||
motors=dummy_motors,
|
motors=dummy_motors,
|
||||||
calibration=dummy_calibration,
|
calibration=dummy_calibration,
|
||||||
)
|
)
|
||||||
motors_bus.connect(assert_motors_exist=False)
|
bus.connect(handshake=False)
|
||||||
|
|
||||||
is_calibrated = motors_bus.is_calibrated
|
is_calibrated = bus.is_calibrated
|
||||||
|
|
||||||
assert is_calibrated
|
assert is_calibrated
|
||||||
assert mock_motors.stubs[offsets_stub].called
|
assert mock_motors.stubs[offsets_stub].called
|
||||||
|
@ -361,17 +339,20 @@ def test_reset_calibration(mock_motors, dummy_motors):
|
||||||
write_mins_stubs = []
|
write_mins_stubs = []
|
||||||
write_maxes_stubs = []
|
write_maxes_stubs = []
|
||||||
for motor in dummy_motors.values():
|
for motor in dummy_motors.values():
|
||||||
write_homing_stubs.append(mock_motors.build_write_stub("Homing_Offset", motor.id, 0))
|
write_homing_stubs.append(
|
||||||
write_mins_stubs.append(mock_motors.build_write_stub("Min_Position_Limit", motor.id, 0))
|
mock_motors.build_write_stub(*STS_SMS_SERIES_CONTROL_TABLE["Homing_Offset"], motor.id, 0)
|
||||||
write_maxes_stubs.append(mock_motors.build_write_stub("Max_Position_Limit", motor.id, 4095))
|
)
|
||||||
|
write_mins_stubs.append(
|
||||||
motors_bus = FeetechMotorsBus(
|
mock_motors.build_write_stub(*STS_SMS_SERIES_CONTROL_TABLE["Min_Position_Limit"], motor.id, 0)
|
||||||
port=mock_motors.port,
|
)
|
||||||
motors=dummy_motors,
|
write_maxes_stubs.append(
|
||||||
|
mock_motors.build_write_stub(*STS_SMS_SERIES_CONTROL_TABLE["Max_Position_Limit"], motor.id, 4095)
|
||||||
)
|
)
|
||||||
motors_bus.connect(assert_motors_exist=False)
|
|
||||||
|
|
||||||
motors_bus.reset_calibration()
|
bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
||||||
|
bus.connect(handshake=False)
|
||||||
|
|
||||||
|
bus.reset_calibration()
|
||||||
|
|
||||||
assert all(mock_motors.stubs[stub].called for stub in write_homing_stubs)
|
assert all(mock_motors.stubs[stub].called for stub in write_homing_stubs)
|
||||||
assert all(mock_motors.stubs[stub].called for stub in write_mins_stubs)
|
assert all(mock_motors.stubs[stub].called for stub in write_mins_stubs)
|
||||||
|
@ -393,23 +374,24 @@ def test_set_half_turn_homings(mock_motors, dummy_motors):
|
||||||
2: -2005, # 42 - 2047
|
2: -2005, # 42 - 2047
|
||||||
3: 1625, # 3672 - 2047
|
3: 1625, # 3672 - 2047
|
||||||
}
|
}
|
||||||
read_pos_stub = mock_motors.build_sync_read_stub("Present_Position", current_positions)
|
read_pos_stub = mock_motors.build_sync_read_stub(
|
||||||
|
*STS_SMS_SERIES_CONTROL_TABLE["Present_Position"], current_positions
|
||||||
|
)
|
||||||
write_homing_stubs = []
|
write_homing_stubs = []
|
||||||
for id_, homing in expected_homings.items():
|
for id_, homing in expected_homings.items():
|
||||||
encoded_homing = encode_sign_magnitude(homing, 11)
|
encoded_homing = encode_sign_magnitude(homing, 11)
|
||||||
stub = mock_motors.build_write_stub("Homing_Offset", id_, encoded_homing)
|
stub = mock_motors.build_write_stub(
|
||||||
|
*STS_SMS_SERIES_CONTROL_TABLE["Homing_Offset"], id_, encoded_homing
|
||||||
|
)
|
||||||
write_homing_stubs.append(stub)
|
write_homing_stubs.append(stub)
|
||||||
|
|
||||||
motors_bus = FeetechMotorsBus(
|
bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
||||||
port=mock_motors.port,
|
bus.connect(handshake=False)
|
||||||
motors=dummy_motors,
|
bus.reset_calibration = MagicMock()
|
||||||
)
|
|
||||||
motors_bus.connect(assert_motors_exist=False)
|
|
||||||
motors_bus.reset_calibration = MagicMock()
|
|
||||||
|
|
||||||
motors_bus.set_half_turn_homings()
|
bus.set_half_turn_homings()
|
||||||
|
|
||||||
motors_bus.reset_calibration.assert_called_once()
|
bus.reset_calibration.assert_called_once()
|
||||||
assert mock_motors.stubs[read_pos_stub].called
|
assert mock_motors.stubs[read_pos_stub].called
|
||||||
assert all(mock_motors.stubs[stub].called for stub in write_homing_stubs)
|
assert all(mock_motors.stubs[stub].called for stub in write_homing_stubs)
|
||||||
|
|
||||||
|
@ -430,16 +412,15 @@ def test_record_ranges_of_motion(mock_motors, dummy_motors):
|
||||||
"dummy_2": 3600,
|
"dummy_2": 3600,
|
||||||
"dummy_3": 4002,
|
"dummy_3": 4002,
|
||||||
}
|
}
|
||||||
read_pos_stub = mock_motors.build_sequential_sync_read_stub("Present_Position", positions)
|
stub = mock_motors.build_sequential_sync_read_stub(
|
||||||
with patch("lerobot.common.motors.motors_bus.enter_pressed", side_effect=[False, True]):
|
*STS_SMS_SERIES_CONTROL_TABLE["Present_Position"], positions
|
||||||
motors_bus = FeetechMotorsBus(
|
|
||||||
port=mock_motors.port,
|
|
||||||
motors=dummy_motors,
|
|
||||||
)
|
)
|
||||||
motors_bus.connect(assert_motors_exist=False)
|
with patch("lerobot.common.motors.motors_bus.enter_pressed", side_effect=[False, True]):
|
||||||
|
bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
||||||
|
bus.connect(handshake=False)
|
||||||
|
|
||||||
mins, maxes = motors_bus.record_ranges_of_motion(display_values=False)
|
mins, maxes = bus.record_ranges_of_motion(display_values=False)
|
||||||
|
|
||||||
assert mock_motors.stubs[read_pos_stub].calls == 3
|
assert mock_motors.stubs[stub].calls == 3
|
||||||
assert mins == expected_mins
|
assert mins == expected_mins
|
||||||
assert maxes == expected_maxes
|
assert maxes == expected_maxes
|
||||||
|
|
|
@ -1,87 +1,469 @@
|
||||||
|
# ruff: noqa: N802
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from lerobot.common.motors.motors_bus import assert_same_address, get_address, get_ctrl_table
|
from lerobot.common.motors.motors_bus import (
|
||||||
|
Motor,
|
||||||
|
MotorNormMode,
|
||||||
|
MotorsBus,
|
||||||
|
assert_same_address,
|
||||||
|
get_address,
|
||||||
|
get_ctrl_table,
|
||||||
|
)
|
||||||
|
|
||||||
# TODO(aliberts)
|
DUMMY_CTRL_TABLE_1 = {
|
||||||
# class DummyMotorsBus(MotorsBus):
|
|
||||||
# def __init__(self, port: str, motors: dict[str, Motor]):
|
|
||||||
# super().__init__(port, motors)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def ctrl_table_1() -> dict:
|
|
||||||
return {
|
|
||||||
"Firmware_Version": (0, 1),
|
"Firmware_Version": (0, 1),
|
||||||
"Model_Number": (1, 2),
|
"Model_Number": (1, 2),
|
||||||
"Present_Position": (3, 4),
|
"Present_Position": (3, 4),
|
||||||
"Goal_Position": (7, 2),
|
"Goal_Position": (11, 2),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DUMMY_CTRL_TABLE_2 = {
|
||||||
@pytest.fixture
|
|
||||||
def ctrl_table_2() -> dict:
|
|
||||||
return {
|
|
||||||
"Model_Number": (0, 2),
|
"Model_Number": (0, 2),
|
||||||
"Firmware_Version": (2, 1),
|
"Firmware_Version": (2, 1),
|
||||||
"Present_Position": (3, 4),
|
"Present_Position": (3, 4),
|
||||||
"Goal_Position": (7, 4),
|
"Present_Velocity": (7, 4),
|
||||||
"Lock": (7, 4),
|
"Goal_Position": (11, 4),
|
||||||
}
|
"Goal_Velocity": (15, 4),
|
||||||
|
"Lock": (19, 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
DUMMY_MODEL_CTRL_TABLE = {
|
||||||
|
"model_1": DUMMY_CTRL_TABLE_1,
|
||||||
|
"model_2": DUMMY_CTRL_TABLE_2,
|
||||||
|
"model_3": DUMMY_CTRL_TABLE_2,
|
||||||
|
}
|
||||||
|
|
||||||
|
DUMMY_BAUDRATE_TABLE = {
|
||||||
|
0: 1_000_000,
|
||||||
|
1: 500_000,
|
||||||
|
2: 250_000,
|
||||||
|
}
|
||||||
|
|
||||||
|
DUMMY_MODEL_BAUDRATE_TABLE = {
|
||||||
|
"model_1": DUMMY_BAUDRATE_TABLE,
|
||||||
|
"model_2": DUMMY_BAUDRATE_TABLE,
|
||||||
|
"model_3": DUMMY_BAUDRATE_TABLE,
|
||||||
|
}
|
||||||
|
|
||||||
|
DUMMY_ENCODING_TABLE = {
|
||||||
|
"Present_Position": 8,
|
||||||
|
"Goal_Position": 10,
|
||||||
|
}
|
||||||
|
|
||||||
|
DUMMY_MODEL_ENCODING_TABLE = {
|
||||||
|
"model_1": DUMMY_ENCODING_TABLE,
|
||||||
|
"model_2": DUMMY_ENCODING_TABLE,
|
||||||
|
"model_3": DUMMY_ENCODING_TABLE,
|
||||||
|
}
|
||||||
|
|
||||||
|
DUMMY_MODEL_NUMBER_TABLE = {
|
||||||
|
"model_1": 1234,
|
||||||
|
"model_2": 5678,
|
||||||
|
"model_3": 5799,
|
||||||
|
}
|
||||||
|
|
||||||
|
DUMMY_MODEL_RESOLUTION_TABLE = {
|
||||||
|
"model_1": 4096,
|
||||||
|
"model_2": 1024,
|
||||||
|
"model_3": 4096,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class MockPortHandler:
|
||||||
|
def __init__(self, port_name):
|
||||||
|
self.is_open: bool = False
|
||||||
|
self.baudrate: int
|
||||||
|
self.packet_start_time: float
|
||||||
|
self.packet_timeout: float
|
||||||
|
self.tx_time_per_byte: float
|
||||||
|
self.is_using: bool = False
|
||||||
|
self.port_name: str = port_name
|
||||||
|
self.ser = None
|
||||||
|
|
||||||
|
def openPort(self):
|
||||||
|
self.is_open = True
|
||||||
|
return self.is_open
|
||||||
|
|
||||||
|
def closePort(self):
|
||||||
|
self.is_open = False
|
||||||
|
|
||||||
|
def clearPort(self): ...
|
||||||
|
def setPortName(self, port_name):
|
||||||
|
self.port_name = port_name
|
||||||
|
|
||||||
|
def getPortName(self):
|
||||||
|
return self.port_name
|
||||||
|
|
||||||
|
def setBaudRate(self, baudrate):
|
||||||
|
self.baudrate: baudrate
|
||||||
|
|
||||||
|
def getBaudRate(self):
|
||||||
|
return self.baudrate
|
||||||
|
|
||||||
|
def getBytesAvailable(self): ...
|
||||||
|
def readPort(self, length): ...
|
||||||
|
def writePort(self, packet): ...
|
||||||
|
def setPacketTimeout(self, packet_length): ...
|
||||||
|
def setPacketTimeoutMillis(self, msec): ...
|
||||||
|
def isPacketTimeout(self): ...
|
||||||
|
def getCurrentTime(self): ...
|
||||||
|
def getTimeSinceStart(self): ...
|
||||||
|
def setupPort(self, cflag_baud): ...
|
||||||
|
def getCFlagBaud(self, baudrate): ...
|
||||||
|
|
||||||
|
|
||||||
|
class MockMotorsBus(MotorsBus):
|
||||||
|
available_baudrates = [500_000, 1_000_000]
|
||||||
|
default_timeout = 1000
|
||||||
|
model_baudrate_table = DUMMY_MODEL_BAUDRATE_TABLE
|
||||||
|
model_ctrl_table = DUMMY_MODEL_CTRL_TABLE
|
||||||
|
model_encoding_table = DUMMY_MODEL_ENCODING_TABLE
|
||||||
|
model_number_table = DUMMY_MODEL_NUMBER_TABLE
|
||||||
|
model_resolution_table = DUMMY_MODEL_RESOLUTION_TABLE
|
||||||
|
normalized_data = ["Present_Position", "Goal_Position"]
|
||||||
|
|
||||||
|
def __init__(self, port: str, motors: dict[str, Motor]):
|
||||||
|
super().__init__(port, motors)
|
||||||
|
self.port_handler = MockPortHandler(port)
|
||||||
|
|
||||||
|
def _assert_protocol_is_compatible(self, instruction_name): ...
|
||||||
|
def _handshake(self): ...
|
||||||
|
def _find_single_motor(self, motor, initial_baudrate): ...
|
||||||
|
def configure_motors(self): ...
|
||||||
|
def read_calibration(self): ...
|
||||||
|
def write_calibration(self, calibration_dict): ...
|
||||||
|
def disable_torque(self, motors): ...
|
||||||
|
def enable_torque(self, motors): ...
|
||||||
|
def _get_half_turn_homings(self, positions): ...
|
||||||
|
def _encode_sign(self, data_name, ids_values): ...
|
||||||
|
def _decode_sign(self, data_name, ids_values): ...
|
||||||
|
def _split_into_byte_chunks(self, value, length): ...
|
||||||
|
def broadcast_ping(self, num_retry, raise_on_error): ...
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def model_ctrl_table(ctrl_table_1, ctrl_table_2) -> dict:
|
def dummy_motors() -> dict[str, Motor]:
|
||||||
return {
|
return {
|
||||||
"model_1": ctrl_table_1,
|
"dummy_1": Motor(1, "model_2", MotorNormMode.RANGE_M100_100),
|
||||||
"model_2": ctrl_table_2,
|
"dummy_2": Motor(2, "model_3", MotorNormMode.RANGE_M100_100),
|
||||||
|
"dummy_3": Motor(3, "model_2", MotorNormMode.RANGE_0_100),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_get_ctrl_table(model_ctrl_table, ctrl_table_1):
|
def test_get_ctrl_table():
|
||||||
model = "model_1"
|
model = "model_1"
|
||||||
ctrl_table = get_ctrl_table(model_ctrl_table, model)
|
ctrl_table = get_ctrl_table(DUMMY_MODEL_CTRL_TABLE, model)
|
||||||
assert ctrl_table == ctrl_table_1
|
assert ctrl_table == DUMMY_CTRL_TABLE_1
|
||||||
|
|
||||||
|
|
||||||
def test_get_ctrl_table_error(model_ctrl_table):
|
def test_get_ctrl_table_error():
|
||||||
model = "model_99"
|
model = "model_99"
|
||||||
with pytest.raises(KeyError, match=f"Control table for {model=} not found."):
|
with pytest.raises(KeyError, match=f"Control table for {model=} not found."):
|
||||||
get_ctrl_table(model_ctrl_table, model)
|
get_ctrl_table(DUMMY_MODEL_CTRL_TABLE, model)
|
||||||
|
|
||||||
|
|
||||||
def test_get_address(model_ctrl_table):
|
def test_get_address():
|
||||||
addr, n_bytes = get_address(model_ctrl_table, "model_1", "Firmware_Version")
|
addr, n_bytes = get_address(DUMMY_MODEL_CTRL_TABLE, "model_1", "Firmware_Version")
|
||||||
assert addr == 0
|
assert addr == 0
|
||||||
assert n_bytes == 1
|
assert n_bytes == 1
|
||||||
|
|
||||||
|
|
||||||
def test_get_address_error(model_ctrl_table):
|
def test_get_address_error():
|
||||||
model = "model_1"
|
model = "model_1"
|
||||||
data_name = "Lock"
|
data_name = "Lock"
|
||||||
with pytest.raises(KeyError, match=f"Address for '{data_name}' not found in {model} control table."):
|
with pytest.raises(KeyError, match=f"Address for '{data_name}' not found in {model} control table."):
|
||||||
get_address(model_ctrl_table, "model_1", data_name)
|
get_address(DUMMY_MODEL_CTRL_TABLE, "model_1", data_name)
|
||||||
|
|
||||||
|
|
||||||
def test_assert_same_address(model_ctrl_table):
|
def test_assert_same_address():
|
||||||
models = ["model_1", "model_2"]
|
models = ["model_1", "model_2"]
|
||||||
assert_same_address(model_ctrl_table, models, "Present_Position")
|
assert_same_address(DUMMY_MODEL_CTRL_TABLE, models, "Present_Position")
|
||||||
|
|
||||||
|
|
||||||
def test_assert_same_address_different_addresses(model_ctrl_table):
|
def test_assert_same_length_different_addresses():
|
||||||
models = ["model_1", "model_2"]
|
models = ["model_1", "model_2"]
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
NotImplementedError,
|
NotImplementedError,
|
||||||
match=re.escape("At least two motor models use a different address"),
|
match=re.escape("At least two motor models use a different address"),
|
||||||
):
|
):
|
||||||
assert_same_address(model_ctrl_table, models, "Model_Number")
|
assert_same_address(DUMMY_MODEL_CTRL_TABLE, models, "Model_Number")
|
||||||
|
|
||||||
|
|
||||||
def test_assert_same_address_different_bytes(model_ctrl_table):
|
def test_assert_same_address_different_length():
|
||||||
models = ["model_1", "model_2"]
|
models = ["model_1", "model_2"]
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
NotImplementedError,
|
NotImplementedError,
|
||||||
match=re.escape("At least two motor models use a different bytes representation"),
|
match=re.escape("At least two motor models use a different bytes representation"),
|
||||||
):
|
):
|
||||||
assert_same_address(model_ctrl_table, models, "Goal_Position")
|
assert_same_address(DUMMY_MODEL_CTRL_TABLE, models, "Goal_Position")
|
||||||
|
|
||||||
|
|
||||||
|
def test__serialize_data_invalid_length():
|
||||||
|
bus = MockMotorsBus("", {})
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
bus._serialize_data(100, 3)
|
||||||
|
|
||||||
|
|
||||||
|
def test__serialize_data_negative_numbers():
|
||||||
|
bus = MockMotorsBus("", {})
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
bus._serialize_data(-1, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def test__serialize_data_large_number():
|
||||||
|
bus = MockMotorsBus("", {})
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
bus._serialize_data(2**32, 4) # 4-byte max is 0xFFFFFFFF
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"data_name, id_, value",
|
||||||
|
[
|
||||||
|
("Firmware_Version", 1, 14),
|
||||||
|
("Model_Number", 1, 5678),
|
||||||
|
("Present_Position", 2, 1337),
|
||||||
|
("Present_Velocity", 3, 42),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_read(data_name, id_, value, dummy_motors):
|
||||||
|
bus = MockMotorsBus("/dev/dummy-port", dummy_motors)
|
||||||
|
bus.connect(handshake=False)
|
||||||
|
addr, length = DUMMY_CTRL_TABLE_2[data_name]
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(MockMotorsBus, "_read", return_value=(value, 0, 0)) as mock__read,
|
||||||
|
patch.object(MockMotorsBus, "_decode_sign", return_value={id_: value}) as mock__decode_sign,
|
||||||
|
patch.object(MockMotorsBus, "_normalize", return_value={id_: value}) as mock__normalize,
|
||||||
|
):
|
||||||
|
returned_value = bus.read(data_name, f"dummy_{id_}")
|
||||||
|
|
||||||
|
assert returned_value == value
|
||||||
|
mock__read.assert_called_once_with(
|
||||||
|
addr,
|
||||||
|
length,
|
||||||
|
id_,
|
||||||
|
num_retry=0,
|
||||||
|
raise_on_error=True,
|
||||||
|
err_msg=f"Failed to read '{data_name}' on {id_=} after 1 tries.",
|
||||||
|
)
|
||||||
|
mock__decode_sign.assert_called_once_with(data_name, {id_: value})
|
||||||
|
if data_name in bus.normalized_data:
|
||||||
|
mock__normalize.assert_called_once_with(data_name, {id_: value})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"data_name, id_, value",
|
||||||
|
[
|
||||||
|
("Goal_Position", 1, 1337),
|
||||||
|
("Goal_Velocity", 2, 3682),
|
||||||
|
("Lock", 3, 1),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_write(data_name, id_, value, dummy_motors):
|
||||||
|
bus = MockMotorsBus("/dev/dummy-port", dummy_motors)
|
||||||
|
bus.connect(handshake=False)
|
||||||
|
addr, length = DUMMY_CTRL_TABLE_2[data_name]
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(MockMotorsBus, "_write", return_value=(0, 0)) as mock__write,
|
||||||
|
patch.object(MockMotorsBus, "_encode_sign", return_value={id_: value}) as mock__encode_sign,
|
||||||
|
patch.object(MockMotorsBus, "_unnormalize", return_value={id_: value}) as mock__unnormalize,
|
||||||
|
):
|
||||||
|
bus.write(data_name, f"dummy_{id_}", value)
|
||||||
|
|
||||||
|
mock__write.assert_called_once_with(
|
||||||
|
addr,
|
||||||
|
length,
|
||||||
|
id_,
|
||||||
|
value,
|
||||||
|
num_retry=0,
|
||||||
|
raise_on_error=True,
|
||||||
|
err_msg=f"Failed to write '{data_name}' on {id_=} with '{value}' after 1 tries.",
|
||||||
|
)
|
||||||
|
mock__encode_sign.assert_called_once_with(data_name, {id_: value})
|
||||||
|
if data_name in bus.normalized_data:
|
||||||
|
mock__unnormalize.assert_called_once_with(data_name, {id_: value})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"data_name, id_, value",
|
||||||
|
[
|
||||||
|
("Firmware_Version", 1, 14),
|
||||||
|
("Model_Number", 1, 5678),
|
||||||
|
("Present_Position", 2, 1337),
|
||||||
|
("Present_Velocity", 3, 42),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_sync_read_by_str(data_name, id_, value, dummy_motors):
|
||||||
|
bus = MockMotorsBus("/dev/dummy-port", dummy_motors)
|
||||||
|
bus.connect(handshake=False)
|
||||||
|
addr, length = DUMMY_CTRL_TABLE_2[data_name]
|
||||||
|
ids = [id_]
|
||||||
|
expected_value = {f"dummy_{id_}": value}
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(MockMotorsBus, "_sync_read", return_value=({id_: value}, 0)) as mock__sync_read,
|
||||||
|
patch.object(MockMotorsBus, "_decode_sign", return_value={id_: value}) as mock__decode_sign,
|
||||||
|
patch.object(MockMotorsBus, "_normalize", return_value={id_: value}) as mock__normalize,
|
||||||
|
):
|
||||||
|
returned_dict = bus.sync_read(data_name, f"dummy_{id_}")
|
||||||
|
|
||||||
|
assert returned_dict == expected_value
|
||||||
|
mock__sync_read.assert_called_once_with(
|
||||||
|
addr,
|
||||||
|
length,
|
||||||
|
ids,
|
||||||
|
num_retry=0,
|
||||||
|
raise_on_error=True,
|
||||||
|
err_msg=f"Failed to sync read '{data_name}' on {ids=} after 1 tries.",
|
||||||
|
)
|
||||||
|
mock__decode_sign.assert_called_once_with(data_name, {id_: value})
|
||||||
|
if data_name in bus.normalized_data:
|
||||||
|
mock__normalize.assert_called_once_with(data_name, {id_: value})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"data_name, ids_values",
|
||||||
|
[
|
||||||
|
("Model_Number", {1: 5678}),
|
||||||
|
("Present_Position", {1: 1337, 2: 42}),
|
||||||
|
("Present_Velocity", {1: 1337, 2: 42, 3: 4016}),
|
||||||
|
],
|
||||||
|
ids=["1 motor", "2 motors", "3 motors"],
|
||||||
|
)
|
||||||
|
def test_sync_read_by_list(data_name, ids_values, dummy_motors):
|
||||||
|
bus = MockMotorsBus("/dev/dummy-port", dummy_motors)
|
||||||
|
bus.connect(handshake=False)
|
||||||
|
addr, length = DUMMY_CTRL_TABLE_2[data_name]
|
||||||
|
ids = list(ids_values)
|
||||||
|
expected_values = {f"dummy_{id_}": val for id_, val in ids_values.items()}
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(MockMotorsBus, "_sync_read", return_value=(ids_values, 0)) as mock__sync_read,
|
||||||
|
patch.object(MockMotorsBus, "_decode_sign", return_value=ids_values) as mock__decode_sign,
|
||||||
|
patch.object(MockMotorsBus, "_normalize", return_value=ids_values) as mock__normalize,
|
||||||
|
):
|
||||||
|
returned_dict = bus.sync_read(data_name, [f"dummy_{id_}" for id_ in ids])
|
||||||
|
|
||||||
|
assert returned_dict == expected_values
|
||||||
|
mock__sync_read.assert_called_once_with(
|
||||||
|
addr,
|
||||||
|
length,
|
||||||
|
ids,
|
||||||
|
num_retry=0,
|
||||||
|
raise_on_error=True,
|
||||||
|
err_msg=f"Failed to sync read '{data_name}' on {ids=} after 1 tries.",
|
||||||
|
)
|
||||||
|
mock__decode_sign.assert_called_once_with(data_name, ids_values)
|
||||||
|
if data_name in bus.normalized_data:
|
||||||
|
mock__normalize.assert_called_once_with(data_name, ids_values)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"data_name, ids_values",
|
||||||
|
[
|
||||||
|
("Model_Number", {1: 5678, 2: 5799, 3: 5678}),
|
||||||
|
("Present_Position", {1: 1337, 2: 42, 3: 4016}),
|
||||||
|
("Goal_Position", {1: 4008, 2: 199, 3: 3446}),
|
||||||
|
],
|
||||||
|
ids=["Model_Number", "Present_Position", "Goal_Position"],
|
||||||
|
)
|
||||||
|
def test_sync_read_by_none(data_name, ids_values, dummy_motors):
|
||||||
|
bus = MockMotorsBus("/dev/dummy-port", dummy_motors)
|
||||||
|
bus.connect(handshake=False)
|
||||||
|
addr, length = DUMMY_CTRL_TABLE_2[data_name]
|
||||||
|
ids = list(ids_values)
|
||||||
|
expected_values = {f"dummy_{id_}": val for id_, val in ids_values.items()}
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(MockMotorsBus, "_sync_read", return_value=(ids_values, 0)) as mock__sync_read,
|
||||||
|
patch.object(MockMotorsBus, "_decode_sign", return_value=ids_values) as mock__decode_sign,
|
||||||
|
patch.object(MockMotorsBus, "_normalize", return_value=ids_values) as mock__normalize,
|
||||||
|
):
|
||||||
|
returned_dict = bus.sync_read(data_name)
|
||||||
|
|
||||||
|
assert returned_dict == expected_values
|
||||||
|
mock__sync_read.assert_called_once_with(
|
||||||
|
addr,
|
||||||
|
length,
|
||||||
|
ids,
|
||||||
|
num_retry=0,
|
||||||
|
raise_on_error=True,
|
||||||
|
err_msg=f"Failed to sync read '{data_name}' on {ids=} after 1 tries.",
|
||||||
|
)
|
||||||
|
mock__decode_sign.assert_called_once_with(data_name, ids_values)
|
||||||
|
if data_name in bus.normalized_data:
|
||||||
|
mock__normalize.assert_called_once_with(data_name, ids_values)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"data_name, value",
|
||||||
|
[
|
||||||
|
("Goal_Position", 500),
|
||||||
|
("Goal_Velocity", 4010),
|
||||||
|
("Lock", 0),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_sync_write_by_single_value(data_name, value, dummy_motors):
|
||||||
|
bus = MockMotorsBus("/dev/dummy-port", dummy_motors)
|
||||||
|
bus.connect(handshake=False)
|
||||||
|
addr, length = DUMMY_CTRL_TABLE_2[data_name]
|
||||||
|
ids_values = {m.id: value for m in dummy_motors.values()}
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(MockMotorsBus, "_sync_write", return_value=(ids_values, 0)) as mock__sync_write,
|
||||||
|
patch.object(MockMotorsBus, "_encode_sign", return_value=ids_values) as mock__encode_sign,
|
||||||
|
patch.object(MockMotorsBus, "_unnormalize", return_value=ids_values) as mock__unnormalize,
|
||||||
|
):
|
||||||
|
bus.sync_write(data_name, value)
|
||||||
|
|
||||||
|
mock__sync_write.assert_called_once_with(
|
||||||
|
addr,
|
||||||
|
length,
|
||||||
|
ids_values,
|
||||||
|
num_retry=0,
|
||||||
|
raise_on_error=True,
|
||||||
|
err_msg=f"Failed to sync write '{data_name}' with {ids_values=} after 1 tries.",
|
||||||
|
)
|
||||||
|
mock__encode_sign.assert_called_once_with(data_name, ids_values)
|
||||||
|
if data_name in bus.normalized_data:
|
||||||
|
mock__unnormalize.assert_called_once_with(data_name, ids_values)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"data_name, ids_values",
|
||||||
|
[
|
||||||
|
("Goal_Position", {1: 1337, 2: 42, 3: 4016}),
|
||||||
|
("Goal_Velocity", {1: 50, 2: 83, 3: 2777}),
|
||||||
|
("Lock", {1: 0, 2: 0, 3: 1}),
|
||||||
|
],
|
||||||
|
ids=["Goal_Position", "Goal_Velocity", "Lock"],
|
||||||
|
)
|
||||||
|
def test_sync_write_by_value_dict(data_name, ids_values, dummy_motors):
|
||||||
|
bus = MockMotorsBus("/dev/dummy-port", dummy_motors)
|
||||||
|
bus.connect(handshake=False)
|
||||||
|
addr, length = DUMMY_CTRL_TABLE_2[data_name]
|
||||||
|
values = {f"dummy_{id_}": val for id_, val in ids_values.items()}
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(MockMotorsBus, "_sync_write", return_value=(ids_values, 0)) as mock__sync_write,
|
||||||
|
patch.object(MockMotorsBus, "_encode_sign", return_value=ids_values) as mock__encode_sign,
|
||||||
|
patch.object(MockMotorsBus, "_unnormalize", return_value=ids_values) as mock__unnormalize,
|
||||||
|
):
|
||||||
|
bus.sync_write(data_name, values)
|
||||||
|
|
||||||
|
mock__sync_write.assert_called_once_with(
|
||||||
|
addr,
|
||||||
|
length,
|
||||||
|
ids_values,
|
||||||
|
num_retry=0,
|
||||||
|
raise_on_error=True,
|
||||||
|
err_msg=f"Failed to sync write '{data_name}' with {ids_values=} after 1 tries.",
|
||||||
|
)
|
||||||
|
mock__encode_sign.assert_called_once_with(data_name, ids_values)
|
||||||
|
if data_name in bus.normalized_data:
|
||||||
|
mock__unnormalize.assert_called_once_with(data_name, ids_values)
|
||||||
|
|
|
@ -172,8 +172,7 @@ def test_record_and_replay_and_policy(tmp_path, request, robot_type, mock):
|
||||||
push_to_hub=False,
|
push_to_hub=False,
|
||||||
# TODO(rcadene, aliberts): test video=True
|
# TODO(rcadene, aliberts): test video=True
|
||||||
video=False,
|
video=False,
|
||||||
# TODO(rcadene): display cameras through cv2 sometimes crashes on mac
|
display_data=False,
|
||||||
display_cameras=False,
|
|
||||||
play_sounds=False,
|
play_sounds=False,
|
||||||
)
|
)
|
||||||
dataset = record(robot, rec_cfg)
|
dataset = record(robot, rec_cfg)
|
||||||
|
@ -226,7 +225,7 @@ def test_record_and_replay_and_policy(tmp_path, request, robot_type, mock):
|
||||||
num_episodes=2,
|
num_episodes=2,
|
||||||
push_to_hub=False,
|
push_to_hub=False,
|
||||||
video=False,
|
video=False,
|
||||||
display_cameras=False,
|
display_data=False,
|
||||||
play_sounds=False,
|
play_sounds=False,
|
||||||
num_image_writer_processes=num_image_writer_processes,
|
num_image_writer_processes=num_image_writer_processes,
|
||||||
)
|
)
|
||||||
|
@ -273,7 +272,7 @@ def test_resume_record(tmp_path, request, robot_type, mock):
|
||||||
episode_time_s=1,
|
episode_time_s=1,
|
||||||
push_to_hub=False,
|
push_to_hub=False,
|
||||||
video=False,
|
video=False,
|
||||||
display_cameras=False,
|
display_data=False,
|
||||||
play_sounds=False,
|
play_sounds=False,
|
||||||
num_episodes=1,
|
num_episodes=1,
|
||||||
)
|
)
|
||||||
|
@ -330,7 +329,7 @@ def test_record_with_event_rerecord_episode(tmp_path, request, robot_type, mock)
|
||||||
num_episodes=1,
|
num_episodes=1,
|
||||||
push_to_hub=False,
|
push_to_hub=False,
|
||||||
video=False,
|
video=False,
|
||||||
display_cameras=False,
|
display_data=False,
|
||||||
play_sounds=False,
|
play_sounds=False,
|
||||||
)
|
)
|
||||||
dataset = record(robot, rec_cfg)
|
dataset = record(robot, rec_cfg)
|
||||||
|
@ -380,7 +379,7 @@ def test_record_with_event_exit_early(tmp_path, request, robot_type, mock):
|
||||||
num_episodes=1,
|
num_episodes=1,
|
||||||
push_to_hub=False,
|
push_to_hub=False,
|
||||||
video=False,
|
video=False,
|
||||||
display_cameras=False,
|
display_data=False,
|
||||||
play_sounds=False,
|
play_sounds=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -433,7 +432,7 @@ def test_record_with_event_stop_recording(tmp_path, request, robot_type, mock, n
|
||||||
num_episodes=2,
|
num_episodes=2,
|
||||||
push_to_hub=False,
|
push_to_hub=False,
|
||||||
video=False,
|
video=False,
|
||||||
display_cameras=False,
|
display_data=False,
|
||||||
play_sounds=False,
|
play_sounds=False,
|
||||||
num_image_writer_processes=num_image_writer_processes,
|
num_image_writer_processes=num_image_writer_processes,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue