Integrate Trossen AI Arms. (#7)

* Trossen Arms Working

* Trossen Arm Driver

* Clean Up

* Update Config and Debug

* docker doc

* Updated code for  new trossen arm driver

* PR Iteration 1

* Fix Code

* Fix documentation

* Fixed Reviewed Code

* White space fix

---------

Co-authored-by: shantanuparabumd <sparab@umd.edu>
This commit is contained in:
shantanuparab-tr 2025-03-06 14:59:49 -06:00 committed by Luke Schmitt
parent 145fe4cd17
commit 0f26b32c79
No known key found for this signature in database
GPG Key ID: E9957A727DDEB555
9 changed files with 593 additions and 4 deletions

View File

@ -0,0 +1,222 @@
This tutorial explains how to use [Trossen AI Bimanual](https://www.trossenrobotics.com/stationary-ai) with LeRobot.
## Setup
Follow the [documentation from Trossen Robotics](https://docs.trossenrobotics.com/trossen_arm/main/getting_started/hardware_setup.html) for setting up the hardware.
## Install LeRobot
On your computer:
1. [Install Miniconda](https://docs.anaconda.com/miniconda/#quick-command-line-install):
```bash
mkdir -p ~/miniconda3
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
rm ~/miniconda3/miniconda.sh
~/miniconda3/bin/conda init bash
```
2. Restart shell or `source ~/.bashrc`
3. Create and activate a fresh conda environment for lerobot
```bash
conda create -y -n lerobot python=3.10 && conda activate lerobot
```
4. Clone LeRobot:
```bash
git clone https://github.com/Interbotix/lerobot.git ~/lerobot
```
5. Install LeRobot with dependencies for the Trossen AI arms (trossen-arm) and cameras (intelrealsense):
```bash
cd ~/lerobot && pip install -e ".[trossen_ai]"
```
For Linux only (not Mac), install extra dependencies for recording datasets:
```bash
conda install -y -c conda-forge ffmpeg
pip uninstall -y opencv-python
conda install -y -c conda-forge "opencv>=4.10.0"
```
## Troubleshooting
If you encounter the following error.
```bash
ImportError: /xxx/xxx/xxx/envs/lerobot/lib/python3.10/site-packages/cv2/python-3.10/../../../.././libtiff.so.6: undefined symbol: jpeg12_write_raw_data, version LIBJPEG_8.0
```
The below are the 2 known system specific solutions
### System 76 Serval Workstation (serw13) & Dell Precision 7670
```bash
conda install pytorch==2.5.1=cpu_openblas_py310ha613aac_2 -y
conda install torchvision==0.21.0 -y
```
### HP
```bash
pip install torch==2.5.1+cu121 torchvision==0.20.1+cu121 torchaudio==2.5.1+cu121 --index-url https://download.pytorch.org/whl/cu121
```
## Teleoperate
By running the following code, you can start your first **SAFE** teleoperation:
```bash
python lerobot/scripts/control_robot.py \
--robot.type=trossen_ai_bimanual \
--robot.max_relative_target=5 \
--control.type=teleoperate
```
By adding `--robot.max_relative_target=5`, we override the default value for `max_relative_target` defined in [`TrossenAIBimanualRobot`](lerobot/common/robot_devices/robots/configs.py). It is expected to be `5` to limit the magnitude of the movement for more safety, but the teleoperation won't be smooth. When you feel confident, you can disable this limit by adding `--robot.max_relative_target=null` to the command line:
```bash
python lerobot/scripts/control_robot.py \
--robot.type=trossen_ai \
--robot.max_relative_target=null \
--control.type=teleoperate
```
## Record a dataset
Once you're familiar with teleoperation, you can record your first dataset with Trossen AI.
If you want to use the Hugging Face hub features for uploading your dataset and you haven't previously done it, make sure you've logged in using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens):
```bash
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
```
Store your Hugging Face repository name in a variable to run these commands:
```bash
HF_USER=$(huggingface-cli whoami | head -n 1)
echo $HF_USER
```
Record 2 episodes and upload your dataset to the hub:
```bash
python lerobot/scripts/control_robot.py \
--robot.type=trossen_ai_bimanual \
--robot.max_relative_target=null \
--control.type=record \
--control.fps=30 \
--control.single_task="Grasp a lego block and put it in the bin." \
--control.repo_id=${HF_USER}/trossen_ai_bimanual_test \
--control.tags='["tutorial"]' \
--control.warmup_time_s=5 \
--control.episode_time_s=30 \
--control.reset_time_s=30 \
--control.num_episodes=2 \
--control.push_to_hub=true
```
Note: If the camera fps is unstable consider increasing the number of image writers per thread.
```bash
python lerobot/scripts/control_robot.py \
--robot.type=trossen_ai_bimanual \
--robot.max_relative_target=null \
--control.type=record \
--control.fps=30 \
--control.single_task="Grasp a lego block and put it in the bin." \
--control.repo_id=${HF_USER}/trossen_ai_bimanual_test \
--control.tags='["tutorial"]' \
--control.warmup_time_s=5 \
--control.episode_time_s=30 \
--control.reset_time_s=30 \
--control.num_episodes=2 \
--control.push_to_hub=true \
--control.num_image_writer_threads_per_camera = 8
```
## Visualize a dataset
If you uploaded your dataset to the hub with `--control.push_to_hub=true`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by:
```bash
echo ${HF_USER}/trossen_ai_bimanual_test
```
If you didn't upload with `--control.push_to_hub=false`, you can also visualize it locally with:
```bash
python lerobot/scripts/visualize_dataset_html.py \
--repo-id ${HF_USER}/trossen_ai_bimanual_test
```
## Replay an episode
Now try to replay the first episode on your robot:
```bash
python lerobot/scripts/control_robot.py \
--robot.type=trossen_ai_bimanual \
--robot.max_relative_target=null \
--control.type=replay \
--control.fps=30 \
--control.repo_id=${HF_USER}/trossen_ai_bimanual_test \
--control.episode=0
```
## Train a policy
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
```bash
python lerobot/scripts/train.py \
--dataset.repo_id=${HF_USER}/trossen_ai_bimanual_test \
--policy.type=act \
--output_dir=outputs/train/act_trossen_ai_bimanual_test \
--job_name=act_trossen_ai_bimanual_test \
--device=cuda \
--wandb.enable=true
```
Let's explain it:
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/trossen_ai_bimanual_test`.
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you could use `device=mps` to train on Apple silicon.
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
For more information on the `train` script see the previous tutorial: [`examples/4_train_policy_with_script.md`](../examples/4_train_policy_with_script.md)
Training should take several hours. You will find checkpoints in `outputs/train/act_trossen_ai_bimanual_test/checkpoints`.
## Evaluate your policy
You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes:
```bash
python lerobot/scripts/control_robot.py \
--robot.type=trossen_ai_bimanual \
--control.type=record \
--control.fps=30 \
--control.single_task="Grasp a lego block and put it in the bin." \
--control.repo_id=${HF_USER}/eval_act_trossen_ai_bimanual_test \
--control.tags='["tutorial"]' \
--control.warmup_time_s=5 \
--control.episode_time_s=30 \
--control.reset_time_s=30 \
--control.num_episodes=10 \
--control.push_to_hub=true \
--control.policy.path=outputs/train/act_trossen_ai_bimanual_test/checkpoints/last/pretrained_model \
--control.num_image_writer_processes=1
```
As you can see, it's almost the same command as previously used to record your training dataset. Two things changed:
1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_trossen_ai_bimanual_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_trossen_ai_bimanual_test`).
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_trossen_ai_bimanual_test`).
3. We use `--control.num_image_writer_processes=1` instead of the default value (`0`). On our computer, using a dedicated process to write images from the 4 cameras on disk allows to reach constent 30 fps during inference. Feel free to explore different values for `--control.num_image_writer_processes`.
## More
Follow this [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) for a more in-depth explaination.

View File

@ -247,7 +247,7 @@ def encode_video_frames(
imgs_dir: Path | str,
video_path: Path | str,
fps: int,
vcodec: str = "libsvtav1",
vcodec: str = "libx264",
pix_fmt: str = "yuv420p",
g: int | None = 2,
crf: int | None = 30,

View File

@ -39,3 +39,10 @@ class FeetechMotorsBusConfig(MotorsBusConfig):
port: str
motors: dict[str, tuple[int, str]]
mock: bool = False
@MotorsBusConfig.register_subclass("trossen_arm_driver")
@dataclass
class TrossenArmDriverConfig(MotorsBusConfig):
ip: str
model: dict[str, tuple[int, str]]
mock: bool = False

View File

@ -0,0 +1,266 @@
import time
import traceback
import trossen_arm as trossen
import numpy as np
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
from lerobot.common.robot_devices.motors.configs import TrossenArmDriverConfig
PITCH_CIRCLE_RADIUS = 0.00875 # meters
VEL_LIMITS = [3.375, 3.375, 3.375, 7.0, 7.0, 7.0, 12.5 * PITCH_CIRCLE_RADIUS]
TROSSEN_ARM_MODELS = {
"V0_LEADER": [trossen.Model.wxai_v0, trossen.StandardEndEffector.wxai_v0_leader],
"V0_FOLLOWER": [trossen.Model.wxai_v0, trossen.StandardEndEffector.wxai_v0_follower],
}
class TrossenArmDriver:
"""
The `TrossenArmDriver` class provides an interface for controlling
Trossen Robotics' robotic arms. It leverages the trossen_arm for communication with arms.
This class allows for configuration, torque management, and motion control of robotic arms. It includes features for handling connection states, moving the
arm to specified poses, and logging timestamps for debugging and performance analysis.
### Key Features:
- **Multi-motor Control:** Supports multiple motors connected to a bus.
- **Mode Switching:** Enables switching between position and gravity
compensation modes.
- **Home and Sleep Pose Management:** Automatically transitions the arm to home and sleep poses for safe operation.
- **Error Handling:** Raises specific exceptions for connection and operational errors.
- **Logging:** Captures timestamps for operations to aid in debugging.
### Example Usage:
```python
motors = {
"joint_0": (1, "4340"),
"joint_1": (2, "4340"),
"joint_2": (4, "4340"),
"joint_3": (6, "4310"),
"joint_4": (7, "4310"),
"joint_5": (8, "4310"),
"joint_6": (9, "4310"),
}
arm_driver = TrossenArmDriver(
motors=motors,
ip="192.168.1.2",
model="V0_LEADER",
)
arm_driver.connect()
# Read motor positions
positions = arm_driver.read("Present_Position")
# Move to a new position (Home Pose)
# Last joint is the gripper, which is in range [0, 450]
arm_driver.write("Goal_Position", [0, 15, 15, 0, 0, 0, 200])
# Disconnect when done
arm_driver.disconnect()
```
"""
def __init__(
self,
config: TrossenArmDriverConfig,
):
self.ip = config.ip
self.model = config.model
self.mock = config.mock
self.driver = None
self.calibration = None
self.is_connected = False
self.group_readers = {}
self.group_writers = {}
self.logs = {}
self.fps = 30
self.home_pose = [0, np.pi/12, np.pi/12, 0, 0, 0, 0]
self.sleep_pose = [0, 0, 0, 0, 0, 0, 0]
self.motors={
# name: (index, model)
"joint_0": [1, "4340"],
"joint_1": [2, "4340"],
"joint_2": [3, "4340"],
"joint_3": [4, "4310"],
"joint_4": [5, "4310"],
"joint_5": [6, "4310"],
"joint_6": [7, "4310"],
}
self.prev_write_time = 0
self.current_write_time = None
# To prevent DiscontinuityError due to large jumps in position in short time.
# We scale the time to move based on the distance between the start and goal values and the maximum speed of the motors.
# The below factor is used to scale the time to move.
self.TIME_SCALING_FACTOR = 3.0
# Minimum time to move for the arm (This is a tuning parameter)
self.MIN_TIME_TO_MOVE = 3.0 / self.fps
def connect(self):
print(f"Connecting to {self.model} arm at {self.ip}...")
if self.is_connected:
raise RobotDeviceAlreadyConnectedError(
f"TrossenArmDriver({self.ip}) is already connected. Do not call `motors_bus.connect()` twice."
)
print("Initializing the drivers...")
# Initialize the driver
self.driver = trossen.TrossenArmDriver()
# Get the model configuration
try:
model_name, model_end_effector = TROSSEN_ARM_MODELS[self.model]
except KeyError:
raise ValueError(f"Unsupported model: {self.model}")
print("Configuring the drivers...")
# Configure the driver
try:
self.driver.configure(model_name, model_end_effector, self.ip, True)
except Exception:
traceback.print_exc()
print(
f"Failed to configure the driver for the {self.model} arm at {self.ip}."
)
raise
# Move the arms to the home pose
self.driver.set_all_modes(trossen.Mode.position)
self.driver.set_all_positions(self.home_pose, 2.0, True)
# Allow to read and write
self.is_connected = True
def reconnect(self):
try:
model_name, model_end_effector = TROSSEN_ARM_MODELS[self.model]
except KeyError:
raise ValueError(f"Unsupported model: {self.model}")
try:
self.driver.configure(model_name, model_end_effector, self.ip, True)
except Exception:
traceback.print_exc()
print(
f"Failed to configure the driver for the {self.model} arm at {self.ip}."
)
raise
self.is_connected = True
@property
def motor_names(self) -> list[str]:
return list(self.motors.keys())
@property
def motor_models(self) -> list[str]:
return [model for _, model in self.motors.values()]
@property
def motor_indices(self) -> list[int]:
return [idx for idx, _ in self.motors.values()]
def set_calibration(self, calibration: dict[str, list]):
self.calibration = calibration
def apply_calibration_autocorrect(self, values: np.ndarray | list, motor_names: list[str] | None):
pass
def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
pass
def autocorrect_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
pass
def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
pass
def read(self, data_name, motor_names: str | list[str] | None = None):
if not self.is_connected:
raise RobotDeviceNotConnectedError(
f"TrossenArmMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`."
)
start_time = time.perf_counter()
# Read the present position of the motors
if data_name == "Present_Position":
# Get the positions of the motors
values = self.driver.get_positions()
values[:-1] = np.degrees(values[:-1]) # Convert all joints except gripper
values[-1] = values[-1] * 10000 # Convert gripper to range (0-450)
else:
values = None
print(f"Data name: {data_name} is not supported for reading.")
# TODO: Add support for reading other data names as required
self.logs["delta_timestamp_s_read"] = time.perf_counter() - start_time
values = np.array(values, dtype=np.float32)
return values
def compute_time_to_move(self, goal_values: np.ndarray):
# Compute the time to move based on the distance between the start and goal values
# and the maximum speed of the motors
current_pose = self.driver.get_positions()
displacement = abs(goal_values - current_pose)
time_to_move_all_joints = self.TIME_SCALING_FACTOR*displacement / VEL_LIMITS
time_to_move = max(time_to_move_all_joints)
time_to_move = max(time_to_move, self.MIN_TIME_TO_MOVE)
return time_to_move
def write(self, data_name, values: int | float | np.ndarray, motor_names: str | list[str] | None = None):
if not self.is_connected:
raise RobotDeviceNotConnectedError(
f"TrossenAIArm({self.port}) is not connected. You need to run `motors_bus.connect()`."
)
start_time = time.perf_counter()
# Write the goal position of the motors
if data_name == "Goal_Position":
values = np.array(values, dtype=np.float32)
# Convert back to radians for joints
values[:-1] = np.radians(values[:-1]) # Convert all joints except gripper
values[-1] = values[-1] / 10000 # Convert gripper back to range (0-0.045)
self.driver.set_all_positions(values.tolist(), self.compute_time_to_move(values), False)
self.prev_write_time = self.current_write_time
# Enable or disable the torque of the motors
elif data_name == "Torque_Enable":
# Set the arms to POSITION mode
if values == 1:
self.driver.set_all_modes(trossen.Mode.position)
else:
self.driver.set_all_modes(trossen.Mode.external_effort)
self.driver.set_all_external_efforts([0.0] * 7, 0.0, True)
elif data_name == "Reset":
self.driver.set_all_modes(trossen.Mode.position)
self.driver.set_all_positions(self.home_pose, 2.0, True)
else:
print(f"Data name: {data_name} value: {values} is not supported for writing.")
self.logs["delta_timestamp_s_write"] = time.perf_counter() - start_time
def disconnect(self):
if not self.is_connected:
raise RobotDeviceNotConnectedError(
f"Trossen Arm Driver ({self.port}) is not connected. Try running `motors_bus.connect()` first."
)
self.driver.set_all_modes(trossen.Mode.position)
self.driver.set_all_positions(self.home_pose, 2.0, True)
self.driver.set_all_positions(self.sleep_pose, 2.0, True)
self.is_connected = False
def __del__(self):
if getattr(self, "is_connected", False):
self.disconnect()

View File

@ -43,6 +43,10 @@ def make_motors_buses_from_configs(motors_bus_configs: dict[str, MotorsBusConfig
from lerobot.common.robot_devices.motors.feetech import FeetechMotorsBus
motors_buses[key] = FeetechMotorsBus(cfg)
elif cfg.type == "trossen_arm_driver":
from lerobot.common.robot_devices.motors.trossen_arm_driver import TrossenArmDriver
motors_buses[key] = TrossenArmDriver(cfg)
else:
raise ValueError(f"The motor type '{cfg.type}' is not valid.")

View File

@ -27,6 +27,7 @@ from lerobot.common.robot_devices.motors.configs import (
DynamixelMotorsBusConfig,
FeetechMotorsBusConfig,
MotorsBusConfig,
TrossenArmDriverConfig,
)
@ -611,3 +612,76 @@ class LeKiwiRobotConfig(RobotConfig):
)
mock: bool = False
@RobotConfig.register_subclass("trossen_ai_bimanual")
@dataclass
class TrossenAIBimanualRobotConfig(ManipulatorRobotConfig):
# /!\ FOR SAFETY, READ THIS /!\
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
# For Trossen AI Arms, for every goal position request, motor rotations are capped at 5 degrees by default.
# When you feel more confident with teleoperation or running the policy, you can extend
# this safety limit and even removing it by setting it to `null`.
# Also, everything is expected to work safely out-of-the-box, but we highly advise to
# first try to teleoperate the grippers only (by commenting out the rest of the motors in this yaml),
# then to gradually add more motors (by uncommenting), until you can teleoperate both arms fully
max_relative_target: int | None = 5
leader_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"left": TrossenArmDriverConfig(
# wxai
ip="192.168.1.3",
model="V0_LEADER",
),
"right": TrossenArmDriverConfig(
# wxai
ip="192.168.1.2",
model="V0_LEADER",
),
}
)
follower_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"left": TrossenArmDriverConfig(
ip="192.168.1.5",
model="V0_FOLLOWER",
),
"right": TrossenArmDriverConfig(
ip="192.168.1.4",
model = "V0_FOLLOWER",
),
}
)
# Troubleshooting: If one of your IntelRealSense cameras freeze during
# data recording due to bandwidth limit, you might need to plug the camera
# on another USB hub or PCIe card.
cameras: dict[str, CameraConfig] = field(
default_factory=lambda: {
"cam_high": IntelRealSenseCameraConfig(
serial_number=130322270184,
fps=30,
width=640,
height=480,
),
"cam_left_wrist": IntelRealSenseCameraConfig(
serial_number=218622274938,
fps=30,
width=640,
height=480,
),
"cam_right_wrist": IntelRealSenseCameraConfig(
serial_number=128422271347,
fps=30,
width=640,
height=480,
),
}
)
mock: bool = False

