diff --git a/lerobot/common/robot_devices/cameras/reachy2.py b/lerobot/common/robot_devices/cameras/reachy2.py index c581c3e1..24040034 100644 --- a/lerobot/common/robot_devices/cameras/reachy2.py +++ b/lerobot/common/robot_devices/cameras/reachy2.py @@ -2,8 +2,9 @@ Wrapper for Reachy2 camera from sdk """ -from dataclasses import dataclass +from dataclasses import dataclass, replace +import cv2 import numpy as np from reachy2_sdk.media.camera import CameraView from reachy2_sdk.media.camera_manager import CameraManager @@ -18,6 +19,14 @@ class ReachyCameraConfig: rotation: int | None = None mock: bool = False + def __post_init__(self): + if self.color_mode not in ["rgb", "bgr"]: + raise ValueError( + f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided." + ) + + self.channels = 3 + class ReachyCamera: def __init__( @@ -29,8 +38,18 @@ class ReachyCamera: config: ReachyCameraConfig | None = None, **kwargs, ): + if config is None: + config = ReachyCameraConfig() + + # Overwrite config arguments using kwargs + config = replace(config, **kwargs) + self.host = host self.port = port + self.width = config.width + self.height = config.height + self.channels = config.channels + self.fps = config.fps self.image_type = image_type self.name = name self.config = config @@ -48,21 +67,24 @@ class ReachyCamera: if not self.is_connected: self.connect() + frame = None + if self.name == "teleop" and hasattr(self.cam_manager, "teleop"): if self.image_type == "left": - return self.cam_manager.teleop.get_frame(CameraView.LEFT) - # return self.cam_manager.teleop.get_compressed_frame(CameraView.LEFT) + frame = self.cam_manager.teleop.get_frame(CameraView.LEFT) elif self.image_type == "right": - return self.cam_manager.teleop.get_frame(CameraView.RIGHT) - # return self.cam_manager.teleop.get_compressed_frame(CameraView.RIGHT) - else: - return None + frame = self.cam_manager.teleop.get_frame(CameraView.RIGHT) elif self.name == "depth" and hasattr(self.cam_manager, "depth"): if self.image_type == "depth": - return self.cam_manager.depth.get_depth_frame() + frame = self.cam_manager.depth.get_depth_frame() elif self.image_type == "rgb": - return self.cam_manager.depth.get_frame() - # return self.cam_manager.depth.get_compressed_frame() - else: - return None - return None + frame = self.cam_manager.depth.get_frame() + + if frame is None: + return None + + if frame is not None and self.config.color_mode == "rgb": + img, timestamp = frame + frame = (cv2.cvtColor(img, cv2.COLOR_BGR2RGB), timestamp) + + return frame diff --git a/lerobot/common/robot_devices/control_utils.py b/lerobot/common/robot_devices/control_utils.py index 87690217..b2d54a66 100644 --- a/lerobot/common/robot_devices/control_utils.py +++ b/lerobot/common/robot_devices/control_utils.py @@ -46,7 +46,7 @@ def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, f log_dt("dt", dt_s) # TODO(aliberts): move robot-specific logs logic in robot.print_logs() - if not robot.robot_type.startswith(("stretch", "Reachy")): + if not robot.robot_type.lower().startswith(("stretch", "reachy")): for name in robot.leader_arms: key = f"read_leader_{name}_pos_dt_s" if key in robot.logs: diff --git a/lerobot/common/robot_devices/robots/reachy2.py b/lerobot/common/robot_devices/robots/reachy2.py index 2667f3e2..d048d6f9 100644 --- a/lerobot/common/robot_devices/robots/reachy2.py +++ b/lerobot/common/robot_devices/robots/reachy2.py @@ -18,10 +18,11 @@ import time from copy import copy from dataclasses import dataclass, field, replace +import numpy as np import torch from reachy2_sdk import ReachySDK -from lerobot.common.robot_devices.cameras.utils import Camera +from lerobot.common.robot_devices.cameras.reachy2 import ReachyCamera REACHY_MOTORS = [ "neck_yaw.pos", @@ -52,8 +53,9 @@ REACHY_MOTORS = [ @dataclass class ReachyRobotConfig: robot_type: str | None = "reachy2" - cameras: dict[str, Camera] = field(default_factory=lambda: {}) + cameras: dict[str, ReachyCamera] = field(default_factory=lambda: {}) ip_address: str | None = "172.17.135.207" + # ip_address: str | None = "192.168.0.197" # ip_address: str | None = "localhost" @@ -74,10 +76,8 @@ class ReachyRobot: self.is_connected = False self.teleop = None self.logs = {} - self.reachy: ReachySDK = ReachySDK(host=config.ip_address) - self.reachy.turn_on() - self.is_connected = True # at init Reachy2 is in fact connected... - self.mobile_base_available = self.reachy.mobile_base is not None + self.reachy = None + self.mobile_base_available = False self.state_keys = None self.action_keys = None @@ -96,16 +96,19 @@ class ReachyRobot: @property def motor_features(self) -> dict: + motors = REACHY_MOTORS + # if self.mobile_base_available: + # motors += REACHY_MOBILE_BASE return { "action": { "dtype": "float32", - "shape": (len(REACHY_MOTORS),), - "names": REACHY_MOTORS, + "shape": (len(motors),), + "names": motors, }, "observation.state": { "dtype": "float32", - "shape": (len(REACHY_MOTORS),), - "names": REACHY_MOTORS, + "shape": (len(motors),), + "names": motors, }, } @@ -114,14 +117,16 @@ class ReachyRobot: return {**self.motor_features, **self.camera_features} def connect(self) -> None: + self.reachy = ReachySDK(host=self.config.ip_address) print("Connecting to Reachy") - self.reachy.is_connected = self.reachy.connect() + self.reachy.connect() + self.is_connected = self.reachy.is_connected if not self.is_connected: print( f"Cannot connect to Reachy at address {self.config.ip_address}. Maybe a connection already exists." ) raise ConnectionError() - self.reachy.turn_on() + # self.reachy.turn_on() print(self.cameras) if self.cameras is not None: for name in self.cameras: @@ -133,6 +138,8 @@ class ReachyRobot: print("Could not connect to the cameras, check that all cameras are plugged-in.") raise ConnectionError() + self.mobile_base_available = self.reachy.mobile_base is not None + def run_calibration(self): pass @@ -169,8 +176,14 @@ class ReachyRobot: action["mobile_base_x.vel"] = last_cmd_vel["x"] action["mobile_base_y.vel"] = last_cmd_vel["y"] action["mobile_base_theta.vel"] = last_cmd_vel["theta"] + else: + action["mobile_base_x.vel"] = 0 + action["mobile_base_y.vel"] = 0 + action["mobile_base_theta.vel"] = 0 - action = torch.as_tensor(list(action.values())) + dtype = self.motor_features["action"]["dtype"] + action = np.array(list(action.values()), dtype=dtype) + # action = torch.as_tensor(list(action.values())) obs_dict = self.capture_observation() action_dict = {} @@ -224,7 +237,9 @@ class ReachyRobot: if self.state_keys is None: self.state_keys = list(state) - state = torch.as_tensor(list(state.values())) + dtype = self.motor_features["observation.state"]["dtype"] + state = np.array(list(state.values()), dtype=dtype) + # state = torch.as_tensor(list(state.values())) # Capture images from cameras images = {} @@ -233,6 +248,7 @@ class ReachyRobot: images[name] = self.cameras[name].read() # Reachy cameras read() is not blocking? # print(f'name: {name} img: {images[name]}') if images[name] is not None: + # images[name] = copy(images[name][0]) # seems like I need to copy? images[name] = torch.from_numpy(copy(images[name][0])) # seems like I need to copy? self.logs[f"read_camera_{name}_dt_s"] = images[name][1] # full timestamp, TODO dt @@ -295,7 +311,7 @@ class ReachyRobot: print("Disconnecting") self.is_connected = False print("Turn off") - self.reachy.turn_off_smoothly() + # self.reachy.turn_off_smoothly() # self.reachy.turn_off() print("\t turn off done") self.reachy.disconnect() diff --git a/lerobot/configs/robot/reachy2.yaml b/lerobot/configs/robot/reachy2.yaml index 5ec9c23c..4fbfb75e 100644 --- a/lerobot/configs/robot/reachy2.yaml +++ b/lerobot/configs/robot/reachy2.yaml @@ -12,9 +12,13 @@ cameras: head_left: _target_: lerobot.common.robot_devices.cameras.reachy2.ReachyCamera name: teleop - host: 172.17.135.207 + host: 172.17.134.85 + # host: 192.168.0.197 # host: localhost port: 50065 + fps: 30 + width: 960 + height: 720 image_type: left # head_right: # _target_: lerobot.common.robot_devices.cameras.reachy2.ReachyCamera @@ -22,6 +26,9 @@ cameras: # host: 172.17.135.207 # port: 50065 # image_type: right + # fps: 30 + # width: 960 + # height: 720 # torso_rgb: # _target_: lerobot.common.robot_devices.cameras.reachy2.ReachyCamera # name: depth @@ -29,9 +36,15 @@ cameras: # # host: localhost # port: 50065 # image_type: rgb + # fps: 30 + # width: 1280 + # height: 720 # torso_depth: # _target_: lerobot.common.robot_devices.cameras.reachy2.ReachyCamera # name: depth # host: 172.17.135.207 # port: 50065 # image_type: depth + # fps: 30 + # width: 1280 + # height: 720 diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index ad73eef4..3eac60ea 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -191,7 +191,7 @@ def teleoperate( @safe_disconnect def record( robot: Robot, - root: str, + root: Path, repo_id: str, single_task: str, pretrained_policy_name_or_path: str | None = None, @@ -204,6 +204,7 @@ def record( video: bool = True, run_compute_stats: bool = True, push_to_hub: bool = True, + tags: list[str] | None = None, num_image_writer_processes: int = 0, num_image_writer_threads_per_camera: int = 4, display_cameras: bool = True, @@ -331,7 +332,7 @@ def record( dataset.consolidate(run_compute_stats) if push_to_hub: - dataset.push_to_hub() + dataset.push_to_hub(tags=tags) log_say("Exiting", play_sounds) return dataset @@ -427,7 +428,7 @@ if __name__ == "__main__": parser_record.add_argument( "--root", type=Path, - default="data", + default=None, help="Root directory where the dataset will be stored locally at '{root}/{repo_id}' (e.g. 'data/hf_username/dataset_name').", ) parser_record.add_argument( @@ -436,6 +437,12 @@ if __name__ == "__main__": default="lerobot/test", help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).", ) + parser_record.add_argument( + "--resume", + type=int, + default=0, + help="Resume recording on an existing dataset.", + ) parser_record.add_argument( "--warmup-time-s", type=int, @@ -494,12 +501,6 @@ if __name__ == "__main__": "Not enough threads might cause low camera fps." ), ) - parser_record.add_argument( - "--force-override", - type=int, - default=0, - help="By default, data recording is resumed. When set to 1, delete the local directory and start data recording from scratch.", - ) parser_record.add_argument( "-p", "--pretrained-policy-name-or-path", @@ -523,7 +524,7 @@ if __name__ == "__main__": parser_replay.add_argument( "--root", type=Path, - default="data", + default=None, help="Root directory where the dataset will be stored locally at '{root}/{repo_id}' (e.g. 'data/hf_username/dataset_name').", ) parser_replay.add_argument(