This commit is contained in:
Simon Alibert 2024-11-27 11:11:54 +01:00
parent 31429e82d0
commit 6366c7f46e
5 changed files with 92 additions and 40 deletions

View File

@ -2,8 +2,9 @@
Wrapper for Reachy2 camera from sdk Wrapper for Reachy2 camera from sdk
""" """
from dataclasses import dataclass from dataclasses import dataclass, replace
import cv2
import numpy as np import numpy as np
from reachy2_sdk.media.camera import CameraView from reachy2_sdk.media.camera import CameraView
from reachy2_sdk.media.camera_manager import CameraManager from reachy2_sdk.media.camera_manager import CameraManager
@ -18,6 +19,14 @@ class ReachyCameraConfig:
rotation: int | None = None rotation: int | None = None
mock: bool = False mock: bool = False
def __post_init__(self):
if self.color_mode not in ["rgb", "bgr"]:
raise ValueError(
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
)
self.channels = 3
class ReachyCamera: class ReachyCamera:
def __init__( def __init__(
@ -29,8 +38,18 @@ class ReachyCamera:
config: ReachyCameraConfig | None = None, config: ReachyCameraConfig | None = None,
**kwargs, **kwargs,
): ):
if config is None:
config = ReachyCameraConfig()
# Overwrite config arguments using kwargs
config = replace(config, **kwargs)
self.host = host self.host = host
self.port = port self.port = port
self.width = config.width
self.height = config.height
self.channels = config.channels
self.fps = config.fps
self.image_type = image_type self.image_type = image_type
self.name = name self.name = name
self.config = config self.config = config
@ -48,21 +67,24 @@ class ReachyCamera:
if not self.is_connected: if not self.is_connected:
self.connect() self.connect()
frame = None
if self.name == "teleop" and hasattr(self.cam_manager, "teleop"): if self.name == "teleop" and hasattr(self.cam_manager, "teleop"):
if self.image_type == "left": if self.image_type == "left":
return self.cam_manager.teleop.get_frame(CameraView.LEFT) frame = self.cam_manager.teleop.get_frame(CameraView.LEFT)
# return self.cam_manager.teleop.get_compressed_frame(CameraView.LEFT)
elif self.image_type == "right": elif self.image_type == "right":
return self.cam_manager.teleop.get_frame(CameraView.RIGHT) frame = self.cam_manager.teleop.get_frame(CameraView.RIGHT)
# return self.cam_manager.teleop.get_compressed_frame(CameraView.RIGHT)
else:
return None
elif self.name == "depth" and hasattr(self.cam_manager, "depth"): elif self.name == "depth" and hasattr(self.cam_manager, "depth"):
if self.image_type == "depth": if self.image_type == "depth":
return self.cam_manager.depth.get_depth_frame() frame = self.cam_manager.depth.get_depth_frame()
elif self.image_type == "rgb": elif self.image_type == "rgb":
return self.cam_manager.depth.get_frame() frame = self.cam_manager.depth.get_frame()
# return self.cam_manager.depth.get_compressed_frame()
else: if frame is None:
return None return None
return None
if frame is not None and self.config.color_mode == "rgb":
img, timestamp = frame
frame = (cv2.cvtColor(img, cv2.COLOR_BGR2RGB), timestamp)
return frame

View File

@ -46,7 +46,7 @@ def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, f
log_dt("dt", dt_s) log_dt("dt", dt_s)
# TODO(aliberts): move robot-specific logs logic in robot.print_logs() # TODO(aliberts): move robot-specific logs logic in robot.print_logs()
if not robot.robot_type.startswith(("stretch", "Reachy")): if not robot.robot_type.lower().startswith(("stretch", "reachy")):
for name in robot.leader_arms: for name in robot.leader_arms:
key = f"read_leader_{name}_pos_dt_s" key = f"read_leader_{name}_pos_dt_s"
if key in robot.logs: if key in robot.logs:

View File

