Added a roarm_m3 for robotic arm adaptation, and modified the keyboard listening method to adapt to the screenless mode

This commit is contained in:
DUDULRX 2025-03-05 20:00:42 +08:00
parent d694ea1d38
commit 496eb6298e
18 changed files with 671 additions and 171 deletions

140
examples/12_use_roarm_m3.md Normal file
View File

@ -0,0 +1,140 @@
# Using the [roarm_m3](https://github.com/waveshareteam/roarm_m3) with LeRobot
## Table of Contents
- [A. Teleoperate](#a-teleoperate)
- [B. Record a dataset](#b-record-a-dataset)
- [C. Visualize a dataset](#c-visualize-a-dataset)
- [D. Replay an episode](#d-replay-an-episode)
- [E. Train a policy](#e-train-a-policy)
- [F. Evaluate your policy](#f-evaluate-your-policy)
- [G. More Information](#g-more-information)
## A. Teleoperate
**Simple teleop**
#### a. Teleop without displaying cameras
You will be able to teleoperate your robot! (it won't connect and display the cameras):
```bash
python lerobot/scripts/control_robot.py \
--robot.type=roarm_m3 \
--robot.cameras='{}' \
--control.type=teleoperate
```
#### b. Teleop with displaying cameras
You will be able to display the cameras while you are teleoperating by running the following code. This is useful to prepare your setup before recording your first dataset.
```bash
python lerobot/scripts/control_robot.py \
--robot.type=roarm_m3 \
--control.type=teleoperate
```
## B. Record a dataset
Once you're familiar with teleoperation, you can record your first dataset with roarm_m3.
If you want to use the Hugging Face hub features for uploading your dataset and you haven't previously done it, make sure you've logged in using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens):
```bash
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
```
Store your Hugging Face repository name in a variable to run these commands:
```bash
HF_USER=$(huggingface-cli whoami | head -n 1)
echo $HF_USER
```
Record 2 episodes and upload your dataset to the hub:
```bash
python lerobot/scripts/control_robot.py \
--robot.type=roarm_m3 \
--control.type=record \
--control.fps=30 \
--control.single_task="Grasp a block and put it in the bin." \
--control.repo_id=${HF_USER}/roarm_m3_test \
--control.tags='["roarm_m3","tutorial"]' \
--control.warmup_time_s=5 \
--control.episode_time_s=30 \
--control.reset_time_s=30 \
--control.num_episodes=2 \
--control.push_to_hub=true
```
Note: You can resume recording by adding `--control.resume=true`.
## C. Visualize a dataset
If you uploaded your dataset to the hub with `--control.push_to_hub=true`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by:
```bash
echo ${HF_USER}/roarm_m3_test
```
If you didn't upload with `--control.push_to_hub=false`, you can also visualize it locally with (a window can be opened in the browser `http://ip:9090` with the visualization tool):
```bash
python lerobot/scripts/visualize_dataset_html.py \
--repo-id ${HF_USER}/roarm_m3_test \
--host ip \
--local-files-only 1
```
## D. Replay an episode
Now try to replay episode nth on your bot:
```bash
python lerobot/scripts/control_robot.py \
--robot.type=roarm_m3 \
--control.type=replay \
--control.fps=30 \
--control.repo_id=${HF_USER}/roarm_m3_test \
--control.episode=n-1
```
## E. Train a policy
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
```bash
python lerobot/scripts/train.py \
--dataset.repo_id=${HF_USER}/roarm_m3_test \
--policy.type=act \
--output_dir=outputs/train/act_roarm_m3_test \
--job_name=act_roarm_m3_test \
--device=cuda \
--wandb.enable=true
```
Let's explain it:
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/roarm_m3_test`.
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you could use `device=mps` to train on Apple silicon.
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
Training should take several hours. You will find checkpoints in `outputs/train/act_roarm_m3_test/checkpoints`.
## F. Evaluate your policy
You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes:
```bash
python lerobot/scripts/control_robot.py \
--robot.type=roarm_m3 \
--control.type=record \
--control.fps=30 \
--control.single_task="Grasp a block and put it in the bin." \
--control.repo_id=${HF_USER}/eval_act_roarm_m3_test \
--control.tags='["tutorial"]' \
--control.warmup_time_s=5 \
--control.episode_time_s=30 \
--control.reset_time_s=30 \
--control.num_episodes=10 \
--control.push_to_hub=true \
--control.policy.path=outputs/train/act_roarm_m3_test/checkpoints/last/pretrained_model
```
As you can see, it's almost the same command as previously used to record your training dataset. Two things changed:
1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_roarm_m3_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_roarm_m3_test`).
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_roarm_m3_test`).
## G. More Information
Follow this [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) for a more in-depth tutorial on controlling real robots with LeRobot.

View File

@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import logging
import shutil
from pathlib import Path
@ -28,7 +27,6 @@ import torch.utils
from datasets import concatenate_datasets, load_dataset
from huggingface_hub import HfApi, snapshot_download
from huggingface_hub.constants import REPOCARD_NAME
from huggingface_hub.errors import RevisionNotFoundError
from lerobot.common.constants import HF_LEROBOT_HOME
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
@ -314,7 +312,7 @@ class LeRobotDatasetMetadata:
obj.repo_id = repo_id
obj.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
obj.root.mkdir(parents=True, exist_ok=False)
obj.root.mkdir(parents=True, exist_ok=True)
if robot is not None:
features = get_features_from_robot(robot, use_videos)
@ -480,7 +478,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.image_writer = None
self.episode_buffer = None
self.root.mkdir(exist_ok=True, parents=True)
self.root.mkdir(exist_ok=False, parents=True)
# Load metadata
self.meta = LeRobotDatasetMetadata(
@ -519,7 +517,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
branch: str | None = None,
tags: list | None = None,
license: str | None = "apache-2.0",
tag_version: bool = True,
push_videos: bool = True,
private: bool = False,
allow_patterns: list[str] | str | None = None,
@ -565,11 +562,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
)
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)
if tag_version:
with contextlib.suppress(RevisionNotFoundError):
hub_api.delete_tag(self.repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
hub_api.create_tag(self.repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
def pull_from_repo(
self,
allow_patterns: list[str] | str | None = None,

View File

@ -31,7 +31,6 @@ import packaging.version
import torch
from datasets.table import embed_table_storage
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
from huggingface_hub.errors import RevisionNotFoundError
from PIL import Image as PILImage
from torchvision import transforms
@ -326,19 +325,6 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) ->
)
hub_versions = get_repo_versions(repo_id)
if not hub_versions:
raise RevisionNotFoundError(
f"""Your dataset must be tagged with a codebase version.
Assuming _version_ is the codebase_version value in the info.json, you can run this:
```python
from huggingface_hub import HfApi
hub_api = HfApi()
hub_api.create_tag("{repo_id}", tag="_version_", repo_type="dataset")
```
"""
)
if target_version in hub_versions:
return f"v{target_version}"

View File

@ -166,6 +166,9 @@ def parse_robot_config(robot_cfg: RobotConfig) -> tuple[str, dict]:
for arm in robot_cfg.leader_arms
for motor in robot_cfg.leader_arms[arm].motors
]
elif robot_cfg.type == "roarm_m3":
state_names = ["roam_m3","roam_m3","roam_m3","roam_m3","roam_m3","roam_m3"]
action_names = ["roam_m3","roam_m3","roam_m3","roam_m3","roam_m3","roam_m3"]
# elif robot_cfg["robot_type"] == "stretch3": TODO
else:
raise NotImplementedError(

View File

@ -57,7 +57,7 @@ def convert_dataset(
dataset.meta.info["codebase_version"] = CODEBASE_VERSION
write_info(dataset.meta.info, dataset.root)
dataset.push_to_hub(branch=branch, tag_version=False, allow_patterns="meta/")
dataset.push_to_hub(branch=branch, allow_patterns="meta/")
# delete old stats.json file
if (dataset.root / STATS_PATH).is_file:

View File

@ -313,7 +313,7 @@ class PI0Policy(PreTrainedPolicy):
state = self.prepare_state(batch)
lang_tokens, lang_masks = self.prepare_language(batch)
actions = self.prepare_action(batch)
actions_is_pad = batch.get("actions_is_pad")
actions_is_pad = batch.get("actions_id_pad")
loss_dict = {}
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)

View File

@ -30,6 +30,20 @@ from lerobot.common.utils.utils import capture_timestamp_utc
# treat the same cameras as new devices. Thus we select a higher bound to search indices.
MAX_OPENCV_INDEX = 60
undistort = True
def undistort_image(image):
import cv2
camera_matrix = np.array([
[289.11451, 0., 347.23664],
[0., 289.75319, 235.67429],
[0., 0., 1.]
])
dist_coeffs = np.array([-0.208848, 0.028006, -0.000705, -0.000820, 0.000000])
undistorted_image = cv2.undistort(image, camera_matrix, dist_coeffs)
return undistorted_image
def find_cameras(raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX, mock=False) -> list[dict]:
cameras = []
@ -368,6 +382,9 @@ class OpenCVCamera:
color_image = cv2.cvtColor(color_image, cv2.COLOR_BGR2RGB)
if undistort:
color_image = undistort_image(color_image)
h, w, _ = color_image.shape
if h != self.height or w != self.width:
raise OSError(

View File

@ -31,7 +31,7 @@ def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> list[C
cameras[key] = IntelRealSenseCamera(cfg)
else:
raise ValueError(f"The camera type '{cfg.type}' is not valid.")
raise ValueError(f"The motor type '{cfg.type}' is not valid.")
return cameras

View File

@ -113,46 +113,73 @@ def predict_action(observation, policy, device, use_amp):
return action
# def init_keyboard_listener():
# # Allow to exit early while recording an episode or resetting the environment,
# # by tapping the right arrow key '->'. This might require a sudo permission
# # to allow your terminal to monitor keyboard events.
# events = {}
# events["exit_early"] = False
# events["rerecord_episode"] = False
# events["stop_recording"] = False
# if is_headless():
# logging.warning(
# "Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
# )
# listener = None
# return listener, events
# # Only import pynput if not in a headless environment
# from pynput import keyboard
# def on_press(key):
# try:
# if key == keyboard.Key.right:
# print("Right arrow key pressed. Exiting loop...")
# events["exit_early"] = True
# elif key == keyboard.Key.left:
# print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
# events["rerecord_episode"] = True
# events["exit_early"] = True
# elif key == keyboard.Key.esc:
# print("Escape key pressed. Stopping data recording...")
# events["stop_recording"] = True
# events["exit_early"] = True
# except Exception as e:
# print(f"Error handling key press: {e}")
# listener = keyboard.Listener(on_press=on_press)
# listener.start()
# return listener, events
def init_keyboard_listener():
# Allow to exit early while recording an episode or resetting the environment,
# by tapping the right arrow key '->'. This might require a sudo permission
# to allow your terminal to monitor keyboard events.
events = {}
events["exit_early"] = False
events["rerecord_episode"] = False
events["stop_recording"] = False
if is_headless():
logging.warning(
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
)
listener = None
return listener, events
# Only import pynput if not in a headless environment
from pynput import keyboard
from sshkeyboard import listen_keyboard
import threading
def on_press(key):
try:
if key == keyboard.Key.right:
if key == "right":
print("Right arrow key pressed. Exiting loop...")
events["exit_early"] = True
elif key == keyboard.Key.left:
elif key == "left":
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
events["rerecord_episode"] = True
events["exit_early"] = True
elif key == keyboard.Key.esc:
print("Escape key pressed. Stopping data recording...")
elif key == "q":
print("Q key pressed. Stopping data recording...")
events["stop_recording"] = True
events["exit_early"] = True
except Exception as e:
print(f"Error handling key press: {e}")
listener = keyboard.Listener(on_press=on_press)
listener = threading.Thread(target=listen_keyboard, kwargs={"on_press": on_press})
listener.start()
return listener, events
return listener,events
def warmup_record(
robot,
@ -256,7 +283,8 @@ def control_loop(
frame = {**observation, **action, "task": single_task}
dataset.add_frame(frame)
if display_cameras and not is_headless():
if display_cameras:
# if display_cameras and not is_headless():
image_keys = [key for key in observation if "image" in key]
for key in image_keys:
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
@ -288,17 +316,25 @@ def reset_environment(robot, events, reset_time_s, fps):
teleoperate=True,
)
# def stop_recording(robot, listener, display_cameras):
# robot.disconnect()
# if not is_headless():
# if listener is not None:
# listener.stop()
# if display_cameras:
# cv2.destroyAllWindows()
def stop_recording(robot, listener, display_cameras):
robot.disconnect()
if not is_headless():
if listener is not None:
listener.stop()
if display_cameras:
cv2.destroyAllWindows()
from sshkeyboard import stop_listening
if listener is not None:
stop_listening()
if display_cameras:
cv2.destroyAllWindows()
def sanity_check_dataset_name(repo_id, policy_cfg):
_, dataset_name = repo_id.split("/")

View File

@ -479,6 +479,45 @@ class So100RobotConfig(ManipulatorRobotConfig):
mock: bool = False
@RobotConfig.register_subclass("roarm_m3")
@dataclass
class RoarmRobotConfig(RobotConfig):
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: int | None = None
leader_arms: dict[str, str] = field(
default_factory=lambda: {
"main": "/dev/ttyUSB0",
}
)
# Follower arms configuration: left and right ports
follower_arms: dict[str, str] = field(
default_factory=lambda: {
"main": "/dev/ttyUSB1",
}
)
cameras: dict[str, CameraConfig] = field(
default_factory=lambda: {
"laptop": OpenCVCameraConfig(
camera_index=0,
fps=30,
width=640,
height=480,
),
"phone": OpenCVCameraConfig(
camera_index=1,
fps=30,
width=640,
height=480,
),
}
)
mock: bool = False
@RobotConfig.register_subclass("stretch")
@dataclass
@ -515,7 +554,6 @@ class StretchRobotConfig(RobotConfig):
mock: bool = False
@RobotConfig.register_subclass("lekiwi")
@dataclass
class LeKiwiRobotConfig(RobotConfig):

View File

@ -392,19 +392,21 @@ class MobileManipulator:
for name in self.leader_arms:
pos = self.leader_arms[name].read("Present_Position")
pos_tensor = torch.from_numpy(pos).float()
# Instead of pos_tensor.item(), use tolist() to convert the entire tensor to a list
arm_positions.extend(pos_tensor.tolist())
y_cmd = 0.0 # m/s forward/backward
x_cmd = 0.0 # m/s lateral
# (The rest of your code for generating wheel commands remains unchanged)
x_cmd = 0.0 # m/s forward/backward
y_cmd = 0.0 # m/s lateral
theta_cmd = 0.0 # deg/s rotation
if self.pressed_keys["forward"]:
y_cmd += xy_speed
if self.pressed_keys["backward"]:
y_cmd -= xy_speed
if self.pressed_keys["left"]:
x_cmd += xy_speed
if self.pressed_keys["right"]:
if self.pressed_keys["backward"]:
x_cmd -= xy_speed
if self.pressed_keys["left"]:
y_cmd += xy_speed
if self.pressed_keys["right"]:
y_cmd -= xy_speed
if self.pressed_keys["rotate_left"]:
theta_cmd += theta_speed
if self.pressed_keys["rotate_right"]:
@ -582,8 +584,8 @@ class MobileManipulator:
# Create the body velocity vector [x, y, theta_rad].
velocity_vector = np.array([x_cmd, y_cmd, theta_rad])
# Define the wheel mounting angles (defined from y axis cw)
angles = np.radians(np.array([300, 180, 60]))
# Define the wheel mounting angles with a -90° offset.
angles = np.radians(np.array([240, 120, 0]) - 90)
# Build the kinematic matrix: each row maps body velocities to a wheels linear speed.
# The third column (base_radius) accounts for the effect of rotation.
m = np.array([[np.cos(a), np.sin(a), base_radius] for a in angles])
@ -639,8 +641,8 @@ class MobileManipulator:
# Compute each wheels linear speed (m/s) from its angular speed.
wheel_linear_speeds = wheel_radps * wheel_radius
# Define the wheel mounting angles (defined from y axis cw)
angles = np.radians(np.array([300, 180, 60]))
# Define the wheel mounting angles with a -90° offset.
angles = np.radians(np.array([240, 120, 0]) - 90)
m = np.array([[np.cos(a), np.sin(a), base_radius] for a in angles])
# Solve the inverse kinematics: body_velocity = M⁻¹ · wheel_linear_speeds.

View File

@ -0,0 +1,306 @@
"""Contains logic to instantiate a robot, read information from its motors and cameras,
and send orders to its motors.
"""
# TODO(rcadene, aliberts): reorganize the codebase into one file per robot, with the associated
# calibration procedure, to make it easy for people to add their own robot.
import json
import logging
import time
import warnings
from pathlib import Path
import numpy as np
import torch
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
from lerobot.common.robot_devices.robots.configs import RoarmRobotConfig
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
from roarm_sdk.roarm import roarm
def ensure_safe_goal_position(
goal_pos: torch.Tensor, present_pos: torch.Tensor, max_relative_target: float | list[float]
):
# Cap relative action target magnitude for safety.
diff = goal_pos - present_pos
max_relative_target = torch.tensor(max_relative_target)
safe_diff = torch.minimum(diff, max_relative_target)
safe_diff = torch.maximum(safe_diff, -max_relative_target)
safe_goal_pos = present_pos + safe_diff
if not torch.allclose(goal_pos, safe_goal_pos):
logging.warning(
"Relative goal position magnitude had to be clamped to be safe.\n"
f" requested relative goal position target: {diff}\n"
f" clamped relative goal position target: {safe_diff}"
)
return safe_goal_pos
def make_roarm_from_configs(configs: dict[str, str]) -> (dict[str, roarm]):
roarms = {}
for key, port in configs.items():
roarms[key] = roarm(roarm_type="roarm_m3", port=port, baudrate=115200)
return roarms
class RoarmRobot:
def __init__(
self,
config: RoarmRobotConfig,
):
self.config = config
self.robot_type = self.config.type
self.leader_arms = make_roarm_from_configs(self.config.leader_arms)
self.follower_arms = make_roarm_from_configs(self.config.follower_arms)
self.cameras = make_cameras_from_configs(self.config.cameras)
self.is_connected = False
self.logs = {}
@property
def camera_features(self) -> dict:
cam_ft = {}
for cam_key, cam in self.cameras.items():
key = f"observation.images.{cam_key}"
cam_ft[key] = {
"shape": (cam.height, cam.width, cam.channels),
"names": ["height", "width", "channels"],
"info": None,
}
return cam_ft
@property
def motor_features(self) -> dict:
return {
"action": {
"dtype": "float32",
"shape": (6,),
"names": "roam_m3",
},
"observation.state": {
"dtype": "float32",
"shape": (6,),
"names": "roam_m3",
},
}
@property
def features(self):
return {**self.motor_features, **self.camera_features}
@property
def has_camera(self):
return len(self.cameras) > 0
@property
def num_cameras(self):
return len(self.cameras)
def connect(self):
if self.is_connected:
raise RobotDeviceAlreadyConnectedError(
"RoarmRobot is already connected. Do not run `robot.connect()` twice."
)
if not self.leader_arms and not self.follower_arms and not self.cameras:
raise ValueError(
"RoarmRobot doesn't have any device to connect. See example of usage in docstring of the class."
)
for name in self.follower_arms:
self.follower_arms[name].joints_angle_get()
for name in self.leader_arms:
self.leader_arms[name].joints_angle_get()
# Connect the cameras
for name in self.cameras:
self.cameras[name].connect()
self.is_connected = True
def teleop_step(
self, record_data=False
) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
if not self.is_connected:
raise RobotDeviceNotConnectedError(
"RoarmRobot is not connected. You need to run `robot.connect()`."
)
# Prepare to assign the position of the leader to the follower
leader_pos = {}
for name in self.leader_arms:
before_lread_t = time.perf_counter()
leader_pos[name] = np.array(self.leader_arms[name].joints_angle_get(), dtype=np.float32)
leader_pos[name] = torch.from_numpy(leader_pos[name])
self.logs[f"read_leader_{name}_pos_dt_s"] = time.perf_counter() - before_lread_t
# Send goal position to the follower
follower_goal_pos = {}
for name in self.follower_arms:
before_fwrite_t = time.perf_counter()
goal_pos = leader_pos[name]
# Cap goal position when too far away from present position.
# Slower fps expected due to reading from the follower.
if self.config.max_relative_target is not None:
present_pos = np.array(self.follower_arms[name].joints_angle_get(), dtype=np.float32)
present_pos = torch.from_numpy(present_pos)
goal_pos = ensure_safe_goal_position(goal_pos, present_pos, self.config.max_relative_target)
# Used when record_data=True
follower_goal_pos[name] = goal_pos
goal_pos = goal_pos.numpy().astype(np.int32)
goal_pos = goal_pos.tolist() if isinstance(goal_pos, (np.ndarray, torch.Tensor)) else goal_pos
self.follower_arms[name].joints_angle_ctrl(angles=goal_pos, speed=0, acc=0)
self.logs[f"write_follower_{name}_goal_pos_dt_s"] = time.perf_counter() - before_fwrite_t
# Early exit when recording data is not requested
if not record_data:
return
# TODO(rcadene): Add velocity and other info
# Read follower position
follower_pos = {}
for name in self.follower_arms:
before_fread_t = time.perf_counter()
follower_pos[name] = np.array(self.follower_arms[name].joints_angle_get(), dtype=np.float32)
follower_pos[name] = torch.from_numpy(follower_pos[name])
self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - before_fread_t
# Create state by concatenating follower current position
state = []
for name in self.follower_arms:
if name in follower_pos:
state.append(follower_pos[name])
state = torch.cat(state)
# Create action by concatenating follower goal position
action = []
for name in self.follower_arms:
if name in follower_goal_pos:
action.append(follower_goal_pos[name])
action = torch.cat(action)
# Capture images from cameras
images = {}
for name in self.cameras:
before_camread_t = time.perf_counter()
images[name] = self.cameras[name].async_read()
images[name] = torch.from_numpy(images[name])
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
# Populate output dictionaries
obs_dict, action_dict = {}, {}
obs_dict["observation.state"] = state
action_dict["action"] = action
for name in self.cameras:
obs_dict[f"observation.images.{name}"] = images[name]
return obs_dict, action_dict
def capture_observation(self):
"""The returned observations do not have a batch dimension."""
if not self.is_connected:
raise RobotDeviceNotConnectedError(
"RoarmRobot is not connected. You need to run `robot.connect()`."
)
# Read follower position
follower_pos = {}
for name in self.follower_arms:
before_fread_t = time.perf_counter()
follower_pos[name] = np.array(self.follower_arms[name].joints_angle_get(), dtype=np.float32)
follower_pos[name] = torch.from_numpy(follower_pos[name])
self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - before_fread_t
# Create state by concatenating follower current position
state = []
for name in self.follower_arms:
if name in follower_pos:
state.append(follower_pos[name])
state = torch.cat(state)
# Capture images from cameras
images = {}
for name in self.cameras:
before_camread_t = time.perf_counter()
images[name] = self.cameras[name].async_read()
images[name] = torch.from_numpy(images[name])
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
# Populate output dictionaries and format to pytorch
obs_dict = {}
obs_dict["observation.state"] = state
for name in self.cameras:
obs_dict[f"observation.images.{name}"] = images[name]
return obs_dict
def send_action(self, action: torch.Tensor) -> torch.Tensor:
"""Command the follower arms to move to a target joint configuration.
The relative action magnitude may be clipped depending on the configuration parameter
`max_relative_target`. In this case, the action sent differs from original action.
Thus, this function always returns the action actually sent.
Args:
action: tensor containing the concatenated goal positions for the follower arms.
"""
if not self.is_connected:
raise RobotDeviceNotConnectedError(
"RoarmRobot is not connected. You need to run `robot.connect()`."
)
from_idx = 0
to_idx = 6
action_sent = []
for name in self.follower_arms:
# Get goal position of each follower arm by splitting the action vector
goal_pos = action[from_idx:to_idx]
from_idx = to_idx
# Cap goal position when too far away from present position.
# Slower fps expected due to reading from the follower.
if self.config.max_relative_target is not None:
present_pos = np.array(self.follower_arms[name].joints_angle_get(), dtype=np.float32)
present_pos = torch.from_numpy(present_pos)
goal_pos = ensure_safe_goal_position(goal_pos, present_pos, self.config.max_relative_target)
# Save tensor to concat and return
action_sent.append(goal_pos)
# Send goal position to each follower
goal_pos = goal_pos.numpy().astype(np.int32)
goal_pos = goal_pos.tolist() if isinstance(goal_pos, (np.ndarray, torch.Tensor)) else goal_pos
self.follower_arms[name].joints_angle_ctrl(angles=goal_pos, speed=0, acc=0)
return torch.cat(action_sent)
def print_logs(self):
pass
# TODO(aliberts): move robot-specific logs logic here
def disconnect(self):
if not self.is_connected:
raise RobotDeviceNotConnectedError(
"RoarmRobot is not connected. You need to run `robot.connect()` before disconnecting."
)
for name in self.follower_arms:
self.follower_arms[name].disconnect()
for name in self.leader_arms:
self.leader_arms[name].disconnect()
for name in self.cameras:
self.cameras[name].disconnect()
self.is_connected = False
def __del__(self):
if getattr(self, "is_connected", False):
self.disconnect()

View File

@ -9,6 +9,7 @@ from lerobot.common.robot_devices.robots.configs import (
MossRobotConfig,
RobotConfig,
So100RobotConfig,
RoarmRobotConfig,
StretchRobotConfig,
)
@ -44,6 +45,8 @@ def make_robot_config(robot_type: str, **kwargs) -> RobotConfig:
return MossRobotConfig(**kwargs)
elif robot_type == "so100":
return So100RobotConfig(**kwargs)
elif robot_type == "roarm_m3":
return RoarmRobotConfig(**kwargs)
elif robot_type == "stretch":
return StretchRobotConfig(**kwargs)
elif robot_type == "lekiwi":
@ -61,12 +64,15 @@ def make_robot_from_config(config: RobotConfig):
from lerobot.common.robot_devices.robots.mobile_manipulator import MobileManipulator
return MobileManipulator(config)
elif isinstance(config, RoarmRobotConfig):
from lerobot.common.robot_devices.robots.roarm_m3 import RoarmRobot
return RoarmRobot(config)
else:
from lerobot.common.robot_devices.robots.stretch import StretchRobot
return StretchRobot(config)
def make_robot(robot_type: str, **kwargs) -> Robot:
config = make_robot_config(robot_type, **kwargs)
return make_robot_from_config(config)

View File

@ -17,7 +17,6 @@ import logging
import os
import os.path as osp
import platform
import subprocess
from copy import copy
from datetime import datetime, timezone
from pathlib import Path
@ -166,31 +165,23 @@ def capture_timestamp_utc():
def say(text, blocking=False):
system = platform.system()
if system == "Darwin":
cmd = ["say", text]
elif system == "Linux":
cmd = ["spd-say", text]
# Check if mac, linux, or windows.
if platform.system() == "Darwin":
cmd = f'say "{text}"'
if not blocking:
cmd += " &"
elif platform.system() == "Linux":
cmd = f'spd-say "{text}"'
if blocking:
cmd.append("--wait")
cmd += " --wait"
elif platform.system() == "Windows":
# TODO(rcadene): Make blocking option work for Windows
cmd = (
'PowerShell -Command "Add-Type -AssemblyName System.Speech; '
f"(New-Object System.Speech.Synthesis.SpeechSynthesizer).Speak('{text}')\""
)
elif system == "Windows":
cmd = [
"PowerShell",
"-Command",
"Add-Type -AssemblyName System.Speech; "
f"(New-Object System.Speech.Synthesis.SpeechSynthesizer).Speak('{text}')",
]
else:
raise RuntimeError("Unsupported operating system for text-to-speech.")
if blocking:
subprocess.run(cmd, check=True)
else:
subprocess.Popen(cmd, creationflags=subprocess.CREATE_NO_WINDOW if system == "Windows" else 0)
os.system(cmd)
def log_say(text, play_sounds, blocking=False):

View File

@ -259,6 +259,7 @@ def record(
if not robot.is_connected:
robot.connect()
#listener, events = init_keyboard_listener()
listener, events = init_keyboard_listener()
# Execute a few seconds without recording to:
@ -314,7 +315,9 @@ def record(
if events["stop_recording"]:
break
log_say("Stop recording", cfg.play_sounds, blocking=True)
# log_say("Stop recording", cfg.play_sounds, blocking=True)
log_say("Stop recording", cfg.play_sounds)
stop_recording(robot, listener, cfg.display_cameras)
if cfg.push_to_hub:

View File

@ -454,7 +454,7 @@ def _compile_episode_data(
@parser.wrap()
def eval_main(cfg: EvalPipelineConfig):
def eval(cfg: EvalPipelineConfig):
logging.info(pformat(asdict(cfg)))
# Check device is available
@ -499,4 +499,4 @@ def eval_main(cfg: EvalPipelineConfig):
if __name__ == "__main__":
init_logging()
eval_main()
eval()

View File

@ -158,7 +158,7 @@ def run_server(
if major_version < 2:
return "Make sure to convert your LeRobotDataset to v2 & above."
episode_data_csv_str, columns, ignored_columns = get_episode_data(dataset, episode_id)
episode_data_csv_str, columns = get_episode_data(dataset, episode_id)
dataset_info = {
"repo_id": f"{dataset_namespace}/{dataset_name}",
"num_samples": dataset.num_frames
@ -194,7 +194,7 @@ def run_server(
]
response = requests.get(
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/episodes.jsonl", timeout=5
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/episodes.jsonl"
)
response.raise_for_status()
# Split into lines and parse each line as JSON
@ -218,7 +218,6 @@ def run_server(
videos_info=videos_info,
episode_data_csv_str=episode_data_csv_str,
columns=columns,
ignored_columns=ignored_columns,
)
app.run(host=host, port=port)
@ -237,14 +236,6 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
selected_columns = [col for col, ft in dataset.features.items() if ft["dtype"] == "float32"]
selected_columns.remove("timestamp")
ignored_columns = []
for column_name in selected_columns:
shape = dataset.features[column_name]["shape"]
shape_dim = len(shape)
if shape_dim > 1:
selected_columns.remove(column_name)
ignored_columns.append(column_name)
# init header of csv with state and action names
header = ["timestamp"]
@ -300,7 +291,7 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
csv_writer.writerows(rows)
csv_string = csv_buffer.getvalue()
return csv_string, columns, ignored_columns
return csv_string, columns
def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]:
@ -327,9 +318,7 @@ def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) ->
def get_dataset_info(repo_id: str) -> IterableNamespace:
response = requests.get(
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json", timeout=5
)
response = requests.get(f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json")
response.raise_for_status() # Raises an HTTPError for bad responses
dataset_info = response.json()
dataset_info["repo_id"] = repo_id

View File

@ -42,22 +42,22 @@
<ul>
<template x-for="episode in paginatedEpisodes" :key="episode">
<li class="font-mono text-sm mt-0.5">
<a :href="'episode_' + episode"
<a :href="'episode_' + episode"
:class="{'underline': true, 'font-bold -ml-1': episode == {{ episode_id }}}"
x-text="'Episode ' + episode"></a>
</li>
</template>
</ul>
<div class="flex items-center mt-3 text-xs" x-show="totalPages > 1">
<button @click="prevPage()"
<button @click="prevPage()"
class="px-2 py-1 bg-slate-800 rounded mr-2"
:class="{'opacity-50 cursor-not-allowed': page === 1}"
:disabled="page === 1">
&laquo; Prev
</button>
<span class="font-mono mr-2" x-text="` ${page} / ${totalPages}`"></span>
<button @click="nextPage()"
<button @click="nextPage()"
class="px-2 py-1 bg-slate-800 rounded"
:class="{'opacity-50 cursor-not-allowed': page === totalPages}"
:disabled="page === totalPages">
@ -65,10 +65,10 @@
</button>
</div>
</div>
<!-- episodes menu for small screens -->
<div class="flex overflow-x-auto md:hidden" x-data="episodePagination">
<button @click="prevPage()"
<button @click="prevPage()"
class="px-2 bg-slate-800 rounded mr-2"
:class="{'opacity-50 cursor-not-allowed': page === 1}"
:disabled="page === 1">&laquo;</button>
@ -83,7 +83,7 @@
</p>
</template>
</div>
<button @click="nextPage()"
<button @click="nextPage()"
class="px-2 bg-slate-800 rounded ml-2"
:class="{'opacity-50 cursor-not-allowed': page === totalPages}"
:disabled="page === totalPages">&raquo; </button>
@ -224,58 +224,49 @@
</p>
</div>
<div>
<table class="text-sm border-collapse border border-slate-700" x-show="currentFrameData">
<thead>
<tr>
<th></th>
<template x-for="(_, colIndex) in Array.from({length: columns.length}, (_, index) => index)">
<th class="border border-slate-700">
<div class="flex gap-x-2 justify-between px-2">
<input type="checkbox" :checked="isColumnChecked(colIndex)"
@change="toggleColumn(colIndex)">
<p x-text="`${columns[colIndex].key}`"></p>
</div>
</th>
</template>
</tr>
</thead>
<tbody>
<template x-for="(row, rowIndex) in rows">
<tr class="odd:bg-gray-800 even:bg-gray-900">
<td class="border border-slate-700">
<div class="flex gap-x-2 max-w-64 font-semibold px-1 break-all">
<input type="checkbox" :checked="isRowChecked(rowIndex)"
@change="toggleRow(rowIndex)">
<table class="text-sm border-collapse border border-slate-700" x-show="currentFrameData">
<thead>
<tr>
<th></th>
<template x-for="(_, colIndex) in Array.from({length: columns.length}, (_, index) => index)">
<th class="border border-slate-700">
<div class="flex gap-x-2 justify-between px-2">
<input type="checkbox" :checked="isColumnChecked(colIndex)"
@change="toggleColumn(colIndex)">
<p x-text="`${columns[colIndex].key}`"></p>
</div>
</th>
</template>
</tr>
</thead>
<tbody>
<template x-for="(row, rowIndex) in rows">
<tr class="odd:bg-gray-800 even:bg-gray-900">
<td class="border border-slate-700">
<div class="flex gap-x-2 max-w-64 font-semibold px-1 break-all">
<input type="checkbox" :checked="isRowChecked(rowIndex)"
@change="toggleRow(rowIndex)">
</div>
</td>
<template x-for="(cell, colIndex) in row">
<td x-show="cell" class="border border-slate-700">
<div class="flex gap-x-2 justify-between px-2" :class="{ 'hidden': cell.isNull }">
<div class="flex gap-x-2">
<input type="checkbox" x-model="cell.checked" @change="updateTableValues()">
<span x-text="`${!cell.isNull ? cell.label : null}`"></span>
</div>
<span class="w-14 text-right" x-text="`${!cell.isNull ? (typeof cell.value === 'number' ? cell.value.toFixed(2) : cell.value) : null}`"
:style="`color: ${cell.color}`"></span>
</div>
</td>
<template x-for="(cell, colIndex) in row">
<td x-show="cell" class="border border-slate-700">
<div class="flex gap-x-2 justify-between px-2" :class="{ 'hidden': cell.isNull }">
<div class="flex gap-x-2">
<input type="checkbox" x-model="cell.checked" @change="updateTableValues()">
<span x-text="`${!cell.isNull ? cell.label : null}`"></span>
</div>
<span class="w-14 text-right" x-text="`${!cell.isNull ? (typeof cell.value === 'number' ? cell.value.toFixed(2) : cell.value) : null}`"
:style="`color: ${cell.color}`"></span>
</div>
</td>
</template>
</tr>
</template>
</tbody>
</table>
</template>
</tr>
</template>
</tbody>
</table>
<div id="labels" class="hidden">
</div>
{% if ignored_columns|length > 0 %}
<div class="m-2 text-orange-700 max-w-96">
Columns {{ ignored_columns }} are NOT shown since the visualizer currently does not support 2D or 3D data.
</div>
{% endif %}
<div id="labels" class="hidden">
</div>
</div>
</div>
@ -485,7 +476,7 @@
episodes: {{ episodes }},
pageSize: 100,
page: 1,
init() {
// Find which page contains the current episode_id
const currentEpisodeId = {{ episode_id }};
@ -494,23 +485,23 @@
this.page = Math.floor(episodeIndex / this.pageSize) + 1;
}
},
get totalPages() {
return Math.ceil(this.episodes.length / this.pageSize);
},
get paginatedEpisodes() {
const start = (this.page - 1) * this.pageSize;
const end = start + this.pageSize;
return this.episodes.slice(start, end);
},
nextPage() {
if (this.page < this.totalPages) {
this.page++;
}
},
prevPage() {
if (this.page > 1) {
this.page--;
@ -524,7 +515,7 @@
window.addEventListener('keydown', (e) => {
// Use the space bar to play and pause, instead of default action (e.g. scrolling)
const { keyCode, key } = e;
if (keyCode === 32 || key === ' ') {
e.preventDefault();
const btnPause = document.querySelector('[x-ref="btnPause"]');