From 496eb6298e6174fd56d82d59973a782c1c758adf Mon Sep 17 00:00:00 2001 From: DUDULRX <162013353+DUDULRX@users.noreply.github.com> Date: Wed, 5 Mar 2025 20:00:42 +0800 Subject: [PATCH] Added a roarm_m3 for robotic arm adaptation, and modified the keyboard listening method to adapt to the screenless mode --- examples/12_use_roarm_m3.md | 140 ++++++++ lerobot/common/datasets/lerobot_dataset.py | 12 +- lerobot/common/datasets/utils.py | 14 - .../datasets/v2/convert_dataset_v1_to_v2.py | 3 + .../v21/convert_dataset_v20_to_v21.py | 2 +- lerobot/common/policies/pi0/modeling_pi0.py | 2 +- .../common/robot_devices/cameras/opencv.py | 17 + lerobot/common/robot_devices/cameras/utils.py | 2 +- lerobot/common/robot_devices/control_utils.py | 92 ++++-- .../common/robot_devices/robots/configs.py | 40 ++- .../robots/mobile_manipulator.py | 24 +- .../common/robot_devices/robots/roarm_m3.py | 306 ++++++++++++++++++ lerobot/common/robot_devices/robots/utils.py | 8 +- lerobot/common/utils/utils.py | 39 +-- lerobot/scripts/control_robot.py | 5 +- lerobot/scripts/eval.py | 4 +- lerobot/scripts/visualize_dataset_html.py | 19 +- .../templates/visualize_dataset_template.html | 113 +++---- 18 files changed, 671 insertions(+), 171 deletions(-) create mode 100644 examples/12_use_roarm_m3.md create mode 100644 lerobot/common/robot_devices/robots/roarm_m3.py diff --git a/examples/12_use_roarm_m3.md b/examples/12_use_roarm_m3.md new file mode 100644 index 00000000..eabeaf8a --- /dev/null +++ b/examples/12_use_roarm_m3.md @@ -0,0 +1,140 @@ +# Using the [roarm_m3](https://github.com/waveshareteam/roarm_m3) with LeRobot + +## Table of Contents + + - [A. Teleoperate](#a-teleoperate) + - [B. Record a dataset](#b-record-a-dataset) + - [C. Visualize a dataset](#c-visualize-a-dataset) + - [D. Replay an episode](#d-replay-an-episode) + - [E. Train a policy](#e-train-a-policy) + - [F. Evaluate your policy](#f-evaluate-your-policy) + - [G. More Information](#g-more-information) + +## A. Teleoperate + +**Simple teleop** +#### a. Teleop without displaying cameras +You will be able to teleoperate your robot! (it won't connect and display the cameras): +```bash +python lerobot/scripts/control_robot.py \ + --robot.type=roarm_m3 \ + --robot.cameras='{}' \ + --control.type=teleoperate +``` + + +#### b. Teleop with displaying cameras +You will be able to display the cameras while you are teleoperating by running the following code. This is useful to prepare your setup before recording your first dataset. +```bash +python lerobot/scripts/control_robot.py \ + --robot.type=roarm_m3 \ + --control.type=teleoperate +``` + +## B. Record a dataset + +Once you're familiar with teleoperation, you can record your first dataset with roarm_m3. + +If you want to use the Hugging Face hub features for uploading your dataset and you haven't previously done it, make sure you've logged in using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens): +```bash +huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential +``` + +Store your Hugging Face repository name in a variable to run these commands: +```bash +HF_USER=$(huggingface-cli whoami | head -n 1) +echo $HF_USER +``` + +Record 2 episodes and upload your dataset to the hub: +```bash +python lerobot/scripts/control_robot.py \ + --robot.type=roarm_m3 \ + --control.type=record \ + --control.fps=30 \ + --control.single_task="Grasp a block and put it in the bin." \ + --control.repo_id=${HF_USER}/roarm_m3_test \ + --control.tags='["roarm_m3","tutorial"]' \ + --control.warmup_time_s=5 \ + --control.episode_time_s=30 \ + --control.reset_time_s=30 \ + --control.num_episodes=2 \ + --control.push_to_hub=true +``` + +Note: You can resume recording by adding `--control.resume=true`. + +## C. Visualize a dataset + +If you uploaded your dataset to the hub with `--control.push_to_hub=true`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by: +```bash +echo ${HF_USER}/roarm_m3_test +``` + +If you didn't upload with `--control.push_to_hub=false`, you can also visualize it locally with (a window can be opened in the browser `http://ip:9090` with the visualization tool): +```bash +python lerobot/scripts/visualize_dataset_html.py \ + --repo-id ${HF_USER}/roarm_m3_test \ + --host ip \ + --local-files-only 1 +``` + +## D. Replay an episode + +Now try to replay episode nth on your bot: +```bash +python lerobot/scripts/control_robot.py \ + --robot.type=roarm_m3 \ + --control.type=replay \ + --control.fps=30 \ + --control.repo_id=${HF_USER}/roarm_m3_test \ + --control.episode=n-1 +``` + +## E. Train a policy + +To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command: +```bash +python lerobot/scripts/train.py \ + --dataset.repo_id=${HF_USER}/roarm_m3_test \ + --policy.type=act \ + --output_dir=outputs/train/act_roarm_m3_test \ + --job_name=act_roarm_m3_test \ + --device=cuda \ + --wandb.enable=true +``` + +Let's explain it: +1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/roarm_m3_test`. +2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset. +4. We provided `device=cuda` since we are training on a Nvidia GPU, but you could use `device=mps` to train on Apple silicon. +5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`. + +Training should take several hours. You will find checkpoints in `outputs/train/act_roarm_m3_test/checkpoints`. + +## F. Evaluate your policy + +You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes: +```bash +python lerobot/scripts/control_robot.py \ + --robot.type=roarm_m3 \ + --control.type=record \ + --control.fps=30 \ + --control.single_task="Grasp a block and put it in the bin." \ + --control.repo_id=${HF_USER}/eval_act_roarm_m3_test \ + --control.tags='["tutorial"]' \ + --control.warmup_time_s=5 \ + --control.episode_time_s=30 \ + --control.reset_time_s=30 \ + --control.num_episodes=10 \ + --control.push_to_hub=true \ + --control.policy.path=outputs/train/act_roarm_m3_test/checkpoints/last/pretrained_model +``` + +As you can see, it's almost the same command as previously used to record your training dataset. Two things changed: +1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_roarm_m3_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_roarm_m3_test`). +2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_roarm_m3_test`). + +## G. More Information + +Follow this [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) for a more in-depth tutorial on controlling real robots with LeRobot. diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 5414c76d..505c33b2 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import contextlib import logging import shutil from pathlib import Path @@ -28,7 +27,6 @@ import torch.utils from datasets import concatenate_datasets, load_dataset from huggingface_hub import HfApi, snapshot_download from huggingface_hub.constants import REPOCARD_NAME -from huggingface_hub.errors import RevisionNotFoundError from lerobot.common.constants import HF_LEROBOT_HOME from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats @@ -314,7 +312,7 @@ class LeRobotDatasetMetadata: obj.repo_id = repo_id obj.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id - obj.root.mkdir(parents=True, exist_ok=False) + obj.root.mkdir(parents=True, exist_ok=True) if robot is not None: features = get_features_from_robot(robot, use_videos) @@ -480,7 +478,7 @@ class LeRobotDataset(torch.utils.data.Dataset): self.image_writer = None self.episode_buffer = None - self.root.mkdir(exist_ok=True, parents=True) + self.root.mkdir(exist_ok=False, parents=True) # Load metadata self.meta = LeRobotDatasetMetadata( @@ -519,7 +517,6 @@ class LeRobotDataset(torch.utils.data.Dataset): branch: str | None = None, tags: list | None = None, license: str | None = "apache-2.0", - tag_version: bool = True, push_videos: bool = True, private: bool = False, allow_patterns: list[str] | str | None = None, @@ -565,11 +562,6 @@ class LeRobotDataset(torch.utils.data.Dataset): ) card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch) - if tag_version: - with contextlib.suppress(RevisionNotFoundError): - hub_api.delete_tag(self.repo_id, tag=CODEBASE_VERSION, repo_type="dataset") - hub_api.create_tag(self.repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset") - def pull_from_repo( self, allow_patterns: list[str] | str | None = None, diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 7e297b35..89adb163 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -31,7 +31,6 @@ import packaging.version import torch from datasets.table import embed_table_storage from huggingface_hub import DatasetCard, DatasetCardData, HfApi -from huggingface_hub.errors import RevisionNotFoundError from PIL import Image as PILImage from torchvision import transforms @@ -326,19 +325,6 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) -> ) hub_versions = get_repo_versions(repo_id) - if not hub_versions: - raise RevisionNotFoundError( - f"""Your dataset must be tagged with a codebase version. - Assuming _version_ is the codebase_version value in the info.json, you can run this: - ```python - from huggingface_hub import HfApi - - hub_api = HfApi() - hub_api.create_tag("{repo_id}", tag="_version_", repo_type="dataset") - ``` - """ - ) - if target_version in hub_versions: return f"v{target_version}" diff --git a/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py b/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py index acf0282f..fe658ce4 100644 --- a/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py +++ b/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py @@ -166,6 +166,9 @@ def parse_robot_config(robot_cfg: RobotConfig) -> tuple[str, dict]: for arm in robot_cfg.leader_arms for motor in robot_cfg.leader_arms[arm].motors ] + elif robot_cfg.type == "roarm_m3": + state_names = ["roam_m3","roam_m3","roam_m3","roam_m3","roam_m3","roam_m3"] + action_names = ["roam_m3","roam_m3","roam_m3","roam_m3","roam_m3","roam_m3"] # elif robot_cfg["robot_type"] == "stretch3": TODO else: raise NotImplementedError( diff --git a/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py b/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py index 163a6003..20bda75b 100644 --- a/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py +++ b/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py @@ -57,7 +57,7 @@ def convert_dataset( dataset.meta.info["codebase_version"] = CODEBASE_VERSION write_info(dataset.meta.info, dataset.root) - dataset.push_to_hub(branch=branch, tag_version=False, allow_patterns="meta/") + dataset.push_to_hub(branch=branch, allow_patterns="meta/") # delete old stats.json file if (dataset.root / STATS_PATH).is_file: diff --git a/lerobot/common/policies/pi0/modeling_pi0.py b/lerobot/common/policies/pi0/modeling_pi0.py index bc53bf85..c8b12caf 100644 --- a/lerobot/common/policies/pi0/modeling_pi0.py +++ b/lerobot/common/policies/pi0/modeling_pi0.py @@ -313,7 +313,7 @@ class PI0Policy(PreTrainedPolicy): state = self.prepare_state(batch) lang_tokens, lang_masks = self.prepare_language(batch) actions = self.prepare_action(batch) - actions_is_pad = batch.get("actions_is_pad") + actions_is_pad = batch.get("actions_id_pad") loss_dict = {} losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time) diff --git a/lerobot/common/robot_devices/cameras/opencv.py b/lerobot/common/robot_devices/cameras/opencv.py index 93c791fa..506f6075 100644 --- a/lerobot/common/robot_devices/cameras/opencv.py +++ b/lerobot/common/robot_devices/cameras/opencv.py @@ -30,6 +30,20 @@ from lerobot.common.utils.utils import capture_timestamp_utc # treat the same cameras as new devices. Thus we select a higher bound to search indices. MAX_OPENCV_INDEX = 60 +undistort = True +def undistort_image(image): + import cv2 + camera_matrix = np.array([ + [289.11451, 0., 347.23664], + [0., 289.75319, 235.67429], + [0., 0., 1.] + ]) + + dist_coeffs = np.array([-0.208848, 0.028006, -0.000705, -0.000820, 0.000000]) + + undistorted_image = cv2.undistort(image, camera_matrix, dist_coeffs) + + return undistorted_image def find_cameras(raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX, mock=False) -> list[dict]: cameras = [] @@ -368,6 +382,9 @@ class OpenCVCamera: color_image = cv2.cvtColor(color_image, cv2.COLOR_BGR2RGB) + if undistort: + color_image = undistort_image(color_image) + h, w, _ = color_image.shape if h != self.height or w != self.width: raise OSError( diff --git a/lerobot/common/robot_devices/cameras/utils.py b/lerobot/common/robot_devices/cameras/utils.py index ef6d8266..88288ea3 100644 --- a/lerobot/common/robot_devices/cameras/utils.py +++ b/lerobot/common/robot_devices/cameras/utils.py @@ -31,7 +31,7 @@ def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> list[C cameras[key] = IntelRealSenseCamera(cfg) else: - raise ValueError(f"The camera type '{cfg.type}' is not valid.") + raise ValueError(f"The motor type '{cfg.type}' is not valid.") return cameras diff --git a/lerobot/common/robot_devices/control_utils.py b/lerobot/common/robot_devices/control_utils.py index d2361a64..c5870e1f 100644 --- a/lerobot/common/robot_devices/control_utils.py +++ b/lerobot/common/robot_devices/control_utils.py @@ -113,46 +113,73 @@ def predict_action(observation, policy, device, use_amp): return action +# def init_keyboard_listener(): +# # Allow to exit early while recording an episode or resetting the environment, +# # by tapping the right arrow key '->'. This might require a sudo permission +# # to allow your terminal to monitor keyboard events. +# events = {} +# events["exit_early"] = False +# events["rerecord_episode"] = False +# events["stop_recording"] = False + +# if is_headless(): +# logging.warning( +# "Headless environment detected. On-screen cameras display and keyboard inputs will not be available." +# ) +# listener = None +# return listener, events + +# # Only import pynput if not in a headless environment +# from pynput import keyboard + +# def on_press(key): +# try: +# if key == keyboard.Key.right: +# print("Right arrow key pressed. Exiting loop...") +# events["exit_early"] = True +# elif key == keyboard.Key.left: +# print("Left arrow key pressed. Exiting loop and rerecord the last episode...") +# events["rerecord_episode"] = True +# events["exit_early"] = True +# elif key == keyboard.Key.esc: +# print("Escape key pressed. Stopping data recording...") +# events["stop_recording"] = True +# events["exit_early"] = True +# except Exception as e: +# print(f"Error handling key press: {e}") + +# listener = keyboard.Listener(on_press=on_press) +# listener.start() + +# return listener, events + def init_keyboard_listener(): - # Allow to exit early while recording an episode or resetting the environment, - # by tapping the right arrow key '->'. This might require a sudo permission - # to allow your terminal to monitor keyboard events. events = {} events["exit_early"] = False events["rerecord_episode"] = False events["stop_recording"] = False - - if is_headless(): - logging.warning( - "Headless environment detected. On-screen cameras display and keyboard inputs will not be available." - ) - listener = None - return listener, events - - # Only import pynput if not in a headless environment - from pynput import keyboard - + from sshkeyboard import listen_keyboard + import threading def on_press(key): try: - if key == keyboard.Key.right: + if key == "right": print("Right arrow key pressed. Exiting loop...") events["exit_early"] = True - elif key == keyboard.Key.left: + elif key == "left": print("Left arrow key pressed. Exiting loop and rerecord the last episode...") events["rerecord_episode"] = True events["exit_early"] = True - elif key == keyboard.Key.esc: - print("Escape key pressed. Stopping data recording...") + elif key == "q": + print("Q key pressed. Stopping data recording...") events["stop_recording"] = True events["exit_early"] = True except Exception as e: print(f"Error handling key press: {e}") - listener = keyboard.Listener(on_press=on_press) + listener = threading.Thread(target=listen_keyboard, kwargs={"on_press": on_press}) listener.start() - return listener, events - + return listener,events def warmup_record( robot, @@ -256,7 +283,8 @@ def control_loop( frame = {**observation, **action, "task": single_task} dataset.add_frame(frame) - if display_cameras and not is_headless(): + if display_cameras: + # if display_cameras and not is_headless(): image_keys = [key for key in observation if "image" in key] for key in image_keys: cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)) @@ -288,17 +316,25 @@ def reset_environment(robot, events, reset_time_s, fps): teleoperate=True, ) +# def stop_recording(robot, listener, display_cameras): +# robot.disconnect() + +# if not is_headless(): +# if listener is not None: +# listener.stop() + +# if display_cameras: +# cv2.destroyAllWindows() def stop_recording(robot, listener, display_cameras): robot.disconnect() - if not is_headless(): - if listener is not None: - listener.stop() - - if display_cameras: - cv2.destroyAllWindows() + from sshkeyboard import stop_listening + if listener is not None: + stop_listening() + if display_cameras: + cv2.destroyAllWindows() def sanity_check_dataset_name(repo_id, policy_cfg): _, dataset_name = repo_id.split("/") diff --git a/lerobot/common/robot_devices/robots/configs.py b/lerobot/common/robot_devices/robots/configs.py index 88cb4e6f..ac38c7eb 100644 --- a/lerobot/common/robot_devices/robots/configs.py +++ b/lerobot/common/robot_devices/robots/configs.py @@ -479,6 +479,45 @@ class So100RobotConfig(ManipulatorRobotConfig): mock: bool = False +@RobotConfig.register_subclass("roarm_m3") +@dataclass +class RoarmRobotConfig(RobotConfig): + # `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes. + # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as + # the number of motors in your follower arms. + max_relative_target: int | None = None + + leader_arms: dict[str, str] = field( + default_factory=lambda: { + "main": "/dev/ttyUSB0", + } + ) + + # Follower arms configuration: left and right ports + follower_arms: dict[str, str] = field( + default_factory=lambda: { + "main": "/dev/ttyUSB1", + } + ) + + cameras: dict[str, CameraConfig] = field( + default_factory=lambda: { + "laptop": OpenCVCameraConfig( + camera_index=0, + fps=30, + width=640, + height=480, + ), + "phone": OpenCVCameraConfig( + camera_index=1, + fps=30, + width=640, + height=480, + ), + } + ) + + mock: bool = False @RobotConfig.register_subclass("stretch") @dataclass @@ -515,7 +554,6 @@ class StretchRobotConfig(RobotConfig): mock: bool = False - @RobotConfig.register_subclass("lekiwi") @dataclass class LeKiwiRobotConfig(RobotConfig): diff --git a/lerobot/common/robot_devices/robots/mobile_manipulator.py b/lerobot/common/robot_devices/robots/mobile_manipulator.py index c2cad227..b20c61f7 100644 --- a/lerobot/common/robot_devices/robots/mobile_manipulator.py +++ b/lerobot/common/robot_devices/robots/mobile_manipulator.py @@ -392,19 +392,21 @@ class MobileManipulator: for name in self.leader_arms: pos = self.leader_arms[name].read("Present_Position") pos_tensor = torch.from_numpy(pos).float() + # Instead of pos_tensor.item(), use tolist() to convert the entire tensor to a list arm_positions.extend(pos_tensor.tolist()) - y_cmd = 0.0 # m/s forward/backward - x_cmd = 0.0 # m/s lateral + # (The rest of your code for generating wheel commands remains unchanged) + x_cmd = 0.0 # m/s forward/backward + y_cmd = 0.0 # m/s lateral theta_cmd = 0.0 # deg/s rotation if self.pressed_keys["forward"]: - y_cmd += xy_speed - if self.pressed_keys["backward"]: - y_cmd -= xy_speed - if self.pressed_keys["left"]: x_cmd += xy_speed - if self.pressed_keys["right"]: + if self.pressed_keys["backward"]: x_cmd -= xy_speed + if self.pressed_keys["left"]: + y_cmd += xy_speed + if self.pressed_keys["right"]: + y_cmd -= xy_speed if self.pressed_keys["rotate_left"]: theta_cmd += theta_speed if self.pressed_keys["rotate_right"]: @@ -582,8 +584,8 @@ class MobileManipulator: # Create the body velocity vector [x, y, theta_rad]. velocity_vector = np.array([x_cmd, y_cmd, theta_rad]) - # Define the wheel mounting angles (defined from y axis cw) - angles = np.radians(np.array([300, 180, 60])) + # Define the wheel mounting angles with a -90° offset. + angles = np.radians(np.array([240, 120, 0]) - 90) # Build the kinematic matrix: each row maps body velocities to a wheel’s linear speed. # The third column (base_radius) accounts for the effect of rotation. m = np.array([[np.cos(a), np.sin(a), base_radius] for a in angles]) @@ -639,8 +641,8 @@ class MobileManipulator: # Compute each wheel’s linear speed (m/s) from its angular speed. wheel_linear_speeds = wheel_radps * wheel_radius - # Define the wheel mounting angles (defined from y axis cw) - angles = np.radians(np.array([300, 180, 60])) + # Define the wheel mounting angles with a -90° offset. + angles = np.radians(np.array([240, 120, 0]) - 90) m = np.array([[np.cos(a), np.sin(a), base_radius] for a in angles]) # Solve the inverse kinematics: body_velocity = M⁻¹ · wheel_linear_speeds. diff --git a/lerobot/common/robot_devices/robots/roarm_m3.py b/lerobot/common/robot_devices/robots/roarm_m3.py new file mode 100644 index 00000000..24461194 --- /dev/null +++ b/lerobot/common/robot_devices/robots/roarm_m3.py @@ -0,0 +1,306 @@ +"""Contains logic to instantiate a robot, read information from its motors and cameras, +and send orders to its motors. +""" +# TODO(rcadene, aliberts): reorganize the codebase into one file per robot, with the associated +# calibration procedure, to make it easy for people to add their own robot. + +import json +import logging +import time +import warnings +from pathlib import Path + +import numpy as np +import torch + +from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs +from lerobot.common.robot_devices.robots.configs import RoarmRobotConfig +from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError +from roarm_sdk.roarm import roarm + +def ensure_safe_goal_position( + goal_pos: torch.Tensor, present_pos: torch.Tensor, max_relative_target: float | list[float] +): + # Cap relative action target magnitude for safety. + diff = goal_pos - present_pos + max_relative_target = torch.tensor(max_relative_target) + safe_diff = torch.minimum(diff, max_relative_target) + safe_diff = torch.maximum(safe_diff, -max_relative_target) + safe_goal_pos = present_pos + safe_diff + + if not torch.allclose(goal_pos, safe_goal_pos): + logging.warning( + "Relative goal position magnitude had to be clamped to be safe.\n" + f" requested relative goal position target: {diff}\n" + f" clamped relative goal position target: {safe_diff}" + ) + + return safe_goal_pos + +def make_roarm_from_configs(configs: dict[str, str]) -> (dict[str, roarm]): + roarms = {} + + for key, port in configs.items(): + roarms[key] = roarm(roarm_type="roarm_m3", port=port, baudrate=115200) + + return roarms + +class RoarmRobot: + + def __init__( + self, + config: RoarmRobotConfig, + ): + self.config = config + self.robot_type = self.config.type + self.leader_arms = make_roarm_from_configs(self.config.leader_arms) + self.follower_arms = make_roarm_from_configs(self.config.follower_arms) + self.cameras = make_cameras_from_configs(self.config.cameras) + self.is_connected = False + self.logs = {} + + @property + def camera_features(self) -> dict: + cam_ft = {} + for cam_key, cam in self.cameras.items(): + key = f"observation.images.{cam_key}" + cam_ft[key] = { + "shape": (cam.height, cam.width, cam.channels), + "names": ["height", "width", "channels"], + "info": None, + } + return cam_ft + + @property + def motor_features(self) -> dict: + return { + "action": { + "dtype": "float32", + "shape": (6,), + "names": "roam_m3", + }, + "observation.state": { + "dtype": "float32", + "shape": (6,), + "names": "roam_m3", + }, + } + + @property + def features(self): + return {**self.motor_features, **self.camera_features} + + @property + def has_camera(self): + return len(self.cameras) > 0 + + @property + def num_cameras(self): + return len(self.cameras) + + def connect(self): + if self.is_connected: + raise RobotDeviceAlreadyConnectedError( + "RoarmRobot is already connected. Do not run `robot.connect()` twice." + ) + + if not self.leader_arms and not self.follower_arms and not self.cameras: + raise ValueError( + "RoarmRobot doesn't have any device to connect. See example of usage in docstring of the class." + ) + + for name in self.follower_arms: + self.follower_arms[name].joints_angle_get() + for name in self.leader_arms: + self.leader_arms[name].joints_angle_get() + + # Connect the cameras + for name in self.cameras: + self.cameras[name].connect() + + self.is_connected = True + + def teleop_step( + self, record_data=False + ) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: + if not self.is_connected: + raise RobotDeviceNotConnectedError( + "RoarmRobot is not connected. You need to run `robot.connect()`." + ) + + # Prepare to assign the position of the leader to the follower + leader_pos = {} + for name in self.leader_arms: + before_lread_t = time.perf_counter() + leader_pos[name] = np.array(self.leader_arms[name].joints_angle_get(), dtype=np.float32) + leader_pos[name] = torch.from_numpy(leader_pos[name]) + self.logs[f"read_leader_{name}_pos_dt_s"] = time.perf_counter() - before_lread_t + + # Send goal position to the follower + follower_goal_pos = {} + for name in self.follower_arms: + before_fwrite_t = time.perf_counter() + goal_pos = leader_pos[name] + + # Cap goal position when too far away from present position. + # Slower fps expected due to reading from the follower. + if self.config.max_relative_target is not None: + present_pos = np.array(self.follower_arms[name].joints_angle_get(), dtype=np.float32) + present_pos = torch.from_numpy(present_pos) + goal_pos = ensure_safe_goal_position(goal_pos, present_pos, self.config.max_relative_target) + + # Used when record_data=True + follower_goal_pos[name] = goal_pos + + goal_pos = goal_pos.numpy().astype(np.int32) + goal_pos = goal_pos.tolist() if isinstance(goal_pos, (np.ndarray, torch.Tensor)) else goal_pos + self.follower_arms[name].joints_angle_ctrl(angles=goal_pos, speed=0, acc=0) + self.logs[f"write_follower_{name}_goal_pos_dt_s"] = time.perf_counter() - before_fwrite_t + + # Early exit when recording data is not requested + if not record_data: + return + + # TODO(rcadene): Add velocity and other info + # Read follower position + follower_pos = {} + for name in self.follower_arms: + before_fread_t = time.perf_counter() + follower_pos[name] = np.array(self.follower_arms[name].joints_angle_get(), dtype=np.float32) + follower_pos[name] = torch.from_numpy(follower_pos[name]) + self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - before_fread_t + + # Create state by concatenating follower current position + state = [] + for name in self.follower_arms: + if name in follower_pos: + state.append(follower_pos[name]) + state = torch.cat(state) + + # Create action by concatenating follower goal position + action = [] + for name in self.follower_arms: + if name in follower_goal_pos: + action.append(follower_goal_pos[name]) + action = torch.cat(action) + + # Capture images from cameras + images = {} + for name in self.cameras: + before_camread_t = time.perf_counter() + images[name] = self.cameras[name].async_read() + images[name] = torch.from_numpy(images[name]) + self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"] + self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t + + # Populate output dictionaries + obs_dict, action_dict = {}, {} + obs_dict["observation.state"] = state + action_dict["action"] = action + for name in self.cameras: + obs_dict[f"observation.images.{name}"] = images[name] + + return obs_dict, action_dict + + def capture_observation(self): + """The returned observations do not have a batch dimension.""" + if not self.is_connected: + raise RobotDeviceNotConnectedError( + "RoarmRobot is not connected. You need to run `robot.connect()`." + ) + + # Read follower position + follower_pos = {} + for name in self.follower_arms: + before_fread_t = time.perf_counter() + follower_pos[name] = np.array(self.follower_arms[name].joints_angle_get(), dtype=np.float32) + follower_pos[name] = torch.from_numpy(follower_pos[name]) + self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - before_fread_t + + # Create state by concatenating follower current position + state = [] + for name in self.follower_arms: + if name in follower_pos: + state.append(follower_pos[name]) + state = torch.cat(state) + + # Capture images from cameras + images = {} + for name in self.cameras: + before_camread_t = time.perf_counter() + images[name] = self.cameras[name].async_read() + images[name] = torch.from_numpy(images[name]) + self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"] + self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t + + # Populate output dictionaries and format to pytorch + obs_dict = {} + obs_dict["observation.state"] = state + for name in self.cameras: + obs_dict[f"observation.images.{name}"] = images[name] + return obs_dict + + def send_action(self, action: torch.Tensor) -> torch.Tensor: + """Command the follower arms to move to a target joint configuration. + + The relative action magnitude may be clipped depending on the configuration parameter + `max_relative_target`. In this case, the action sent differs from original action. + Thus, this function always returns the action actually sent. + + Args: + action: tensor containing the concatenated goal positions for the follower arms. + """ + if not self.is_connected: + raise RobotDeviceNotConnectedError( + "RoarmRobot is not connected. You need to run `robot.connect()`." + ) + + from_idx = 0 + to_idx = 6 + action_sent = [] + for name in self.follower_arms: + # Get goal position of each follower arm by splitting the action vector + goal_pos = action[from_idx:to_idx] + from_idx = to_idx + + # Cap goal position when too far away from present position. + # Slower fps expected due to reading from the follower. + if self.config.max_relative_target is not None: + present_pos = np.array(self.follower_arms[name].joints_angle_get(), dtype=np.float32) + present_pos = torch.from_numpy(present_pos) + goal_pos = ensure_safe_goal_position(goal_pos, present_pos, self.config.max_relative_target) + + # Save tensor to concat and return + action_sent.append(goal_pos) + + # Send goal position to each follower + goal_pos = goal_pos.numpy().astype(np.int32) + goal_pos = goal_pos.tolist() if isinstance(goal_pos, (np.ndarray, torch.Tensor)) else goal_pos + self.follower_arms[name].joints_angle_ctrl(angles=goal_pos, speed=0, acc=0) + + return torch.cat(action_sent) + + def print_logs(self): + pass + # TODO(aliberts): move robot-specific logs logic here + + def disconnect(self): + if not self.is_connected: + raise RobotDeviceNotConnectedError( + "RoarmRobot is not connected. You need to run `robot.connect()` before disconnecting." + ) + + for name in self.follower_arms: + self.follower_arms[name].disconnect() + + for name in self.leader_arms: + self.leader_arms[name].disconnect() + + for name in self.cameras: + self.cameras[name].disconnect() + + self.is_connected = False + + def __del__(self): + if getattr(self, "is_connected", False): + self.disconnect() diff --git a/lerobot/common/robot_devices/robots/utils.py b/lerobot/common/robot_devices/robots/utils.py index 47e2519b..d418ec34 100644 --- a/lerobot/common/robot_devices/robots/utils.py +++ b/lerobot/common/robot_devices/robots/utils.py @@ -9,6 +9,7 @@ from lerobot.common.robot_devices.robots.configs import ( MossRobotConfig, RobotConfig, So100RobotConfig, + RoarmRobotConfig, StretchRobotConfig, ) @@ -44,6 +45,8 @@ def make_robot_config(robot_type: str, **kwargs) -> RobotConfig: return MossRobotConfig(**kwargs) elif robot_type == "so100": return So100RobotConfig(**kwargs) + elif robot_type == "roarm_m3": + return RoarmRobotConfig(**kwargs) elif robot_type == "stretch": return StretchRobotConfig(**kwargs) elif robot_type == "lekiwi": @@ -61,12 +64,15 @@ def make_robot_from_config(config: RobotConfig): from lerobot.common.robot_devices.robots.mobile_manipulator import MobileManipulator return MobileManipulator(config) + elif isinstance(config, RoarmRobotConfig): + from lerobot.common.robot_devices.robots.roarm_m3 import RoarmRobot + + return RoarmRobot(config) else: from lerobot.common.robot_devices.robots.stretch import StretchRobot return StretchRobot(config) - def make_robot(robot_type: str, **kwargs) -> Robot: config = make_robot_config(robot_type, **kwargs) return make_robot_from_config(config) diff --git a/lerobot/common/utils/utils.py b/lerobot/common/utils/utils.py index cd26f04b..d0c12b30 100644 --- a/lerobot/common/utils/utils.py +++ b/lerobot/common/utils/utils.py @@ -17,7 +17,6 @@ import logging import os import os.path as osp import platform -import subprocess from copy import copy from datetime import datetime, timezone from pathlib import Path @@ -166,31 +165,23 @@ def capture_timestamp_utc(): def say(text, blocking=False): - system = platform.system() - - if system == "Darwin": - cmd = ["say", text] - - elif system == "Linux": - cmd = ["spd-say", text] + # Check if mac, linux, or windows. + if platform.system() == "Darwin": + cmd = f'say "{text}"' + if not blocking: + cmd += " &" + elif platform.system() == "Linux": + cmd = f'spd-say "{text}"' if blocking: - cmd.append("--wait") + cmd += " --wait" + elif platform.system() == "Windows": + # TODO(rcadene): Make blocking option work for Windows + cmd = ( + 'PowerShell -Command "Add-Type -AssemblyName System.Speech; ' + f"(New-Object System.Speech.Synthesis.SpeechSynthesizer).Speak('{text}')\"" + ) - elif system == "Windows": - cmd = [ - "PowerShell", - "-Command", - "Add-Type -AssemblyName System.Speech; " - f"(New-Object System.Speech.Synthesis.SpeechSynthesizer).Speak('{text}')", - ] - - else: - raise RuntimeError("Unsupported operating system for text-to-speech.") - - if blocking: - subprocess.run(cmd, check=True) - else: - subprocess.Popen(cmd, creationflags=subprocess.CREATE_NO_WINDOW if system == "Windows" else 0) + os.system(cmd) def log_say(text, play_sounds, blocking=False): diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index ab5d0e8a..32271b80 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -259,6 +259,7 @@ def record( if not robot.is_connected: robot.connect() + #listener, events = init_keyboard_listener() listener, events = init_keyboard_listener() # Execute a few seconds without recording to: @@ -314,7 +315,9 @@ def record( if events["stop_recording"]: break - log_say("Stop recording", cfg.play_sounds, blocking=True) + # log_say("Stop recording", cfg.play_sounds, blocking=True) + log_say("Stop recording", cfg.play_sounds) + stop_recording(robot, listener, cfg.display_cameras) if cfg.push_to_hub: diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 47225993..a4f79afc 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -454,7 +454,7 @@ def _compile_episode_data( @parser.wrap() -def eval_main(cfg: EvalPipelineConfig): +def eval(cfg: EvalPipelineConfig): logging.info(pformat(asdict(cfg))) # Check device is available @@ -499,4 +499,4 @@ def eval_main(cfg: EvalPipelineConfig): if __name__ == "__main__": init_logging() - eval_main() + eval() diff --git a/lerobot/scripts/visualize_dataset_html.py b/lerobot/scripts/visualize_dataset_html.py index d5825aa6..a0da0869 100644 --- a/lerobot/scripts/visualize_dataset_html.py +++ b/lerobot/scripts/visualize_dataset_html.py @@ -158,7 +158,7 @@ def run_server( if major_version < 2: return "Make sure to convert your LeRobotDataset to v2 & above." - episode_data_csv_str, columns, ignored_columns = get_episode_data(dataset, episode_id) + episode_data_csv_str, columns = get_episode_data(dataset, episode_id) dataset_info = { "repo_id": f"{dataset_namespace}/{dataset_name}", "num_samples": dataset.num_frames @@ -194,7 +194,7 @@ def run_server( ] response = requests.get( - f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/episodes.jsonl", timeout=5 + f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/episodes.jsonl" ) response.raise_for_status() # Split into lines and parse each line as JSON @@ -218,7 +218,6 @@ def run_server( videos_info=videos_info, episode_data_csv_str=episode_data_csv_str, columns=columns, - ignored_columns=ignored_columns, ) app.run(host=host, port=port) @@ -237,14 +236,6 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index) selected_columns = [col for col, ft in dataset.features.items() if ft["dtype"] == "float32"] selected_columns.remove("timestamp") - ignored_columns = [] - for column_name in selected_columns: - shape = dataset.features[column_name]["shape"] - shape_dim = len(shape) - if shape_dim > 1: - selected_columns.remove(column_name) - ignored_columns.append(column_name) - # init header of csv with state and action names header = ["timestamp"] @@ -300,7 +291,7 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index) csv_writer.writerows(rows) csv_string = csv_buffer.getvalue() - return csv_string, columns, ignored_columns + return csv_string, columns def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]: @@ -327,9 +318,7 @@ def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) -> def get_dataset_info(repo_id: str) -> IterableNamespace: - response = requests.get( - f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json", timeout=5 - ) + response = requests.get(f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json") response.raise_for_status() # Raises an HTTPError for bad responses dataset_info = response.json() dataset_info["repo_id"] = repo_id diff --git a/lerobot/templates/visualize_dataset_template.html b/lerobot/templates/visualize_dataset_template.html index cf9d40f1..d81ce630 100644 --- a/lerobot/templates/visualize_dataset_template.html +++ b/lerobot/templates/visualize_dataset_template.html @@ -42,22 +42,22 @@ - +
- -
- +
- @@ -83,7 +83,7 @@

- @@ -224,58 +224,49 @@

-
- - - - - - - - - + +
- - - {% if ignored_columns|length > 0 %} -
- Columns {{ ignored_columns }} are NOT shown since the visualizer currently does not support 2D or 3D data. -
- {% endif %} + -
@@ -485,7 +476,7 @@ episodes: {{ episodes }}, pageSize: 100, page: 1, - + init() { // Find which page contains the current episode_id const currentEpisodeId = {{ episode_id }}; @@ -494,23 +485,23 @@ this.page = Math.floor(episodeIndex / this.pageSize) + 1; } }, - + get totalPages() { return Math.ceil(this.episodes.length / this.pageSize); }, - + get paginatedEpisodes() { const start = (this.page - 1) * this.pageSize; const end = start + this.pageSize; return this.episodes.slice(start, end); }, - + nextPage() { if (this.page < this.totalPages) { this.page++; } }, - + prevPage() { if (this.page > 1) { this.page--; @@ -524,7 +515,7 @@ window.addEventListener('keydown', (e) => { // Use the space bar to play and pause, instead of default action (e.g. scrolling) const { keyCode, key } = e; - + if (keyCode === 32 || key === ' ') { e.preventDefault(); const btnPause = document.querySelector('[x-ref="btnPause"]');