@ -18,10 +18,11 @@ import time
from copy import copy from copy import copy
from dataclasses import dataclass, field, replace from dataclasses import dataclass, field, replace
import numpy as np
import torch import torch
from reachy2_sdk import ReachySDK from reachy2_sdk import ReachySDK
from lerobot.common.robot_devices.cameras.utils import Camera from lerobot.common.robot_devices.cameras.reachy2 import ReachyCamera
REACHY_MOTORS = [ REACHY_MOTORS = [
"neck_yaw.pos", "neck_yaw.pos",
@ -52,8 +53,9 @@ REACHY_MOTORS = [
@dataclass @dataclass
class ReachyRobotConfig: class ReachyRobotConfig:
robot_type: str | None = "reachy2" robot_type: str | None = "reachy2"
cameras: dict[str, Camera] = field(default_factory=lambda: {}) cameras: dict[str, ReachyCamera] = field(default_factory=lambda: {})
ip_address: str | None = "172.17.135.207" ip_address: str | None = "172.17.135.207"
# ip_address: str | None = "192.168.0.197"
# ip_address: str | None = "localhost" # ip_address: str | None = "localhost"
@ -74,10 +76,8 @@ class ReachyRobot:
self.is_connected = False self.is_connected = False
self.teleop = None self.teleop = None
self.logs = {} self.logs = {}
self.reachy: ReachySDK = ReachySDK(host=config.ip_address) self.reachy = None
self.reachy.turn_on() self.mobile_base_available = False
self.is_connected = True # at init Reachy2 is in fact connected...
self.mobile_base_available = self.reachy.mobile_base is not None
self.state_keys = None self.state_keys = None
self.action_keys = None self.action_keys = None
@ -96,16 +96,19 @@ class ReachyRobot:
@property @property
def motor_features(self) -> dict: def motor_features(self) -> dict:
motors = REACHY_MOTORS
# if self.mobile_base_available:
# motors += REACHY_MOBILE_BASE
return { return {
"action": { "action": {
"dtype": "float32", "dtype": "float32",
"shape": (len(REACHY_MOTORS),), "shape": (len(motors),),
"names": REACHY_MOTORS, "names": motors,
}, },
"observation.state": { "observation.state": {
"dtype": "float32", "dtype": "float32",
"shape": (len(REACHY_MOTORS),), "shape": (len(motors),),
"names": REACHY_MOTORS, "names": motors,
}, },
} }
@ -114,14 +117,16 @@ class ReachyRobot:
return {**self.motor_features, **self.camera_features} return {**self.motor_features, **self.camera_features}
def connect(self) -> None: def connect(self) -> None:
self.reachy = ReachySDK(host=self.config.ip_address)
print("Connecting to Reachy") print("Connecting to Reachy")
self.reachy.is_connected = self.reachy.connect() self.reachy.connect()
self.is_connected = self.reachy.is_connected
if not self.is_connected: if not self.is_connected:
print( print(
f"Cannot connect to Reachy at address {self.config.ip_address}. Maybe a connection already exists." f"Cannot connect to Reachy at address {self.config.ip_address}. Maybe a connection already exists."
) )
raise ConnectionError() raise ConnectionError()
self.reachy.turn_on() # self.reachy.turn_on()
print(self.cameras) print(self.cameras)
if self.cameras is not None: if self.cameras is not None:
for name in self.cameras: for name in self.cameras:
@ -133,6 +138,8 @@ class ReachyRobot:
print("Could not connect to the cameras, check that all cameras are plugged-in.") print("Could not connect to the cameras, check that all cameras are plugged-in.")
raise ConnectionError() raise ConnectionError()
self.mobile_base_available = self.reachy.mobile_base is not None
def run_calibration(self): def run_calibration(self):
pass pass
@ -169,8 +176,14 @@ class ReachyRobot:
action["mobile_base_x.vel"] = last_cmd_vel["x"] action["mobile_base_x.vel"] = last_cmd_vel["x"]
action["mobile_base_y.vel"] = last_cmd_vel["y"] action["mobile_base_y.vel"] = last_cmd_vel["y"]
action["mobile_base_theta.vel"] = last_cmd_vel["theta"] action["mobile_base_theta.vel"] = last_cmd_vel["theta"]
else:
action["mobile_base_x.vel"] = 0
action["mobile_base_y.vel"] = 0
action["mobile_base_theta.vel"] = 0
action = torch.as_tensor(list(action.values())) dtype = self.motor_features["action"]["dtype"]
action = np.array(list(action.values()), dtype=dtype)
# action = torch.as_tensor(list(action.values()))
obs_dict = self.capture_observation() obs_dict = self.capture_observation()
action_dict = {} action_dict = {}
@ -224,7 +237,9 @@ class ReachyRobot:
if self.state_keys is None: if self.state_keys is None:
self.state_keys = list(state) self.state_keys = list(state)
state = torch.as_tensor(list(state.values())) dtype = self.motor_features["observation.state"]["dtype"]
state = np.array(list(state.values()), dtype=dtype)
# state = torch.as_tensor(list(state.values()))
# Capture images from cameras # Capture images from cameras
images = {} images = {}
@ -233,6 +248,7 @@ class ReachyRobot:
images[name] = self.cameras[name].read() # Reachy cameras read() is not blocking? images[name] = self.cameras[name].read() # Reachy cameras read() is not blocking?
# print(f'name: {name} img: {images[name]}') # print(f'name: {name} img: {images[name]}')
if images[name] is not None: if images[name] is not None:
# images[name] = copy(images[name][0]) # seems like I need to copy?
images[name] = torch.from_numpy(copy(images[name][0])) # seems like I need to copy? images[name] = torch.from_numpy(copy(images[name][0])) # seems like I need to copy?
self.logs[f"read_camera_{name}_dt_s"] = images[name][1] # full timestamp, TODO dt self.logs[f"read_camera_{name}_dt_s"] = images[name][1] # full timestamp, TODO dt
@ -295,7 +311,7 @@ class ReachyRobot:
print("Disconnecting") print("Disconnecting")
self.is_connected = False self.is_connected = False
print("Turn off") print("Turn off")
self.reachy.turn_off_smoothly() # self.reachy.turn_off_smoothly()
# self.reachy.turn_off() # self.reachy.turn_off()
print("\t turn off done") print("\t turn off done")
self.reachy.disconnect() self.reachy.disconnect()

View File

@ -12,9 +12,13 @@ cameras:
head_left: head_left:
_target_: lerobot.common.robot_devices.cameras.reachy2.ReachyCamera _target_: lerobot.common.robot_devices.cameras.reachy2.ReachyCamera
name: teleop name: teleop
host: 172.17.135.207 host: 172.17.134.85
# host: 192.168.0.197
# host: localhost # host: localhost
port: 50065 port: 50065
fps: 30
width: 960
height: 720
image_type: left image_type: left
# head_right: # head_right:
# _target_: lerobot.common.robot_devices.cameras.reachy2.ReachyCamera # _target_: lerobot.common.robot_devices.cameras.reachy2.ReachyCamera
@ -22,6 +26,9 @@ cameras:
# host: 172.17.135.207 # host: 172.17.135.207
# port: 50065 # port: 50065
# image_type: right # image_type: right
# fps: 30
# width: 960
# height: 720
# torso_rgb: # torso_rgb:
# _target_: lerobot.common.robot_devices.cameras.reachy2.ReachyCamera # _target_: lerobot.common.robot_devices.cameras.reachy2.ReachyCamera
# name: depth # name: depth
@ -29,9 +36,15 @@ cameras:
# # host: localhost # # host: localhost
# port: 50065 # port: 50065
# image_type: rgb # image_type: rgb
# fps: 30
# width: 1280
# height: 720
# torso_depth: # torso_depth:
# _target_: lerobot.common.robot_devices.cameras.reachy2.ReachyCamera # _target_: lerobot.common.robot_devices.cameras.reachy2.ReachyCamera
# name: depth # name: depth
# host: 172.17.135.207 # host: 172.17.135.207
# port: 50065 # port: 50065
# image_type: depth # image_type: depth
# fps: 30
# width: 1280
# height: 720

View File

@ -191,7 +191,7 @@ def teleoperate(
@safe_disconnect @safe_disconnect
def record( def record(
robot: Robot, robot: Robot,
root: str, root: Path,
repo_id: str, repo_id: str,
single_task: str, single_task: str,
pretrained_policy_name_or_path: str | None = None, pretrained_policy_name_or_path: str | None = None,
@ -204,6 +204,7 @@ def record(
video: bool = True, video: bool = True,
run_compute_stats: bool = True, run_compute_stats: bool = True,
push_to_hub: bool = True, push_to_hub: bool = True,
tags: list[str] | None = None,
num_image_writer_processes: int = 0, num_image_writer_processes: int = 0,
num_image_writer_threads_per_camera: int = 4, num_image_writer_threads_per_camera: int = 4,
display_cameras: bool = True, display_cameras: bool = True,
@ -331,7 +332,7 @@ def record(
dataset.consolidate(run_compute_stats) dataset.consolidate(run_compute_stats)
if push_to_hub: if push_to_hub:
dataset.push_to_hub() dataset.push_to_hub(tags=tags)
log_say("Exiting", play_sounds) log_say("Exiting", play_sounds)
return dataset return dataset
@ -427,7 +428,7 @@ if __name__ == "__main__":
parser_record.add_argument( parser_record.add_argument(
"--root", "--root",
type=Path, type=Path,
default="data", default=None,
help="Root directory where the dataset will be stored locally at '{root}/{repo_id}' (e.g. 'data/hf_username/dataset_name').", help="Root directory where the dataset will be stored locally at '{root}/{repo_id}' (e.g. 'data/hf_username/dataset_name').",
) )
parser_record.add_argument( parser_record.add_argument(
@ -436,6 +437,12 @@ if __name__ == "__main__":
default="lerobot/test", default="lerobot/test",
help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).", help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
) )
parser_record.add_argument(
"--resume",
type=int,
default=0,
help="Resume recording on an existing dataset.",
)
parser_record.add_argument( parser_record.add_argument(
"--warmup-time-s", "--warmup-time-s",
type=int, type=int,
@ -494,12 +501,6 @@ if __name__ == "__main__":
"Not enough threads might cause low camera fps." "Not enough threads might cause low camera fps."
), ),
) )
parser_record.add_argument(
"--force-override",
type=int,
default=0,
help="By default, data recording is resumed. When set to 1, delete the local directory and start data recording from scratch.",
)
parser_record.add_argument( parser_record.add_argument(
"-p", "-p",
"--pretrained-policy-name-or-path", "--pretrained-policy-name-or-path",
@ -523,7 +524,7 @@ if __name__ == "__main__":
parser_replay.add_argument( parser_replay.add_argument(
"--root", "--root",
type=Path, type=Path,
default="data", default=None,
help="Root directory where the dataset will be stored locally at '{root}/{repo_id}' (e.g. 'data/hf_username/dataset_name').", help="Root directory where the dataset will be stored locally at '{root}/{repo_id}' (e.g. 'data/hf_username/dataset_name').",
) )
parser_replay.add_argument( parser_replay.add_argument(