diff --git a/lerobot/common/robot_devices/cameras/reachy2.py b/lerobot/common/robot_devices/cameras/reachy2.py index c75dab3e..c581c3e1 100644 --- a/lerobot/common/robot_devices/cameras/reachy2.py +++ b/lerobot/common/robot_devices/cameras/reachy2.py @@ -2,34 +2,15 @@ Wrapper for Reachy2 camera from sdk """ -import argparse -import concurrent.futures -import math -import platform -import shutil -import threading -import time -from dataclasses import dataclass, replace -from pathlib import Path -from threading import Thread +from dataclasses import dataclass import numpy as np -from PIL import Image - -from lerobot.common.robot_devices.utils import ( - RobotDeviceAlreadyConnectedError, - RobotDeviceNotConnectedError, - busy_wait, -) -from lerobot.common.utils.utils import capture_timestamp_utc - from reachy2_sdk.media.camera import CameraView from reachy2_sdk.media.camera_manager import CameraManager @dataclass class ReachyCameraConfig: - fps: int | None = None width: int | None = None height: int | None = None @@ -46,7 +27,7 @@ class ReachyCamera: name: str, image_type: str, config: ReachyCameraConfig | None = None, - **kwargs + **kwargs, ): self.host = host self.port = port @@ -64,7 +45,6 @@ class ReachyCamera: self.is_connected = True def read(self) -> np.ndarray: - if not self.is_connected: self.connect() @@ -78,7 +58,6 @@ class ReachyCamera: else: return None elif self.name == "depth" and hasattr(self.cam_manager, "depth"): - if self.image_type == "depth": return self.cam_manager.depth.get_depth_frame() elif self.image_type == "rgb": diff --git a/lerobot/common/robot_devices/robots/reachy2.py b/lerobot/common/robot_devices/robots/reachy2.py index 252148f5..2667f3e2 100644 --- a/lerobot/common/robot_devices/robots/reachy2.py +++ b/lerobot/common/robot_devices/robots/reachy2.py @@ -18,17 +18,40 @@ import time from copy import copy from dataclasses import dataclass, field, replace -import numpy as np import torch from reachy2_sdk import ReachySDK from lerobot.common.robot_devices.cameras.utils import Camera +REACHY_MOTORS = [ + "neck_yaw.pos", + "neck_pitch.pos", + "neck_roll.pos", + "r_shoulder_pitch.pos", + "r_shoulder_roll.pos", + "r_elbow_yaw.pos", + "r_elbow_pitch.pos", + "r_wrist_roll.pos", + "r_wrist_pitch.pos", + "r_wrist_yaw.pos", + "r_gripper.pos", + "l_shoulder_pitch.pos", + "l_shoulder_roll.pos", + "l_elbow_yaw.pos", + "l_elbow_pitch.pos", + "l_wrist_roll.pos", + "l_wrist_pitch.pos", + "l_wrist_yaw.pos", + "l_gripper.pos", + "mobile_base.vx", + "mobile_base.vy", + "mobile_base.vtheta", +] @dataclass class ReachyRobotConfig: - robot_type: str | None = "Reachy2" + robot_type: str | None = "reachy2" cameras: dict[str, Camera] = field(default_factory=lambda: {}) ip_address: str | None = "172.17.135.207" # ip_address: str | None = "localhost" @@ -59,6 +82,37 @@ class ReachyRobot: self.state_keys = None self.action_keys = None + @property + def camera_features(self) -> dict: + cam_ft = {} + for cam_key, cam in self.cameras.items(): + key = f"observation.images.{cam_key}" + cam_ft[key] = { + "shape": (cam.height, cam.width, cam.channels), + "names": ["height", "width", "channels"], + "info": None, + } + return cam_ft + + @property + def motor_features(self) -> dict: + return { + "action": { + "dtype": "float32", + "shape": (len(REACHY_MOTORS),), + "names": REACHY_MOTORS, + }, + "observation.state": { + "dtype": "float32", + "shape": (len(REACHY_MOTORS),), + "names": REACHY_MOTORS, + }, + } + + @property + def features(self): + return {**self.motor_features, **self.camera_features} + def connect(self) -> None: print("Connecting to Reachy") self.reachy.is_connected = self.reachy.connect() @@ -73,14 +127,10 @@ class ReachyRobot: for name in self.cameras: print(f"Connecting camera: {name}") self.cameras[name].connect() - self.is_connected = ( - self.is_connected and self.cameras[name].is_connected - ) + self.is_connected = self.is_connected and self.cameras[name].is_connected if not self.is_connected: - print( - "Could not connect to the cameras, check that all cameras are plugged-in." - ) + print("Could not connect to the cameras, check that all cameras are plugged-in.") raise ConnectionError() def run_calibration(self): @@ -119,7 +169,7 @@ class ReachyRobot: action["mobile_base_x.vel"] = last_cmd_vel["x"] action["mobile_base_y.vel"] = last_cmd_vel["y"] action["mobile_base_theta.vel"] = last_cmd_vel["theta"] - + action = torch.as_tensor(list(action.values())) obs_dict = self.capture_observation() @@ -179,18 +229,12 @@ class ReachyRobot: # Capture images from cameras images = {} for name in self.cameras: - before_camread_t = time.perf_counter() - images[name] = self.cameras[ - name - ].read() # Reachy cameras read() is not blocking? + # before_camread_t = time.perf_counter() + images[name] = self.cameras[name].read() # Reachy cameras read() is not blocking? # print(f'name: {name} img: {images[name]}') if images[name] is not None: - images[name] = torch.from_numpy( - copy(images[name][0]) - ) # seems like I need to copy? - self.logs[f"read_camera_{name}_dt_s"] = images[name][ - 1 - ] # full timestamp, TODO dt + images[name] = torch.from_numpy(copy(images[name][0])) # seems like I need to copy? + self.logs[f"read_camera_{name}_dt_s"] = images[name][1] # full timestamp, TODO dt # Populate output dictionnaries obs_dict = {} diff --git a/lerobot/configs/robot/reachy2.yaml b/lerobot/configs/robot/reachy2.yaml index 456387e4..5ec9c23c 100644 --- a/lerobot/configs/robot/reachy2.yaml +++ b/lerobot/configs/robot/reachy2.yaml @@ -6,8 +6,7 @@ _target_: lerobot.common.robot_devices.robots.reachy2.ReachyRobot - -robot_type: Reachy2 +robot_type: reachy2 cameras: head_left: diff --git a/lerobot/scripts/test_reachy.py b/lerobot/scripts/test_reachy.py index 79cb10d8..6aba324e 100644 --- a/lerobot/scripts/test_reachy.py +++ b/lerobot/scripts/test_reachy.py @@ -1,30 +1,10 @@ - -import argparse -import logging import time -from pathlib import Path -from typing import List # from safetensors.torch import load_file, save_file -from lerobot.common.datasets.lerobot_dataset import LeRobotDataset -from lerobot.common.datasets.populate_dataset import (create_lerobot_dataset, - delete_current_episode, - init_dataset, - save_current_episode) -from lerobot.common.robot_devices.control_utils import ( - control_loop, has_method, init_keyboard_listener, init_policy, - log_control_info, record_episode, reset_environment, - sanity_check_dataset_name, stop_recording, warmup_record) from lerobot.common.robot_devices.robots.factory import make_robot -from lerobot.common.robot_devices.robots.utils import Robot -from lerobot.common.robot_devices.utils import busy_wait, safe_disconnect -from lerobot.common.utils.utils import (init_hydra_config, init_logging, - log_say, none_or_int) +from lerobot.common.utils.utils import init_hydra_config, init_logging - -import time -import cv2 -if __name__ == '__main__': +if __name__ == "__main__": init_logging() control_mode = "test"