Integrate Trossen AI Arms. (#7)
* Trossen Arms Working * Trossen Arm Driver * Clean Up * Update Config and Debug * docker doc * Updated code for new trossen arm driver * PR Iteration 1 * Fix Code * Fix documentation * Fixed Reviewed Code * White space fix --------- Co-authored-by: shantanuparabumd <sparab@umd.edu>
This commit is contained in:
parent
145fe4cd17
commit
0f26b32c79
|
@ -0,0 +1,222 @@
|
|||
This tutorial explains how to use [Trossen AI Bimanual](https://www.trossenrobotics.com/stationary-ai) with LeRobot.
|
||||
|
||||
## Setup
|
||||
|
||||
Follow the [documentation from Trossen Robotics](https://docs.trossenrobotics.com/trossen_arm/main/getting_started/hardware_setup.html) for setting up the hardware.
|
||||
|
||||
|
||||
## Install LeRobot
|
||||
|
||||
On your computer:
|
||||
|
||||
1. [Install Miniconda](https://docs.anaconda.com/miniconda/#quick-command-line-install):
|
||||
```bash
|
||||
mkdir -p ~/miniconda3
|
||||
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
|
||||
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
|
||||
rm ~/miniconda3/miniconda.sh
|
||||
~/miniconda3/bin/conda init bash
|
||||
```
|
||||
|
||||
2. Restart shell or `source ~/.bashrc`
|
||||
|
||||
3. Create and activate a fresh conda environment for lerobot
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.10 && conda activate lerobot
|
||||
```
|
||||
|
||||
4. Clone LeRobot:
|
||||
```bash
|
||||
git clone https://github.com/Interbotix/lerobot.git ~/lerobot
|
||||
```
|
||||
|
||||
5. Install LeRobot with dependencies for the Trossen AI arms (trossen-arm) and cameras (intelrealsense):
|
||||
|
||||
```bash
|
||||
cd ~/lerobot && pip install -e ".[trossen_ai]"
|
||||
```
|
||||
|
||||
For Linux only (not Mac), install extra dependencies for recording datasets:
|
||||
|
||||
```bash
|
||||
conda install -y -c conda-forge ffmpeg
|
||||
pip uninstall -y opencv-python
|
||||
conda install -y -c conda-forge "opencv>=4.10.0"
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
If you encounter the following error.
|
||||
|
||||
```bash
|
||||
ImportError: /xxx/xxx/xxx/envs/lerobot/lib/python3.10/site-packages/cv2/python-3.10/../../../.././libtiff.so.6: undefined symbol: jpeg12_write_raw_data, version LIBJPEG_8.0
|
||||
```
|
||||
|
||||
The below are the 2 known system specific solutions
|
||||
|
||||
### System 76 Serval Workstation (serw13) & Dell Precision 7670
|
||||
|
||||
```bash
|
||||
conda install pytorch==2.5.1=cpu_openblas_py310ha613aac_2 -y
|
||||
conda install torchvision==0.21.0 -y
|
||||
```
|
||||
|
||||
### HP
|
||||
|
||||
```bash
|
||||
pip install torch==2.5.1+cu121 torchvision==0.20.1+cu121 torchaudio==2.5.1+cu121 --index-url https://download.pytorch.org/whl/cu121
|
||||
```
|
||||
|
||||
|
||||
## Teleoperate
|
||||
|
||||
By running the following code, you can start your first **SAFE** teleoperation:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=trossen_ai_bimanual \
|
||||
--robot.max_relative_target=5 \
|
||||
--control.type=teleoperate
|
||||
```
|
||||
|
||||
By adding `--robot.max_relative_target=5`, we override the default value for `max_relative_target` defined in [`TrossenAIBimanualRobot`](lerobot/common/robot_devices/robots/configs.py). It is expected to be `5` to limit the magnitude of the movement for more safety, but the teleoperation won't be smooth. When you feel confident, you can disable this limit by adding `--robot.max_relative_target=null` to the command line:
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=trossen_ai \
|
||||
--robot.max_relative_target=null \
|
||||
--control.type=teleoperate
|
||||
```
|
||||
|
||||
## Record a dataset
|
||||
|
||||
Once you're familiar with teleoperation, you can record your first dataset with Trossen AI.
|
||||
|
||||
If you want to use the Hugging Face hub features for uploading your dataset and you haven't previously done it, make sure you've logged in using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens):
|
||||
|
||||
```bash
|
||||
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
```
|
||||
|
||||
Store your Hugging Face repository name in a variable to run these commands:
|
||||
|
||||
```bash
|
||||
HF_USER=$(huggingface-cli whoami | head -n 1)
|
||||
echo $HF_USER
|
||||
```
|
||||
|
||||
Record 2 episodes and upload your dataset to the hub:
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=trossen_ai_bimanual \
|
||||
--robot.max_relative_target=null \
|
||||
--control.type=record \
|
||||
--control.fps=30 \
|
||||
--control.single_task="Grasp a lego block and put it in the bin." \
|
||||
--control.repo_id=${HF_USER}/trossen_ai_bimanual_test \
|
||||
--control.tags='["tutorial"]' \
|
||||
--control.warmup_time_s=5 \
|
||||
--control.episode_time_s=30 \
|
||||
--control.reset_time_s=30 \
|
||||
--control.num_episodes=2 \
|
||||
--control.push_to_hub=true
|
||||
```
|
||||
|
||||
Note: If the camera fps is unstable consider increasing the number of image writers per thread.
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=trossen_ai_bimanual \
|
||||
--robot.max_relative_target=null \
|
||||
--control.type=record \
|
||||
--control.fps=30 \
|
||||
--control.single_task="Grasp a lego block and put it in the bin." \
|
||||
--control.repo_id=${HF_USER}/trossen_ai_bimanual_test \
|
||||
--control.tags='["tutorial"]' \
|
||||
--control.warmup_time_s=5 \
|
||||
--control.episode_time_s=30 \
|
||||
--control.reset_time_s=30 \
|
||||
--control.num_episodes=2 \
|
||||
--control.push_to_hub=true \
|
||||
--control.num_image_writer_threads_per_camera = 8
|
||||
```
|
||||
|
||||
## Visualize a dataset
|
||||
|
||||
If you uploaded your dataset to the hub with `--control.push_to_hub=true`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by:
|
||||
```bash
|
||||
echo ${HF_USER}/trossen_ai_bimanual_test
|
||||
```
|
||||
|
||||
If you didn't upload with `--control.push_to_hub=false`, you can also visualize it locally with:
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset_html.py \
|
||||
--repo-id ${HF_USER}/trossen_ai_bimanual_test
|
||||
```
|
||||
|
||||
## Replay an episode
|
||||
|
||||
|
||||
Now try to replay the first episode on your robot:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=trossen_ai_bimanual \
|
||||
--robot.max_relative_target=null \
|
||||
--control.type=replay \
|
||||
--control.fps=30 \
|
||||
--control.repo_id=${HF_USER}/trossen_ai_bimanual_test \
|
||||
--control.episode=0
|
||||
```
|
||||
|
||||
## Train a policy
|
||||
|
||||
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--dataset.repo_id=${HF_USER}/trossen_ai_bimanual_test \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/act_trossen_ai_bimanual_test \
|
||||
--job_name=act_trossen_ai_bimanual_test \
|
||||
--device=cuda \
|
||||
--wandb.enable=true
|
||||
```
|
||||
|
||||
Let's explain it:
|
||||
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/trossen_ai_bimanual_test`.
|
||||
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
||||
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you could use `device=mps` to train on Apple silicon.
|
||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||
|
||||
For more information on the `train` script see the previous tutorial: [`examples/4_train_policy_with_script.md`](../examples/4_train_policy_with_script.md)
|
||||
|
||||
Training should take several hours. You will find checkpoints in `outputs/train/act_trossen_ai_bimanual_test/checkpoints`.
|
||||
|
||||
## Evaluate your policy
|
||||
|
||||
You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=trossen_ai_bimanual \
|
||||
--control.type=record \
|
||||
--control.fps=30 \
|
||||
--control.single_task="Grasp a lego block and put it in the bin." \
|
||||
--control.repo_id=${HF_USER}/eval_act_trossen_ai_bimanual_test \
|
||||
--control.tags='["tutorial"]' \
|
||||
--control.warmup_time_s=5 \
|
||||
--control.episode_time_s=30 \
|
||||
--control.reset_time_s=30 \
|
||||
--control.num_episodes=10 \
|
||||
--control.push_to_hub=true \
|
||||
--control.policy.path=outputs/train/act_trossen_ai_bimanual_test/checkpoints/last/pretrained_model \
|
||||
--control.num_image_writer_processes=1
|
||||
```
|
||||
|
||||
As you can see, it's almost the same command as previously used to record your training dataset. Two things changed:
|
||||
1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_trossen_ai_bimanual_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_trossen_ai_bimanual_test`).
|
||||
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_trossen_ai_bimanual_test`).
|
||||
3. We use `--control.num_image_writer_processes=1` instead of the default value (`0`). On our computer, using a dedicated process to write images from the 4 cameras on disk allows to reach constent 30 fps during inference. Feel free to explore different values for `--control.num_image_writer_processes`.
|
||||
|
||||
## More
|
||||
|
||||
Follow this [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) for a more in-depth explaination.
|
||||
|
|
@ -247,7 +247,7 @@ def encode_video_frames(
|
|||
imgs_dir: Path | str,
|
||||
video_path: Path | str,
|
||||
fps: int,
|
||||
vcodec: str = "libsvtav1",
|
||||
vcodec: str = "libx264",
|
||||
pix_fmt: str = "yuv420p",
|
||||
g: int | None = 2,
|
||||
crf: int | None = 30,
|
||||
|
|
|
@ -39,3 +39,10 @@ class FeetechMotorsBusConfig(MotorsBusConfig):
|
|||
port: str
|
||||
motors: dict[str, tuple[int, str]]
|
||||
mock: bool = False
|
||||
|
||||
@MotorsBusConfig.register_subclass("trossen_arm_driver")
|
||||
@dataclass
|
||||
class TrossenArmDriverConfig(MotorsBusConfig):
|
||||
ip: str
|
||||
model: dict[str, tuple[int, str]]
|
||||
mock: bool = False
|
||||
|
|
|
@ -0,0 +1,266 @@
|
|||
import time
|
||||
import traceback
|
||||
import trossen_arm as trossen
|
||||
import numpy as np
|
||||
|
||||
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
|
||||
from lerobot.common.robot_devices.motors.configs import TrossenArmDriverConfig
|
||||
|
||||
PITCH_CIRCLE_RADIUS = 0.00875 # meters
|
||||
VEL_LIMITS = [3.375, 3.375, 3.375, 7.0, 7.0, 7.0, 12.5 * PITCH_CIRCLE_RADIUS]
|
||||
|
||||
TROSSEN_ARM_MODELS = {
|
||||
"V0_LEADER": [trossen.Model.wxai_v0, trossen.StandardEndEffector.wxai_v0_leader],
|
||||
"V0_FOLLOWER": [trossen.Model.wxai_v0, trossen.StandardEndEffector.wxai_v0_follower],
|
||||
}
|
||||
|
||||
class TrossenArmDriver:
|
||||
"""
|
||||
The `TrossenArmDriver` class provides an interface for controlling
|
||||
Trossen Robotics' robotic arms. It leverages the trossen_arm for communication with arms.
|
||||
|
||||
This class allows for configuration, torque management, and motion control of robotic arms. It includes features for handling connection states, moving the
|
||||
arm to specified poses, and logging timestamps for debugging and performance analysis.
|
||||
|
||||
### Key Features:
|
||||
- **Multi-motor Control:** Supports multiple motors connected to a bus.
|
||||
- **Mode Switching:** Enables switching between position and gravity
|
||||
compensation modes.
|
||||
- **Home and Sleep Pose Management:** Automatically transitions the arm to home and sleep poses for safe operation.
|
||||
- **Error Handling:** Raises specific exceptions for connection and operational errors.
|
||||
- **Logging:** Captures timestamps for operations to aid in debugging.
|
||||
|
||||
### Example Usage:
|
||||
```python
|
||||
motors = {
|
||||
"joint_0": (1, "4340"),
|
||||
"joint_1": (2, "4340"),
|
||||
"joint_2": (4, "4340"),
|
||||
"joint_3": (6, "4310"),
|
||||
"joint_4": (7, "4310"),
|
||||
"joint_5": (8, "4310"),
|
||||
"joint_6": (9, "4310"),
|
||||
}
|
||||
arm_driver = TrossenArmDriver(
|
||||
motors=motors,
|
||||
ip="192.168.1.2",
|
||||
model="V0_LEADER",
|
||||
)
|
||||
arm_driver.connect()
|
||||
|
||||
# Read motor positions
|
||||
positions = arm_driver.read("Present_Position")
|
||||
|
||||
# Move to a new position (Home Pose)
|
||||
# Last joint is the gripper, which is in range [0, 450]
|
||||
arm_driver.write("Goal_Position", [0, 15, 15, 0, 0, 0, 200])
|
||||
|
||||
# Disconnect when done
|
||||
arm_driver.disconnect()
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: TrossenArmDriverConfig,
|
||||
):
|
||||
self.ip = config.ip
|
||||
self.model = config.model
|
||||
self.mock = config.mock
|
||||
self.driver = None
|
||||
self.calibration = None
|
||||
self.is_connected = False
|
||||
self.group_readers = {}
|
||||
self.group_writers = {}
|
||||
self.logs = {}
|
||||
self.fps = 30
|
||||
self.home_pose = [0, np.pi/12, np.pi/12, 0, 0, 0, 0]
|
||||
self.sleep_pose = [0, 0, 0, 0, 0, 0, 0]
|
||||
|
||||
self.motors={
|
||||
# name: (index, model)
|
||||
"joint_0": [1, "4340"],
|
||||
"joint_1": [2, "4340"],
|
||||
"joint_2": [3, "4340"],
|
||||
"joint_3": [4, "4310"],
|
||||
"joint_4": [5, "4310"],
|
||||
"joint_5": [6, "4310"],
|
||||
"joint_6": [7, "4310"],
|
||||
}
|
||||
|
||||
self.prev_write_time = 0
|
||||
self.current_write_time = None
|
||||
|
||||
# To prevent DiscontinuityError due to large jumps in position in short time.
|
||||
# We scale the time to move based on the distance between the start and goal values and the maximum speed of the motors.
|
||||
# The below factor is used to scale the time to move.
|
||||
self.TIME_SCALING_FACTOR = 3.0
|
||||
|
||||
# Minimum time to move for the arm (This is a tuning parameter)
|
||||
self.MIN_TIME_TO_MOVE = 3.0 / self.fps
|
||||
|
||||
def connect(self):
|
||||
print(f"Connecting to {self.model} arm at {self.ip}...")
|
||||
if self.is_connected:
|
||||
raise RobotDeviceAlreadyConnectedError(
|
||||
f"TrossenArmDriver({self.ip}) is already connected. Do not call `motors_bus.connect()` twice."
|
||||
)
|
||||
|
||||
print("Initializing the drivers...")
|
||||
|
||||
# Initialize the driver
|
||||
self.driver = trossen.TrossenArmDriver()
|
||||
|
||||
# Get the model configuration
|
||||
try:
|
||||
model_name, model_end_effector = TROSSEN_ARM_MODELS[self.model]
|
||||
except KeyError:
|
||||
raise ValueError(f"Unsupported model: {self.model}")
|
||||
|
||||
print("Configuring the drivers...")
|
||||
|
||||
# Configure the driver
|
||||
try:
|
||||
self.driver.configure(model_name, model_end_effector, self.ip, True)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
print(
|
||||
f"Failed to configure the driver for the {self.model} arm at {self.ip}."
|
||||
)
|
||||
raise
|
||||
|
||||
# Move the arms to the home pose
|
||||
self.driver.set_all_modes(trossen.Mode.position)
|
||||
self.driver.set_all_positions(self.home_pose, 2.0, True)
|
||||
|
||||
# Allow to read and write
|
||||
self.is_connected = True
|
||||
|
||||
|
||||
def reconnect(self):
|
||||
try:
|
||||
model_name, model_end_effector = TROSSEN_ARM_MODELS[self.model]
|
||||
except KeyError:
|
||||
raise ValueError(f"Unsupported model: {self.model}")
|
||||
try:
|
||||
self.driver.configure(model_name, model_end_effector, self.ip, True)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
print(
|
||||
f"Failed to configure the driver for the {self.model} arm at {self.ip}."
|
||||
)
|
||||
raise
|
||||
|
||||
self.is_connected = True
|
||||
|
||||
|
||||
@property
|
||||
def motor_names(self) -> list[str]:
|
||||
return list(self.motors.keys())
|
||||
|
||||
@property
|
||||
def motor_models(self) -> list[str]:
|
||||
return [model for _, model in self.motors.values()]
|
||||
|
||||
@property
|
||||
def motor_indices(self) -> list[int]:
|
||||
return [idx for idx, _ in self.motors.values()]
|
||||
|
||||
def set_calibration(self, calibration: dict[str, list]):
|
||||
self.calibration = calibration
|
||||
|
||||
def apply_calibration_autocorrect(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
pass
|
||||
|
||||
def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
pass
|
||||
|
||||
def autocorrect_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
pass
|
||||
|
||||
def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
pass
|
||||
|
||||
def read(self, data_name, motor_names: str | list[str] | None = None):
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
f"TrossenArmMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`."
|
||||
)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# Read the present position of the motors
|
||||
if data_name == "Present_Position":
|
||||
# Get the positions of the motors
|
||||
values = self.driver.get_positions()
|
||||
values[:-1] = np.degrees(values[:-1]) # Convert all joints except gripper
|
||||
values[-1] = values[-1] * 10000 # Convert gripper to range (0-450)
|
||||
else:
|
||||
values = None
|
||||
print(f"Data name: {data_name} is not supported for reading.")
|
||||
|
||||
# TODO: Add support for reading other data names as required
|
||||
|
||||
self.logs["delta_timestamp_s_read"] = time.perf_counter() - start_time
|
||||
|
||||
values = np.array(values, dtype=np.float32)
|
||||
return values
|
||||
|
||||
def compute_time_to_move(self, goal_values: np.ndarray):
|
||||
# Compute the time to move based on the distance between the start and goal values
|
||||
# and the maximum speed of the motors
|
||||
current_pose = self.driver.get_positions()
|
||||
displacement = abs(goal_values - current_pose)
|
||||
time_to_move_all_joints = self.TIME_SCALING_FACTOR*displacement / VEL_LIMITS
|
||||
time_to_move = max(time_to_move_all_joints)
|
||||
time_to_move = max(time_to_move, self.MIN_TIME_TO_MOVE)
|
||||
return time_to_move
|
||||
|
||||
def write(self, data_name, values: int | float | np.ndarray, motor_names: str | list[str] | None = None):
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
f"TrossenAIArm({self.port}) is not connected. You need to run `motors_bus.connect()`."
|
||||
)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# Write the goal position of the motors
|
||||
if data_name == "Goal_Position":
|
||||
values = np.array(values, dtype=np.float32)
|
||||
# Convert back to radians for joints
|
||||
values[:-1] = np.radians(values[:-1]) # Convert all joints except gripper
|
||||
values[-1] = values[-1] / 10000 # Convert gripper back to range (0-0.045)
|
||||
self.driver.set_all_positions(values.tolist(), self.compute_time_to_move(values), False)
|
||||
self.prev_write_time = self.current_write_time
|
||||
|
||||
# Enable or disable the torque of the motors
|
||||
elif data_name == "Torque_Enable":
|
||||
# Set the arms to POSITION mode
|
||||
if values == 1:
|
||||
self.driver.set_all_modes(trossen.Mode.position)
|
||||
else:
|
||||
self.driver.set_all_modes(trossen.Mode.external_effort)
|
||||
self.driver.set_all_external_efforts([0.0] * 7, 0.0, True)
|
||||
elif data_name == "Reset":
|
||||
self.driver.set_all_modes(trossen.Mode.position)
|
||||
self.driver.set_all_positions(self.home_pose, 2.0, True)
|
||||
else:
|
||||
print(f"Data name: {data_name} value: {values} is not supported for writing.")
|
||||
|
||||
self.logs["delta_timestamp_s_write"] = time.perf_counter() - start_time
|
||||
|
||||
def disconnect(self):
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
f"Trossen Arm Driver ({self.port}) is not connected. Try running `motors_bus.connect()` first."
|
||||
)
|
||||
self.driver.set_all_modes(trossen.Mode.position)
|
||||
self.driver.set_all_positions(self.home_pose, 2.0, True)
|
||||
self.driver.set_all_positions(self.sleep_pose, 2.0, True)
|
||||
|
||||
self.is_connected = False
|
||||
|
||||
def __del__(self):
|
||||
if getattr(self, "is_connected", False):
|
||||
self.disconnect()
|
|
@ -43,6 +43,10 @@ def make_motors_buses_from_configs(motors_bus_configs: dict[str, MotorsBusConfig
|
|||
from lerobot.common.robot_devices.motors.feetech import FeetechMotorsBus
|
||||
|
||||
motors_buses[key] = FeetechMotorsBus(cfg)
|
||||
elif cfg.type == "trossen_arm_driver":
|
||||
from lerobot.common.robot_devices.motors.trossen_arm_driver import TrossenArmDriver
|
||||
|
||||
motors_buses[key] = TrossenArmDriver(cfg)
|
||||
|
||||
else:
|
||||
raise ValueError(f"The motor type '{cfg.type}' is not valid.")
|
||||
|
|
|
@ -27,6 +27,7 @@ from lerobot.common.robot_devices.motors.configs import (
|
|||
DynamixelMotorsBusConfig,
|
||||
FeetechMotorsBusConfig,
|
||||
MotorsBusConfig,
|
||||
TrossenArmDriverConfig,
|
||||
)
|
||||
|
||||
|
||||
|
@ -611,3 +612,76 @@ class LeKiwiRobotConfig(RobotConfig):
|
|||
)
|
||||
|
||||
mock: bool = False
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("trossen_ai_bimanual")
|
||||
@dataclass
|
||||
class TrossenAIBimanualRobotConfig(ManipulatorRobotConfig):
|
||||
|
||||
# /!\ FOR SAFETY, READ THIS /!\
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
# For Trossen AI Arms, for every goal position request, motor rotations are capped at 5 degrees by default.
|
||||
# When you feel more confident with teleoperation or running the policy, you can extend
|
||||
# this safety limit and even removing it by setting it to `null`.
|
||||
# Also, everything is expected to work safely out-of-the-box, but we highly advise to
|
||||
# first try to teleoperate the grippers only (by commenting out the rest of the motors in this yaml),
|
||||
# then to gradually add more motors (by uncommenting), until you can teleoperate both arms fully
|
||||
max_relative_target: int | None = 5
|
||||
|
||||
leader_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"left": TrossenArmDriverConfig(
|
||||
# wxai
|
||||
ip="192.168.1.3",
|
||||
model="V0_LEADER",
|
||||
),
|
||||
"right": TrossenArmDriverConfig(
|
||||
# wxai
|
||||
ip="192.168.1.2",
|
||||
model="V0_LEADER",
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
follower_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"left": TrossenArmDriverConfig(
|
||||
ip="192.168.1.5",
|
||||
model="V0_FOLLOWER",
|
||||
),
|
||||
"right": TrossenArmDriverConfig(
|
||||
ip="192.168.1.4",
|
||||
model = "V0_FOLLOWER",
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
# Troubleshooting: If one of your IntelRealSense cameras freeze during
|
||||
# data recording due to bandwidth limit, you might need to plug the camera
|
||||
# on another USB hub or PCIe card.
|
||||
cameras: dict[str, CameraConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"cam_high": IntelRealSenseCameraConfig(
|
||||
serial_number=130322270184,
|
||||
fps=30,
|
||||
width=640,
|
||||
height=480,
|
||||
),
|
||||
"cam_left_wrist": IntelRealSenseCameraConfig(
|
||||
serial_number=218622274938,
|
||||
fps=30,
|
||||
width=640,
|
||||
height=480,
|
||||
),
|
||||
"cam_right_wrist": IntelRealSenseCameraConfig(
|
||||
serial_number=128422271347,
|
||||
fps=30,
|
||||
width=640,
|
||||
height=480,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
mock: bool = False
|
||||
|
|
|
@ -160,7 +160,8 @@ class ManipulatorRobot:
|
|||
):
|
||||
self.config = config
|
||||
self.robot_type = self.config.type
|
||||
self.calibration_dir = Path(self.config.calibration_dir)
|
||||
if not self.robot_type =="trossen_ai_bimanual":
|
||||
self.calibration_dir = Path(self.config.calibration_dir)
|
||||
self.leader_arms = make_motors_buses_from_configs(self.config.leader_arms)
|
||||
self.follower_arms = make_motors_buses_from_configs(self.config.follower_arms)
|
||||
self.cameras = make_cameras_from_configs(self.config.cameras)
|
||||
|
@ -222,6 +223,15 @@ class ManipulatorRobot:
|
|||
available_arms.append(arm_id)
|
||||
return available_arms
|
||||
|
||||
def teleop_safety_stop(self):
|
||||
if self.robot_type in ["trossen_ai_bimanual"]:
|
||||
for arms in self.follower_arms:
|
||||
self.follower_arms[arms].write("Reset", 1)
|
||||
self.follower_arms[arms].write("Torque_Enable", 1)
|
||||
for arms in self.leader_arms:
|
||||
self.leader_arms[arms].write("Reset", 1)
|
||||
self.leader_arms[arms].write("Torque_Enable", 0)
|
||||
|
||||
def connect(self):
|
||||
if self.is_connected:
|
||||
raise RobotDeviceAlreadyConnectedError(
|
||||
|
@ -241,7 +251,7 @@ class ManipulatorRobot:
|
|||
print(f"Connecting {name} leader arm.")
|
||||
self.leader_arms[name].connect()
|
||||
|
||||
if self.robot_type in ["koch", "koch_bimanual", "aloha"]:
|
||||
if self.robot_type in ["koch", "koch_bimanual", "aloha", "trossen_ai_bimanual"]:
|
||||
from lerobot.common.robot_devices.motors.dynamixel import TorqueMode
|
||||
elif self.robot_type in ["so100", "moss", "lekiwi"]:
|
||||
from lerobot.common.robot_devices.motors.feetech import TorqueMode
|
||||
|
@ -253,7 +263,9 @@ class ManipulatorRobot:
|
|||
for name in self.leader_arms:
|
||||
self.leader_arms[name].write("Torque_Enable", TorqueMode.DISABLED.value)
|
||||
|
||||
self.activate_calibration()
|
||||
if not self.robot_type == "trossen_ai_bimanual":
|
||||
print("Checking if calibration is needed.")
|
||||
self.activate_calibration()
|
||||
|
||||
# Set robot preset (e.g. torque in leader gripper for Koch v1.1)
|
||||
if self.robot_type in ["koch", "koch_bimanual"]:
|
||||
|
|
|
@ -24,6 +24,7 @@ from lerobot.common.robot_devices.robots.configs import (
|
|||
RobotConfig,
|
||||
So100RobotConfig,
|
||||
StretchRobotConfig,
|
||||
TrossenAIBimanualRobotConfig,
|
||||
)
|
||||
|
||||
|
||||
|
@ -62,6 +63,8 @@ def make_robot_config(robot_type: str, **kwargs) -> RobotConfig:
|
|||
return StretchRobotConfig(**kwargs)
|
||||
elif robot_type == "lekiwi":
|
||||
return LeKiwiRobotConfig(**kwargs)
|
||||
elif robot_type == "trossen_ai_bimanual":
|
||||
return TrossenAIBimanualRobotConfig(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Robot type '{robot_type}' is not available.")
|
||||
|
||||
|
|
|
@ -93,6 +93,7 @@ stretch = [
|
|||
"pynput>=1.7.7",
|
||||
]
|
||||
test = ["pytest>=8.1.0", "pytest-cov>=5.0.0", "pyserial>=3.5"]
|
||||
trossen_ai = ["pyrealsense2>=2.55.1.6486", "trossen-arm>=1.6.0"]
|
||||
umi = ["imagecodecs>=2024.1.1"]
|
||||
video_benchmark = ["scikit-image>=0.23.2", "pandas>=2.2.2"]
|
||||
xarm = ["gym-xarm>=0.1.1 ; python_version < '4.0'"]
|
||||
|
|
Loading…
Reference in New Issue