Merge 8e28f1ee60
into 1c873df5c0
This commit is contained in:
commit
02576b7a65
|
@ -0,0 +1,144 @@
|
|||
# Using the [roarm_m3](https://github.com/waveshareteam/roarm_m3) with LeRobot
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [A. Install LeRobot](#a-install-lerobot)
|
||||
- [B. Teleoperate](#b-teleoperate)
|
||||
- [C. Record a dataset](#c-record-a-dataset)
|
||||
- [D. Visualize a dataset](#d-visualize-a-dataset)
|
||||
- [E. Replay an episode](#e-replay-an-episode)
|
||||
- [F. Train a policy](#f-train-a-policy)
|
||||
- [G. Evaluate your policy](#g-evaluate-your-policy)
|
||||
- [H. More Information](#h-more-information)
|
||||
|
||||
## A. Install LeRobot
|
||||
|
||||
Before running the following commands, make sure you have installed LeRobot by following the [installation instructions](https://github.com/lerobot/lerobot/blob/main/README.md).
|
||||
|
||||
## B. Teleoperate
|
||||
|
||||
**Simple teleop**
|
||||
#### a. Teleop without displaying cameras
|
||||
You will be able to teleoperate your robot! (it won't connect and display the cameras):
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=roarm_m3 \
|
||||
--robot.cameras='{}' \
|
||||
--control.type=teleoperate
|
||||
```
|
||||
|
||||
#### b. Teleop with displaying cameras
|
||||
You will be able to display the cameras while you are teleoperating by running the following code. This is useful to prepare your setup before recording your first dataset.
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=roarm_m3 \
|
||||
--control.type=teleoperate
|
||||
```
|
||||
|
||||
## C. Record a dataset
|
||||
|
||||
Once you're familiar with teleoperation, you can record your first dataset with roarm_m3.
|
||||
|
||||
If you want to use the Hugging Face hub features for uploading your dataset and you haven't previously done it, make sure you've logged in using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens):
|
||||
```bash
|
||||
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
```
|
||||
|
||||
Store your Hugging Face repository name in a variable to run these commands:
|
||||
```bash
|
||||
HF_USER=$(huggingface-cli whoami | head -n 1)
|
||||
echo $HF_USER
|
||||
```
|
||||
|
||||
Record 2 episodes and upload your dataset to the hub:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=roarm_m3 \
|
||||
--control.type=record \
|
||||
--control.fps=30 \
|
||||
--control.single_task="Grasp a block and put it in the bin." \
|
||||
--control.repo_id=${HF_USER}/roarm_m3_test \
|
||||
--control.tags='["roarm_m3","tutorial"]' \
|
||||
--control.warmup_time_s=5 \
|
||||
--control.episode_time_s=30 \
|
||||
--control.reset_time_s=30 \
|
||||
--control.num_episodes=2 \
|
||||
--control.push_to_hub=true
|
||||
```
|
||||
|
||||
Note: You can resume recording by adding `--control.resume=true`.
|
||||
|
||||
## D. Visualize a dataset
|
||||
|
||||
If you uploaded your dataset to the hub with `--control.push_to_hub=true`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by:
|
||||
```bash
|
||||
echo ${HF_USER}/roarm_m3_test
|
||||
```
|
||||
|
||||
If you didn't upload with `--control.push_to_hub=false`, you can also visualize it locally with (a window can be opened in the browser `http://ip:9090` with the visualization tool):
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset_html.py \
|
||||
--repo-id ${HF_USER}/roarm_m3_test \
|
||||
--host ip \
|
||||
--local-files-only 1
|
||||
```
|
||||
|
||||
## E. Replay an episode
|
||||
|
||||
Now try to replay episode nth on your bot:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=roarm_m3 \
|
||||
--control.type=replay \
|
||||
--control.fps=30 \
|
||||
--control.repo_id=${HF_USER}/roarm_m3_test \
|
||||
--control.episode=n-1
|
||||
```
|
||||
|
||||
## F. Train a policy
|
||||
|
||||
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--dataset.repo_id=${HF_USER}/roarm_m3_test \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/act_roarm_m3_test \
|
||||
--job_name=act_roarm_m3_test \
|
||||
--device=cuda \
|
||||
--wandb.enable=true
|
||||
```
|
||||
|
||||
Let's explain it:
|
||||
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/roarm_m3_test`.
|
||||
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
||||
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you could use `device=mps` to train on Apple silicon.
|
||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||
|
||||
Training should take several hours. You will find checkpoints in `outputs/train/act_roarm_m3_test/checkpoints`.
|
||||
|
||||
## G. Evaluate your policy
|
||||
|
||||
You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=roarm_m3 \
|
||||
--control.type=record \
|
||||
--control.fps=30 \
|
||||
--control.single_task="Grasp a block and put it in the bin." \
|
||||
--control.repo_id=${HF_USER}/eval_act_roarm_m3_test \
|
||||
--control.tags='["tutorial"]' \
|
||||
--control.warmup_time_s=5 \
|
||||
--control.episode_time_s=30 \
|
||||
--control.reset_time_s=30 \
|
||||
--control.num_episodes=10 \
|
||||
--control.push_to_hub=true \
|
||||
--control.policy.path=outputs/train/act_roarm_m3_test/checkpoints/last/pretrained_model
|
||||
```
|
||||
|
||||
As you can see, it's almost the same command as previously used to record your training dataset. Two things changed:
|
||||
1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_roarm_m3_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_roarm_m3_test`).
|
||||
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_roarm_m3_test`).
|
||||
|
||||
## H. More Information
|
||||
|
||||
Follow this [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) for a more in-depth tutorial on controlling real robots with LeRobot.
|
|
@ -166,6 +166,9 @@ def parse_robot_config(robot_cfg: RobotConfig) -> tuple[str, dict]:
|
|||
for arm in robot_cfg.leader_arms
|
||||
for motor in robot_cfg.leader_arms[arm].motors
|
||||
]
|
||||
elif robot_cfg.type == "roarm_m3":
|
||||
state_names = ["1", "2", "3", "4", "5", "6"]
|
||||
action_names = ["1", "2", "3", "4", "5", "6"]
|
||||
# elif robot_cfg["robot_type"] == "stretch3": TODO
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
|
|
|
@ -44,6 +44,20 @@ from lerobot.common.utils.utils import capture_timestamp_utc
|
|||
# treat the same cameras as new devices. Thus we select a higher bound to search indices.
|
||||
MAX_OPENCV_INDEX = 60
|
||||
|
||||
undistort = True
|
||||
|
||||
|
||||
def undistort_image(image):
|
||||
import cv2
|
||||
|
||||
camera_matrix = np.array([[289.11451, 0.0, 347.23664], [0.0, 289.75319, 235.67429], [0.0, 0.0, 1.0]])
|
||||
|
||||
dist_coeffs = np.array([-0.208848, 0.028006, -0.000705, -0.000820, 0.000000])
|
||||
|
||||
undistorted_image = cv2.undistort(image, camera_matrix, dist_coeffs)
|
||||
|
||||
return undistorted_image
|
||||
|
||||
|
||||
def find_cameras(raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX, mock=False) -> list[dict]:
|
||||
cameras = []
|
||||
|
@ -404,6 +418,9 @@ class OpenCVCamera:
|
|||
|
||||
color_image = cv2.cvtColor(color_image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
if undistort:
|
||||
color_image = undistort_image(color_image)
|
||||
|
||||
h, w, _ = color_image.shape
|
||||
if h != self.capture_height or w != self.capture_width:
|
||||
raise OSError(
|
||||
|
|
|
@ -128,42 +128,73 @@ def predict_action(observation, policy, device, use_amp):
|
|||
return action
|
||||
|
||||
|
||||
# def init_keyboard_listener():
|
||||
# # Allow to exit early while recording an episode or resetting the environment,
|
||||
# # by tapping the right arrow key '->'. This might require a sudo permission
|
||||
# # to allow your terminal to monitor keyboard events.
|
||||
# events = {}
|
||||
# events["exit_early"] = False
|
||||
# events["rerecord_episode"] = False
|
||||
# events["stop_recording"] = False
|
||||
|
||||
# if is_headless():
|
||||
# logging.warning(
|
||||
# "Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
|
||||
# )
|
||||
# listener = None
|
||||
# return listener, events
|
||||
|
||||
# # Only import pynput if not in a headless environment
|
||||
# from pynput import keyboard
|
||||
|
||||
# def on_press(key):
|
||||
# try:
|
||||
# if key == keyboard.Key.right:
|
||||
# print("Right arrow key pressed. Exiting loop...")
|
||||
# events["exit_early"] = True
|
||||
# elif key == keyboard.Key.left:
|
||||
# print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
|
||||
# events["rerecord_episode"] = True
|
||||
# events["exit_early"] = True
|
||||
# elif key == keyboard.Key.esc:
|
||||
# print("Escape key pressed. Stopping data recording...")
|
||||
# events["stop_recording"] = True
|
||||
# events["exit_early"] = True
|
||||
# except Exception as e:
|
||||
# print(f"Error handling key press: {e}")
|
||||
|
||||
# listener = keyboard.Listener(on_press=on_press)
|
||||
# listener.start()
|
||||
|
||||
# return listener, events
|
||||
|
||||
|
||||
def init_keyboard_listener():
|
||||
# Allow to exit early while recording an episode or resetting the environment,
|
||||
# by tapping the right arrow key '->'. This might require a sudo permission
|
||||
# to allow your terminal to monitor keyboard events.
|
||||
events = {}
|
||||
events["exit_early"] = False
|
||||
events["rerecord_episode"] = False
|
||||
events["stop_recording"] = False
|
||||
import threading
|
||||
|
||||
if is_headless():
|
||||
logging.warning(
|
||||
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
|
||||
)
|
||||
listener = None
|
||||
return listener, events
|
||||
|
||||
# Only import pynput if not in a headless environment
|
||||
from pynput import keyboard
|
||||
from sshkeyboard import listen_keyboard
|
||||
|
||||
def on_press(key):
|
||||
try:
|
||||
if key == keyboard.Key.right:
|
||||
if key == "right":
|
||||
print("Right arrow key pressed. Exiting loop...")
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.left:
|
||||
elif key == "left":
|
||||
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
|
||||
events["rerecord_episode"] = True
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.esc:
|
||||
print("Escape key pressed. Stopping data recording...")
|
||||
elif key == "q":
|
||||
print("Q key pressed. Stopping data recording...")
|
||||
events["stop_recording"] = True
|
||||
events["exit_early"] = True
|
||||
except Exception as e:
|
||||
print(f"Error handling key press: {e}")
|
||||
|
||||
listener = keyboard.Listener(on_press=on_press)
|
||||
listener = threading.Thread(target=listen_keyboard, kwargs={"on_press": on_press})
|
||||
listener.start()
|
||||
|
||||
return listener, events
|
||||
|
@ -264,7 +295,8 @@ def control_loop(
|
|||
frame = {**observation, **action, "task": single_task}
|
||||
dataset.add_frame(frame)
|
||||
|
||||
if display_cameras and not is_headless():
|
||||
if display_cameras:
|
||||
# if display_cameras and not is_headless():
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
for key in image_keys:
|
||||
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||
|
@ -297,15 +329,27 @@ def reset_environment(robot, events, reset_time_s, fps):
|
|||
)
|
||||
|
||||
|
||||
# def stop_recording(robot, listener, display_cameras):
|
||||
# robot.disconnect()
|
||||
|
||||
# if not is_headless():
|
||||
# if listener is not None:
|
||||
# listener.stop()
|
||||
|
||||
# if display_cameras:
|
||||
# cv2.destroyAllWindows()
|
||||
|
||||
|
||||
def stop_recording(robot, listener, display_cameras):
|
||||
robot.disconnect()
|
||||
|
||||
if not is_headless():
|
||||
if listener is not None:
|
||||
listener.stop()
|
||||
from sshkeyboard import stop_listening
|
||||
|
||||
if display_cameras:
|
||||
cv2.destroyAllWindows()
|
||||
if listener is not None:
|
||||
stop_listening()
|
||||
|
||||
if display_cameras:
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
|
||||
def sanity_check_dataset_name(repo_id, policy_cfg):
|
||||
|
|
|
@ -494,6 +494,47 @@ class So100RobotConfig(ManipulatorRobotConfig):
|
|||
mock: bool = False
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("roarm_m3")
|
||||
@dataclass
|
||||
class RoarmRobotConfig(RobotConfig):
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
|
||||
leader_arms: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"main": "/dev/ttyUSB0",
|
||||
}
|
||||
)
|
||||
|
||||
# Follower arms configuration: left and right ports
|
||||
follower_arms: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"main": "/dev/ttyUSB1",
|
||||
}
|
||||
)
|
||||
|
||||
cameras: dict[str, CameraConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"laptop": OpenCVCameraConfig(
|
||||
camera_index=0,
|
||||
fps=30,
|
||||
width=640,
|
||||
height=480,
|
||||
),
|
||||
"phone": OpenCVCameraConfig(
|
||||
camera_index=2,
|
||||
fps=30,
|
||||
width=640,
|
||||
height=480,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
mock: bool = False
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("stretch")
|
||||
@dataclass
|
||||
class StretchRobotConfig(RobotConfig):
|
||||
|
|
|
@ -0,0 +1,305 @@
|
|||
"""Contains logic to instantiate a robot, read information from its motors and cameras,
|
||||
and send orders to its motors.
|
||||
"""
|
||||
# TODO(rcadene, aliberts): reorganize the codebase into one file per robot, with the associated
|
||||
# calibration procedure, to make it easy for people to add their own robot.
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from roarm_sdk.roarm import roarm
|
||||
|
||||
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.common.robot_devices.robots.configs import RoarmRobotConfig
|
||||
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
|
||||
|
||||
|
||||
def ensure_safe_goal_position(
|
||||
goal_pos: torch.Tensor, present_pos: torch.Tensor, max_relative_target: float | list[float]
|
||||
):
|
||||
# Cap relative action target magnitude for safety.
|
||||
diff = goal_pos - present_pos
|
||||
max_relative_target = torch.tensor(max_relative_target)
|
||||
safe_diff = torch.minimum(diff, max_relative_target)
|
||||
safe_diff = torch.maximum(safe_diff, -max_relative_target)
|
||||
safe_goal_pos = present_pos + safe_diff
|
||||
|
||||
if not torch.allclose(goal_pos, safe_goal_pos):
|
||||
logging.warning(
|
||||
"Relative goal position magnitude had to be clamped to be safe.\n"
|
||||
f" requested relative goal position target: {diff}\n"
|
||||
f" clamped relative goal position target: {safe_diff}"
|
||||
)
|
||||
|
||||
return safe_goal_pos
|
||||
|
||||
|
||||
def make_roarm_from_configs(configs: dict[str, str]) -> dict[str, roarm]:
|
||||
roarms = {}
|
||||
|
||||
for key, port in configs.items():
|
||||
roarms[key] = roarm(roarm_type="roarm_m3", port=port, baudrate=115200)
|
||||
|
||||
return roarms
|
||||
|
||||
|
||||
class RoarmRobot:
|
||||
def __init__(
|
||||
self,
|
||||
config: RoarmRobotConfig,
|
||||
):
|
||||
self.config = config
|
||||
self.robot_type = self.config.type
|
||||
self.leader_arms = make_roarm_from_configs(self.config.leader_arms)
|
||||
self.follower_arms = make_roarm_from_configs(self.config.follower_arms)
|
||||
self.cameras = make_cameras_from_configs(self.config.cameras)
|
||||
self.is_connected = False
|
||||
self.logs = {}
|
||||
|
||||
@property
|
||||
def camera_features(self) -> dict:
|
||||
cam_ft = {}
|
||||
for cam_key, cam in self.cameras.items():
|
||||
key = f"observation.images.{cam_key}"
|
||||
cam_ft[key] = {
|
||||
"shape": (cam.height, cam.width, cam.channels),
|
||||
"names": ["height", "width", "channels"],
|
||||
"info": None,
|
||||
}
|
||||
return cam_ft
|
||||
|
||||
@property
|
||||
def motor_features(self) -> dict:
|
||||
return {
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (6,),
|
||||
"names": ["1", "2,3", "4", "5", "6"],
|
||||
},
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": (6,),
|
||||
"names": ["1", "2,3", "4", "5", "6"],
|
||||
},
|
||||
}
|
||||
|
||||
@property
|
||||
def features(self):
|
||||
return {**self.motor_features, **self.camera_features}
|
||||
|
||||
@property
|
||||
def has_camera(self):
|
||||
return len(self.cameras) > 0
|
||||
|
||||
@property
|
||||
def num_cameras(self):
|
||||
return len(self.cameras)
|
||||
|
||||
def connect(self):
|
||||
if self.is_connected:
|
||||
raise RobotDeviceAlreadyConnectedError(
|
||||
"RoarmRobot is already connected. Do not run `robot.connect()` twice."
|
||||
)
|
||||
|
||||
if not self.leader_arms and not self.follower_arms and not self.cameras:
|
||||
raise ValueError(
|
||||
"RoarmRobot doesn't have any device to connect. See example of usage in docstring of the class."
|
||||
)
|
||||
|
||||
for name in self.follower_arms:
|
||||
self.follower_arms[name].joints_angle_get()
|
||||
for name in self.leader_arms:
|
||||
self.leader_arms[name].joints_angle_get()
|
||||
|
||||
# Connect the cameras
|
||||
for name in self.cameras:
|
||||
self.cameras[name].connect()
|
||||
|
||||
self.is_connected = True
|
||||
|
||||
def teleop_step(
|
||||
self, record_data=False
|
||||
) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
"RoarmRobot is not connected. You need to run `robot.connect()`."
|
||||
)
|
||||
|
||||
# Prepare to assign the position of the leader to the follower
|
||||
leader_pos = {}
|
||||
for name in self.leader_arms:
|
||||
before_lread_t = time.perf_counter()
|
||||
leader_pos[name] = np.array(self.leader_arms[name].joints_angle_get(), dtype=np.float32)
|
||||
leader_pos[name] = torch.from_numpy(leader_pos[name])
|
||||
self.logs[f"read_leader_{name}_pos_dt_s"] = time.perf_counter() - before_lread_t
|
||||
|
||||
# Send goal position to the follower
|
||||
follower_goal_pos = {}
|
||||
for name in self.follower_arms:
|
||||
before_fwrite_t = time.perf_counter()
|
||||
goal_pos = leader_pos[name]
|
||||
|
||||
# Cap goal position when too far away from present position.
|
||||
# Slower fps expected due to reading from the follower.
|
||||
if self.config.max_relative_target is not None:
|
||||
present_pos = np.array(self.follower_arms[name].joints_angle_get(), dtype=np.float32)
|
||||
present_pos = torch.from_numpy(present_pos)
|
||||
goal_pos = ensure_safe_goal_position(goal_pos, present_pos, self.config.max_relative_target)
|
||||
|
||||
# Used when record_data=True
|
||||
follower_goal_pos[name] = goal_pos
|
||||
|
||||
goal_pos = goal_pos.numpy().astype(np.int32)
|
||||
goal_pos = goal_pos.tolist() if isinstance(goal_pos, (np.ndarray, torch.Tensor)) else goal_pos
|
||||
self.follower_arms[name].joints_angle_ctrl(angles=goal_pos, speed=0, acc=0)
|
||||
self.logs[f"write_follower_{name}_goal_pos_dt_s"] = time.perf_counter() - before_fwrite_t
|
||||
|
||||
# Early exit when recording data is not requested
|
||||
if not record_data:
|
||||
return
|
||||
|
||||
# TODO(rcadene): Add velocity and other info
|
||||
# Read follower position
|
||||
follower_pos = {}
|
||||
for name in self.follower_arms:
|
||||
before_fread_t = time.perf_counter()
|
||||
follower_pos[name] = np.array(self.follower_arms[name].joints_angle_get(), dtype=np.float32)
|
||||
follower_pos[name] = torch.from_numpy(follower_pos[name])
|
||||
self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - before_fread_t
|
||||
|
||||
# Create state by concatenating follower current position
|
||||
state = []
|
||||
for name in self.follower_arms:
|
||||
if name in follower_pos:
|
||||
state.append(follower_pos[name])
|
||||
state = torch.cat(state)
|
||||
|
||||
# Create action by concatenating follower goal position
|
||||
action = []
|
||||
for name in self.follower_arms:
|
||||
if name in follower_goal_pos:
|
||||
action.append(follower_goal_pos[name])
|
||||
action = torch.cat(action)
|
||||
|
||||
# Capture images from cameras
|
||||
images = {}
|
||||
for name in self.cameras:
|
||||
before_camread_t = time.perf_counter()
|
||||
images[name] = self.cameras[name].async_read()
|
||||
images[name] = torch.from_numpy(images[name])
|
||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
||||
|
||||
# Populate output dictionaries
|
||||
obs_dict, action_dict = {}, {}
|
||||
obs_dict["observation.state"] = state
|
||||
action_dict["action"] = action
|
||||
for name in self.cameras:
|
||||
obs_dict[f"observation.images.{name}"] = images[name]
|
||||
|
||||
return obs_dict, action_dict
|
||||
|
||||
def capture_observation(self):
|
||||
"""The returned observations do not have a batch dimension."""
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
"RoarmRobot is not connected. You need to run `robot.connect()`."
|
||||
)
|
||||
|
||||
# Read follower position
|
||||
follower_pos = {}
|
||||
for name in self.follower_arms:
|
||||
before_fread_t = time.perf_counter()
|
||||
follower_pos[name] = np.array(self.follower_arms[name].joints_angle_get(), dtype=np.float32)
|
||||
follower_pos[name] = torch.from_numpy(follower_pos[name])
|
||||
self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - before_fread_t
|
||||
|
||||
# Create state by concatenating follower current position
|
||||
state = []
|
||||
for name in self.follower_arms:
|
||||
if name in follower_pos:
|
||||
state.append(follower_pos[name])
|
||||
state = torch.cat(state)
|
||||
|
||||
# Capture images from cameras
|
||||
images = {}
|
||||
for name in self.cameras:
|
||||
before_camread_t = time.perf_counter()
|
||||
images[name] = self.cameras[name].async_read()
|
||||
images[name] = torch.from_numpy(images[name])
|
||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
||||
|
||||
# Populate output dictionaries and format to pytorch
|
||||
obs_dict = {}
|
||||
obs_dict["observation.state"] = state
|
||||
for name in self.cameras:
|
||||
obs_dict[f"observation.images.{name}"] = images[name]
|
||||
return obs_dict
|
||||
|
||||
def send_action(self, action: torch.Tensor) -> torch.Tensor:
|
||||
"""Command the follower arms to move to a target joint configuration.
|
||||
|
||||
The relative action magnitude may be clipped depending on the configuration parameter
|
||||
`max_relative_target`. In this case, the action sent differs from original action.
|
||||
Thus, this function always returns the action actually sent.
|
||||
|
||||
Args:
|
||||
action: tensor containing the concatenated goal positions for the follower arms.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
"RoarmRobot is not connected. You need to run `robot.connect()`."
|
||||
)
|
||||
|
||||
from_idx = 0
|
||||
to_idx = 6
|
||||
action_sent = []
|
||||
for name in self.follower_arms:
|
||||
# Get goal position of each follower arm by splitting the action vector
|
||||
goal_pos = action[from_idx:to_idx]
|
||||
from_idx = to_idx
|
||||
|
||||
# Cap goal position when too far away from present position.
|
||||
# Slower fps expected due to reading from the follower.
|
||||
if self.config.max_relative_target is not None:
|
||||
present_pos = np.array(self.follower_arms[name].joints_angle_get(), dtype=np.float32)
|
||||
present_pos = torch.from_numpy(present_pos)
|
||||
goal_pos = ensure_safe_goal_position(goal_pos, present_pos, self.config.max_relative_target)
|
||||
|
||||
# Save tensor to concat and return
|
||||
action_sent.append(goal_pos)
|
||||
|
||||
# Send goal position to each follower
|
||||
goal_pos = goal_pos.numpy().astype(np.int32)
|
||||
goal_pos = goal_pos.tolist() if isinstance(goal_pos, (np.ndarray, torch.Tensor)) else goal_pos
|
||||
self.follower_arms[name].joints_angle_ctrl(angles=goal_pos, speed=0, acc=0)
|
||||
|
||||
return torch.cat(action_sent)
|
||||
|
||||
def print_logs(self):
|
||||
pass
|
||||
# TODO(aliberts): move robot-specific logs logic here
|
||||
|
||||
def disconnect(self):
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
"RoarmRobot is not connected. You need to run `robot.connect()` before disconnecting."
|
||||
)
|
||||
|
||||
for name in self.follower_arms:
|
||||
self.follower_arms[name].disconnect()
|
||||
|
||||
for name in self.leader_arms:
|
||||
self.leader_arms[name].disconnect()
|
||||
|
||||
for name in self.cameras:
|
||||
self.cameras[name].disconnect()
|
||||
|
||||
self.is_connected = False
|
||||
|
||||
def __del__(self):
|
||||
if getattr(self, "is_connected", False):
|
||||
self.disconnect()
|
|
@ -21,6 +21,7 @@ from lerobot.common.robot_devices.robots.configs import (
|
|||
LeKiwiRobotConfig,
|
||||
ManipulatorRobotConfig,
|
||||
MossRobotConfig,
|
||||
RoarmRobotConfig,
|
||||
RobotConfig,
|
||||
So100RobotConfig,
|
||||
StretchRobotConfig,
|
||||
|
@ -58,6 +59,8 @@ def make_robot_config(robot_type: str, **kwargs) -> RobotConfig:
|
|||
return MossRobotConfig(**kwargs)
|
||||
elif robot_type == "so100":
|
||||
return So100RobotConfig(**kwargs)
|
||||
elif robot_type == "roarm_m3":
|
||||
return RoarmRobotConfig(**kwargs)
|
||||
elif robot_type == "stretch":
|
||||
return StretchRobotConfig(**kwargs)
|
||||
elif robot_type == "lekiwi":
|
||||
|
@ -75,6 +78,10 @@ def make_robot_from_config(config: RobotConfig):
|
|||
from lerobot.common.robot_devices.robots.mobile_manipulator import MobileManipulator
|
||||
|
||||
return MobileManipulator(config)
|
||||
elif isinstance(config, RoarmRobotConfig):
|
||||
from lerobot.common.robot_devices.robots.roarm_m3 import RoarmRobot
|
||||
|
||||
return RoarmRobot(config)
|
||||
else:
|
||||
from lerobot.common.robot_devices.robots.stretch import StretchRobot
|
||||
|
||||
|
|
|
@ -73,6 +73,7 @@ dependencies = [
|
|||
"torchvision>=0.21.0",
|
||||
"wandb>=0.16.3",
|
||||
"zarr>=2.17.0",
|
||||
"sshkeyboard",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
@ -96,6 +97,7 @@ test = ["pytest>=8.1.0", "pytest-cov>=5.0.0", "pyserial>=3.5"]
|
|||
umi = ["imagecodecs>=2024.1.1"]
|
||||
video_benchmark = ["scikit-image>=0.23.2", "pandas>=2.2.2"]
|
||||
xarm = ["gym-xarm>=0.1.1 ; python_version < '4.0'"]
|
||||
roarm = ["roarm-sdk==0.0.11"]
|
||||
|
||||
[tool.poetry]
|
||||
requires-poetry = ">=2.1"
|
||||
|
|
Loading…
Reference in New Issue