View File

@ -160,7 +160,8 @@ class ManipulatorRobot:
):
self.config = config
self.robot_type = self.config.type
self.calibration_dir = Path(self.config.calibration_dir)
if not self.robot_type =="trossen_ai_bimanual":
self.calibration_dir = Path(self.config.calibration_dir)
self.leader_arms = make_motors_buses_from_configs(self.config.leader_arms)
self.follower_arms = make_motors_buses_from_configs(self.config.follower_arms)
self.cameras = make_cameras_from_configs(self.config.cameras)
@ -222,6 +223,15 @@ class ManipulatorRobot:
available_arms.append(arm_id)
return available_arms
def teleop_safety_stop(self):
if self.robot_type in ["trossen_ai_bimanual"]:
for arms in self.follower_arms:
self.follower_arms[arms].write("Reset", 1)
self.follower_arms[arms].write("Torque_Enable", 1)
for arms in self.leader_arms:
self.leader_arms[arms].write("Reset", 1)
self.leader_arms[arms].write("Torque_Enable", 0)
def connect(self):
if self.is_connected:
raise RobotDeviceAlreadyConnectedError(
@ -241,7 +251,7 @@ class ManipulatorRobot:
print(f"Connecting {name} leader arm.")
self.leader_arms[name].connect()
if self.robot_type in ["koch", "koch_bimanual", "aloha"]:
if self.robot_type in ["koch", "koch_bimanual", "aloha", "trossen_ai_bimanual"]:
from lerobot.common.robot_devices.motors.dynamixel import TorqueMode
elif self.robot_type in ["so100", "moss", "lekiwi"]:
from lerobot.common.robot_devices.motors.feetech import TorqueMode
@ -253,7 +263,9 @@ class ManipulatorRobot:
for name in self.leader_arms:
self.leader_arms[name].write("Torque_Enable", TorqueMode.DISABLED.value)
self.activate_calibration()
if not self.robot_type == "trossen_ai_bimanual":
print("Checking if calibration is needed.")
self.activate_calibration()
# Set robot preset (e.g. torque in leader gripper for Koch v1.1)
if self.robot_type in ["koch", "koch_bimanual"]:

View File

@ -24,6 +24,7 @@ from lerobot.common.robot_devices.robots.configs import (
RobotConfig,
So100RobotConfig,
StretchRobotConfig,
TrossenAIBimanualRobotConfig,
)
@ -62,6 +63,8 @@ def make_robot_config(robot_type: str, **kwargs) -> RobotConfig:
return StretchRobotConfig(**kwargs)
elif robot_type == "lekiwi":
return LeKiwiRobotConfig(**kwargs)
elif robot_type == "trossen_ai_bimanual":
return TrossenAIBimanualRobotConfig(**kwargs)
else:
raise ValueError(f"Robot type '{robot_type}' is not available.")

View File

@ -93,6 +93,7 @@ stretch = [
"pynput>=1.7.7",
]
test = ["pytest>=8.1.0", "pytest-cov>=5.0.0", "pyserial>=3.5"]
trossen_ai = ["pyrealsense2>=2.55.1.6486", "trossen-arm>=1.6.0"]
umi = ["imagecodecs>=2024.1.1"]
video_benchmark = ["scikit-image>=0.23.2", "pandas>=2.2.2"]
xarm = ["gym-xarm>=0.1.1 ; python_version < '4.0'"]