diff --git a/examples/12_use_roarm_m3.md b/examples/12_use_roarm_m3.md new file mode 100644 index 00000000..44c2570a --- /dev/null +++ b/examples/12_use_roarm_m3.md @@ -0,0 +1,144 @@ +# Using the [roarm_m3](https://github.com/waveshareteam/roarm_m3) with LeRobot + +## Table of Contents + + - [A. Install LeRobot](#a-install-lerobot) + - [B. Teleoperate](#b-teleoperate) + - [C. Record a dataset](#c-record-a-dataset) + - [D. Visualize a dataset](#d-visualize-a-dataset) + - [E. Replay an episode](#e-replay-an-episode) + - [F. Train a policy](#f-train-a-policy) + - [G. Evaluate your policy](#g-evaluate-your-policy) + - [H. More Information](#h-more-information) + +## A. Install LeRobot + +Before running the following commands, make sure you have installed LeRobot by following the [installation instructions](https://github.com/lerobot/lerobot/blob/main/README.md). + +## B. Teleoperate + +**Simple teleop** +#### a. Teleop without displaying cameras +You will be able to teleoperate your robot! (it won't connect and display the cameras): +```bash +python lerobot/scripts/control_robot.py \ + --robot.type=roarm_m3 \ + --robot.cameras='{}' \ + --control.type=teleoperate +``` + +#### b. Teleop with displaying cameras +You will be able to display the cameras while you are teleoperating by running the following code. This is useful to prepare your setup before recording your first dataset. +```bash +python lerobot/scripts/control_robot.py \ + --robot.type=roarm_m3 \ + --control.type=teleoperate +``` + +## C. Record a dataset + +Once you're familiar with teleoperation, you can record your first dataset with roarm_m3. + +If you want to use the Hugging Face hub features for uploading your dataset and you haven't previously done it, make sure you've logged in using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens): +```bash +huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential +``` + +Store your Hugging Face repository name in a variable to run these commands: +```bash +HF_USER=$(huggingface-cli whoami | head -n 1) +echo $HF_USER +``` + +Record 2 episodes and upload your dataset to the hub: +```bash +python lerobot/scripts/control_robot.py \ + --robot.type=roarm_m3 \ + --control.type=record \ + --control.fps=30 \ + --control.single_task="Grasp a block and put it in the bin." \ + --control.repo_id=${HF_USER}/roarm_m3_test \ + --control.tags='["roarm_m3","tutorial"]' \ + --control.warmup_time_s=5 \ + --control.episode_time_s=30 \ + --control.reset_time_s=30 \ + --control.num_episodes=2 \ + --control.push_to_hub=true +``` + +Note: You can resume recording by adding `--control.resume=true`. + +## D. Visualize a dataset + +If you uploaded your dataset to the hub with `--control.push_to_hub=true`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by: +```bash +echo ${HF_USER}/roarm_m3_test +``` + +If you didn't upload with `--control.push_to_hub=false`, you can also visualize it locally with (a window can be opened in the browser `http://ip:9090` with the visualization tool): +```bash +python lerobot/scripts/visualize_dataset_html.py \ + --repo-id ${HF_USER}/roarm_m3_test \ + --host ip \ + --local-files-only 1 +``` + +## E. Replay an episode + +Now try to replay episode nth on your bot: +```bash +python lerobot/scripts/control_robot.py \ + --robot.type=roarm_m3 \ + --control.type=replay \ + --control.fps=30 \ + --control.repo_id=${HF_USER}/roarm_m3_test \ + --control.episode=n-1 +``` + +## F. Train a policy + +To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command: +```bash +python lerobot/scripts/train.py \ + --dataset.repo_id=${HF_USER}/roarm_m3_test \ + --policy.type=act \ + --output_dir=outputs/train/act_roarm_m3_test \ + --job_name=act_roarm_m3_test \ + --device=cuda \ + --wandb.enable=true +``` + +Let's explain it: +1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/roarm_m3_test`. +2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset. +4. We provided `device=cuda` since we are training on a Nvidia GPU, but you could use `device=mps` to train on Apple silicon. +5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`. + +Training should take several hours. You will find checkpoints in `outputs/train/act_roarm_m3_test/checkpoints`. + +## G. Evaluate your policy + +You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes: +```bash +python lerobot/scripts/control_robot.py \ + --robot.type=roarm_m3 \ + --control.type=record \ + --control.fps=30 \ + --control.single_task="Grasp a block and put it in the bin." \ + --control.repo_id=${HF_USER}/eval_act_roarm_m3_test \ + --control.tags='["tutorial"]' \ + --control.warmup_time_s=5 \ + --control.episode_time_s=30 \ + --control.reset_time_s=30 \ + --control.num_episodes=10 \ + --control.push_to_hub=true \ + --control.policy.path=outputs/train/act_roarm_m3_test/checkpoints/last/pretrained_model +``` + +As you can see, it's almost the same command as previously used to record your training dataset. Two things changed: +1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_roarm_m3_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_roarm_m3_test`). +2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_roarm_m3_test`). + +## H. More Information + +Follow this [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) for a more in-depth tutorial on controlling real robots with LeRobot. diff --git a/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py b/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py index acf0282f..f5769dc5 100644 --- a/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py +++ b/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py @@ -166,6 +166,9 @@ def parse_robot_config(robot_cfg: RobotConfig) -> tuple[str, dict]: for arm in robot_cfg.leader_arms for motor in robot_cfg.leader_arms[arm].motors ] + elif robot_cfg.type == "roarm_m3": + state_names = ["1", "2", "3", "4", "5", "6"] + action_names = ["1", "2", "3", "4", "5", "6"] # elif robot_cfg["robot_type"] == "stretch3": TODO else: raise NotImplementedError( diff --git a/lerobot/common/robot_devices/cameras/opencv.py b/lerobot/common/robot_devices/cameras/opencv.py index f279f315..998fd54a 100644 --- a/lerobot/common/robot_devices/cameras/opencv.py +++ b/lerobot/common/robot_devices/cameras/opencv.py @@ -44,6 +44,20 @@ from lerobot.common.utils.utils import capture_timestamp_utc # treat the same cameras as new devices. Thus we select a higher bound to search indices. MAX_OPENCV_INDEX = 60 +undistort = True + + +def undistort_image(image): + import cv2 + + camera_matrix = np.array([[289.11451, 0.0, 347.23664], [0.0, 289.75319, 235.67429], [0.0, 0.0, 1.0]]) + + dist_coeffs = np.array([-0.208848, 0.028006, -0.000705, -0.000820, 0.000000]) + + undistorted_image = cv2.undistort(image, camera_matrix, dist_coeffs) + + return undistorted_image + def find_cameras(raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX, mock=False) -> list[dict]: cameras = [] @@ -404,6 +418,9 @@ class OpenCVCamera: color_image = cv2.cvtColor(color_image, cv2.COLOR_BGR2RGB) + if undistort: + color_image = undistort_image(color_image) + h, w, _ = color_image.shape if h != self.capture_height or w != self.capture_width: raise OSError( diff --git a/lerobot/common/robot_devices/control_utils.py b/lerobot/common/robot_devices/control_utils.py index 78a8c6a6..e774bc35 100644 --- a/lerobot/common/robot_devices/control_utils.py +++ b/lerobot/common/robot_devices/control_utils.py @@ -128,42 +128,73 @@ def predict_action(observation, policy, device, use_amp): return action +# def init_keyboard_listener(): +# # Allow to exit early while recording an episode or resetting the environment, +# # by tapping the right arrow key '->'. This might require a sudo permission +# # to allow your terminal to monitor keyboard events. +# events = {} +# events["exit_early"] = False +# events["rerecord_episode"] = False +# events["stop_recording"] = False + +# if is_headless(): +# logging.warning( +# "Headless environment detected. On-screen cameras display and keyboard inputs will not be available." +# ) +# listener = None +# return listener, events + +# # Only import pynput if not in a headless environment +# from pynput import keyboard + +# def on_press(key): +# try: +# if key == keyboard.Key.right: +# print("Right arrow key pressed. Exiting loop...") +# events["exit_early"] = True +# elif key == keyboard.Key.left: +# print("Left arrow key pressed. Exiting loop and rerecord the last episode...") +# events["rerecord_episode"] = True +# events["exit_early"] = True +# elif key == keyboard.Key.esc: +# print("Escape key pressed. Stopping data recording...") +# events["stop_recording"] = True +# events["exit_early"] = True +# except Exception as e: +# print(f"Error handling key press: {e}") + +# listener = keyboard.Listener(on_press=on_press) +# listener.start() + +# return listener, events + + def init_keyboard_listener(): - # Allow to exit early while recording an episode or resetting the environment, - # by tapping the right arrow key '->'. This might require a sudo permission - # to allow your terminal to monitor keyboard events. events = {} events["exit_early"] = False events["rerecord_episode"] = False events["stop_recording"] = False + import threading - if is_headless(): - logging.warning( - "Headless environment detected. On-screen cameras display and keyboard inputs will not be available." - ) - listener = None - return listener, events - - # Only import pynput if not in a headless environment - from pynput import keyboard + from sshkeyboard import listen_keyboard def on_press(key): try: - if key == keyboard.Key.right: + if key == "right": print("Right arrow key pressed. Exiting loop...") events["exit_early"] = True - elif key == keyboard.Key.left: + elif key == "left": print("Left arrow key pressed. Exiting loop and rerecord the last episode...") events["rerecord_episode"] = True events["exit_early"] = True - elif key == keyboard.Key.esc: - print("Escape key pressed. Stopping data recording...") + elif key == "q": + print("Q key pressed. Stopping data recording...") events["stop_recording"] = True events["exit_early"] = True except Exception as e: print(f"Error handling key press: {e}") - listener = keyboard.Listener(on_press=on_press) + listener = threading.Thread(target=listen_keyboard, kwargs={"on_press": on_press}) listener.start() return listener, events @@ -264,7 +295,8 @@ def control_loop( frame = {**observation, **action, "task": single_task} dataset.add_frame(frame) - if display_cameras and not is_headless(): + if display_cameras: + # if display_cameras and not is_headless(): image_keys = [key for key in observation if "image" in key] for key in image_keys: cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)) @@ -297,15 +329,27 @@ def reset_environment(robot, events, reset_time_s, fps): ) +# def stop_recording(robot, listener, display_cameras): +# robot.disconnect() + +# if not is_headless(): +# if listener is not None: +# listener.stop() + +# if display_cameras: +# cv2.destroyAllWindows() + + def stop_recording(robot, listener, display_cameras): robot.disconnect() - if not is_headless(): - if listener is not None: - listener.stop() + from sshkeyboard import stop_listening - if display_cameras: - cv2.destroyAllWindows() + if listener is not None: + stop_listening() + + if display_cameras: + cv2.destroyAllWindows() def sanity_check_dataset_name(repo_id, policy_cfg): diff --git a/lerobot/common/robot_devices/robots/configs.py b/lerobot/common/robot_devices/robots/configs.py index e940b442..6a4140af 100644 --- a/lerobot/common/robot_devices/robots/configs.py +++ b/lerobot/common/robot_devices/robots/configs.py @@ -494,6 +494,47 @@ class So100RobotConfig(ManipulatorRobotConfig): mock: bool = False +@RobotConfig.register_subclass("roarm_m3") +@dataclass +class RoarmRobotConfig(RobotConfig): + # `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes. + # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as + # the number of motors in your follower arms. + max_relative_target: int | None = None + + leader_arms: dict[str, str] = field( + default_factory=lambda: { + "main": "/dev/ttyUSB0", + } + ) + + # Follower arms configuration: left and right ports + follower_arms: dict[str, str] = field( + default_factory=lambda: { + "main": "/dev/ttyUSB1", + } + ) + + cameras: dict[str, CameraConfig] = field( + default_factory=lambda: { + "laptop": OpenCVCameraConfig( + camera_index=0, + fps=30, + width=640, + height=480, + ), + "phone": OpenCVCameraConfig( + camera_index=2, + fps=30, + width=640, + height=480, + ), + } + ) + + mock: bool = False + + @RobotConfig.register_subclass("stretch") @dataclass class StretchRobotConfig(RobotConfig): diff --git a/lerobot/common/robot_devices/robots/roarm_m3.py b/lerobot/common/robot_devices/robots/roarm_m3.py new file mode 100644 index 00000000..941ca4ec --- /dev/null +++ b/lerobot/common/robot_devices/robots/roarm_m3.py @@ -0,0 +1,305 @@ +"""Contains logic to instantiate a robot, read information from its motors and cameras, +and send orders to its motors. +""" +# TODO(rcadene, aliberts): reorganize the codebase into one file per robot, with the associated +# calibration procedure, to make it easy for people to add their own robot. + +import logging +import time + +import numpy as np +import torch +from roarm_sdk.roarm import roarm + +from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs +from lerobot.common.robot_devices.robots.configs import RoarmRobotConfig +from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError + + +def ensure_safe_goal_position( + goal_pos: torch.Tensor, present_pos: torch.Tensor, max_relative_target: float | list[float] +): + # Cap relative action target magnitude for safety. + diff = goal_pos - present_pos + max_relative_target = torch.tensor(max_relative_target) + safe_diff = torch.minimum(diff, max_relative_target) + safe_diff = torch.maximum(safe_diff, -max_relative_target) + safe_goal_pos = present_pos + safe_diff + + if not torch.allclose(goal_pos, safe_goal_pos): + logging.warning( + "Relative goal position magnitude had to be clamped to be safe.\n" + f" requested relative goal position target: {diff}\n" + f" clamped relative goal position target: {safe_diff}" + ) + + return safe_goal_pos + + +def make_roarm_from_configs(configs: dict[str, str]) -> dict[str, roarm]: + roarms = {} + + for key, port in configs.items(): + roarms[key] = roarm(roarm_type="roarm_m3", port=port, baudrate=115200) + + return roarms + + +class RoarmRobot: + def __init__( + self, + config: RoarmRobotConfig, + ): + self.config = config + self.robot_type = self.config.type + self.leader_arms = make_roarm_from_configs(self.config.leader_arms) + self.follower_arms = make_roarm_from_configs(self.config.follower_arms) + self.cameras = make_cameras_from_configs(self.config.cameras) + self.is_connected = False + self.logs = {} + + @property + def camera_features(self) -> dict: + cam_ft = {} + for cam_key, cam in self.cameras.items(): + key = f"observation.images.{cam_key}" + cam_ft[key] = { + "shape": (cam.height, cam.width, cam.channels), + "names": ["height", "width", "channels"], + "info": None, + } + return cam_ft + + @property + def motor_features(self) -> dict: + return { + "action": { + "dtype": "float32", + "shape": (6,), + "names": ["1", "2,3", "4", "5", "6"], + }, + "observation.state": { + "dtype": "float32", + "shape": (6,), + "names": ["1", "2,3", "4", "5", "6"], + }, + } + + @property + def features(self): + return {**self.motor_features, **self.camera_features} + + @property + def has_camera(self): + return len(self.cameras) > 0 + + @property + def num_cameras(self): + return len(self.cameras) + + def connect(self): + if self.is_connected: + raise RobotDeviceAlreadyConnectedError( + "RoarmRobot is already connected. Do not run `robot.connect()` twice." + ) + + if not self.leader_arms and not self.follower_arms and not self.cameras: + raise ValueError( + "RoarmRobot doesn't have any device to connect. See example of usage in docstring of the class." + ) + + for name in self.follower_arms: + self.follower_arms[name].joints_angle_get() + for name in self.leader_arms: + self.leader_arms[name].joints_angle_get() + + # Connect the cameras + for name in self.cameras: + self.cameras[name].connect() + + self.is_connected = True + + def teleop_step( + self, record_data=False + ) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: + if not self.is_connected: + raise RobotDeviceNotConnectedError( + "RoarmRobot is not connected. You need to run `robot.connect()`." + ) + + # Prepare to assign the position of the leader to the follower + leader_pos = {} + for name in self.leader_arms: + before_lread_t = time.perf_counter() + leader_pos[name] = np.array(self.leader_arms[name].joints_angle_get(), dtype=np.float32) + leader_pos[name] = torch.from_numpy(leader_pos[name]) + self.logs[f"read_leader_{name}_pos_dt_s"] = time.perf_counter() - before_lread_t + + # Send goal position to the follower + follower_goal_pos = {} + for name in self.follower_arms: + before_fwrite_t = time.perf_counter() + goal_pos = leader_pos[name] + + # Cap goal position when too far away from present position. + # Slower fps expected due to reading from the follower. + if self.config.max_relative_target is not None: + present_pos = np.array(self.follower_arms[name].joints_angle_get(), dtype=np.float32) + present_pos = torch.from_numpy(present_pos) + goal_pos = ensure_safe_goal_position(goal_pos, present_pos, self.config.max_relative_target) + + # Used when record_data=True + follower_goal_pos[name] = goal_pos + + goal_pos = goal_pos.numpy().astype(np.int32) + goal_pos = goal_pos.tolist() if isinstance(goal_pos, (np.ndarray, torch.Tensor)) else goal_pos + self.follower_arms[name].joints_angle_ctrl(angles=goal_pos, speed=0, acc=0) + self.logs[f"write_follower_{name}_goal_pos_dt_s"] = time.perf_counter() - before_fwrite_t + + # Early exit when recording data is not requested + if not record_data: + return + + # TODO(rcadene): Add velocity and other info + # Read follower position + follower_pos = {} + for name in self.follower_arms: + before_fread_t = time.perf_counter() + follower_pos[name] = np.array(self.follower_arms[name].joints_angle_get(), dtype=np.float32) + follower_pos[name] = torch.from_numpy(follower_pos[name]) + self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - before_fread_t + + # Create state by concatenating follower current position + state = [] + for name in self.follower_arms: + if name in follower_pos: + state.append(follower_pos[name]) + state = torch.cat(state) + + # Create action by concatenating follower goal position + action = [] + for name in self.follower_arms: + if name in follower_goal_pos: + action.append(follower_goal_pos[name]) + action = torch.cat(action) + + # Capture images from cameras + images = {} + for name in self.cameras: + before_camread_t = time.perf_counter() + images[name] = self.cameras[name].async_read() + images[name] = torch.from_numpy(images[name]) + self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"] + self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t + + # Populate output dictionaries + obs_dict, action_dict = {}, {} + obs_dict["observation.state"] = state + action_dict["action"] = action + for name in self.cameras: + obs_dict[f"observation.images.{name}"] = images[name] + + return obs_dict, action_dict + + def capture_observation(self): + """The returned observations do not have a batch dimension.""" + if not self.is_connected: + raise RobotDeviceNotConnectedError( + "RoarmRobot is not connected. You need to run `robot.connect()`." + ) + + # Read follower position + follower_pos = {} + for name in self.follower_arms: + before_fread_t = time.perf_counter() + follower_pos[name] = np.array(self.follower_arms[name].joints_angle_get(), dtype=np.float32) + follower_pos[name] = torch.from_numpy(follower_pos[name]) + self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - before_fread_t + + # Create state by concatenating follower current position + state = [] + for name in self.follower_arms: + if name in follower_pos: + state.append(follower_pos[name]) + state = torch.cat(state) + + # Capture images from cameras + images = {} + for name in self.cameras: + before_camread_t = time.perf_counter() + images[name] = self.cameras[name].async_read() + images[name] = torch.from_numpy(images[name]) + self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"] + self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t + + # Populate output dictionaries and format to pytorch + obs_dict = {} + obs_dict["observation.state"] = state + for name in self.cameras: + obs_dict[f"observation.images.{name}"] = images[name] + return obs_dict + + def send_action(self, action: torch.Tensor) -> torch.Tensor: + """Command the follower arms to move to a target joint configuration. + + The relative action magnitude may be clipped depending on the configuration parameter + `max_relative_target`. In this case, the action sent differs from original action. + Thus, this function always returns the action actually sent. + + Args: + action: tensor containing the concatenated goal positions for the follower arms. + """ + if not self.is_connected: + raise RobotDeviceNotConnectedError( + "RoarmRobot is not connected. You need to run `robot.connect()`." + ) + + from_idx = 0 + to_idx = 6 + action_sent = [] + for name in self.follower_arms: + # Get goal position of each follower arm by splitting the action vector + goal_pos = action[from_idx:to_idx] + from_idx = to_idx + + # Cap goal position when too far away from present position. + # Slower fps expected due to reading from the follower. + if self.config.max_relative_target is not None: + present_pos = np.array(self.follower_arms[name].joints_angle_get(), dtype=np.float32) + present_pos = torch.from_numpy(present_pos) + goal_pos = ensure_safe_goal_position(goal_pos, present_pos, self.config.max_relative_target) + + # Save tensor to concat and return + action_sent.append(goal_pos) + + # Send goal position to each follower + goal_pos = goal_pos.numpy().astype(np.int32) + goal_pos = goal_pos.tolist() if isinstance(goal_pos, (np.ndarray, torch.Tensor)) else goal_pos + self.follower_arms[name].joints_angle_ctrl(angles=goal_pos, speed=0, acc=0) + + return torch.cat(action_sent) + + def print_logs(self): + pass + # TODO(aliberts): move robot-specific logs logic here + + def disconnect(self): + if not self.is_connected: + raise RobotDeviceNotConnectedError( + "RoarmRobot is not connected. You need to run `robot.connect()` before disconnecting." + ) + + for name in self.follower_arms: + self.follower_arms[name].disconnect() + + for name in self.leader_arms: + self.leader_arms[name].disconnect() + + for name in self.cameras: + self.cameras[name].disconnect() + + self.is_connected = False + + def __del__(self): + if getattr(self, "is_connected", False): + self.disconnect() diff --git a/lerobot/common/robot_devices/robots/utils.py b/lerobot/common/robot_devices/robots/utils.py index dab514d5..b864fa1b 100644 --- a/lerobot/common/robot_devices/robots/utils.py +++ b/lerobot/common/robot_devices/robots/utils.py @@ -21,6 +21,7 @@ from lerobot.common.robot_devices.robots.configs import ( LeKiwiRobotConfig, ManipulatorRobotConfig, MossRobotConfig, + RoarmRobotConfig, RobotConfig, So100RobotConfig, StretchRobotConfig, @@ -58,6 +59,8 @@ def make_robot_config(robot_type: str, **kwargs) -> RobotConfig: return MossRobotConfig(**kwargs) elif robot_type == "so100": return So100RobotConfig(**kwargs) + elif robot_type == "roarm_m3": + return RoarmRobotConfig(**kwargs) elif robot_type == "stretch": return StretchRobotConfig(**kwargs) elif robot_type == "lekiwi": @@ -75,6 +78,10 @@ def make_robot_from_config(config: RobotConfig): from lerobot.common.robot_devices.robots.mobile_manipulator import MobileManipulator return MobileManipulator(config) + elif isinstance(config, RoarmRobotConfig): + from lerobot.common.robot_devices.robots.roarm_m3 import RoarmRobot + + return RoarmRobot(config) else: from lerobot.common.robot_devices.robots.stretch import StretchRobot diff --git a/pyproject.toml b/pyproject.toml index 1fa7b246..3124a933 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,6 +73,7 @@ dependencies = [ "torchvision>=0.21.0", "wandb>=0.16.3", "zarr>=2.17.0", + "sshkeyboard", ] [project.optional-dependencies] @@ -96,6 +97,7 @@ test = ["pytest>=8.1.0", "pytest-cov>=5.0.0", "pyserial>=3.5"] umi = ["imagecodecs>=2024.1.1"] video_benchmark = ["scikit-image>=0.23.2", "pandas>=2.2.2"] xarm = ["gym-xarm>=0.1.1 ; python_version < '4.0'"] +roarm = ["roarm-sdk==0.0.11"] [tool.poetry] requires-poetry = ">=2.1"