WIP
This commit is contained in:
parent
31429e82d0
commit
6366c7f46e
|
@ -2,8 +2,9 @@
|
||||||
Wrapper for Reachy2 camera from sdk
|
Wrapper for Reachy2 camera from sdk
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, replace
|
||||||
|
|
||||||
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from reachy2_sdk.media.camera import CameraView
|
from reachy2_sdk.media.camera import CameraView
|
||||||
from reachy2_sdk.media.camera_manager import CameraManager
|
from reachy2_sdk.media.camera_manager import CameraManager
|
||||||
|
@ -18,6 +19,14 @@ class ReachyCameraConfig:
|
||||||
rotation: int | None = None
|
rotation: int | None = None
|
||||||
mock: bool = False
|
mock: bool = False
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.color_mode not in ["rgb", "bgr"]:
|
||||||
|
raise ValueError(
|
||||||
|
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.channels = 3
|
||||||
|
|
||||||
|
|
||||||
class ReachyCamera:
|
class ReachyCamera:
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -29,8 +38,18 @@ class ReachyCamera:
|
||||||
config: ReachyCameraConfig | None = None,
|
config: ReachyCameraConfig | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
if config is None:
|
||||||
|
config = ReachyCameraConfig()
|
||||||
|
|
||||||
|
# Overwrite config arguments using kwargs
|
||||||
|
config = replace(config, **kwargs)
|
||||||
|
|
||||||
self.host = host
|
self.host = host
|
||||||
self.port = port
|
self.port = port
|
||||||
|
self.width = config.width
|
||||||
|
self.height = config.height
|
||||||
|
self.channels = config.channels
|
||||||
|
self.fps = config.fps
|
||||||
self.image_type = image_type
|
self.image_type = image_type
|
||||||
self.name = name
|
self.name = name
|
||||||
self.config = config
|
self.config = config
|
||||||
|
@ -48,21 +67,24 @@ class ReachyCamera:
|
||||||
if not self.is_connected:
|
if not self.is_connected:
|
||||||
self.connect()
|
self.connect()
|
||||||
|
|
||||||
|
frame = None
|
||||||
|
|
||||||
if self.name == "teleop" and hasattr(self.cam_manager, "teleop"):
|
if self.name == "teleop" and hasattr(self.cam_manager, "teleop"):
|
||||||
if self.image_type == "left":
|
if self.image_type == "left":
|
||||||
return self.cam_manager.teleop.get_frame(CameraView.LEFT)
|
frame = self.cam_manager.teleop.get_frame(CameraView.LEFT)
|
||||||
# return self.cam_manager.teleop.get_compressed_frame(CameraView.LEFT)
|
|
||||||
elif self.image_type == "right":
|
elif self.image_type == "right":
|
||||||
return self.cam_manager.teleop.get_frame(CameraView.RIGHT)
|
frame = self.cam_manager.teleop.get_frame(CameraView.RIGHT)
|
||||||
# return self.cam_manager.teleop.get_compressed_frame(CameraView.RIGHT)
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
elif self.name == "depth" and hasattr(self.cam_manager, "depth"):
|
elif self.name == "depth" and hasattr(self.cam_manager, "depth"):
|
||||||
if self.image_type == "depth":
|
if self.image_type == "depth":
|
||||||
return self.cam_manager.depth.get_depth_frame()
|
frame = self.cam_manager.depth.get_depth_frame()
|
||||||
elif self.image_type == "rgb":
|
elif self.image_type == "rgb":
|
||||||
return self.cam_manager.depth.get_frame()
|
frame = self.cam_manager.depth.get_frame()
|
||||||
# return self.cam_manager.depth.get_compressed_frame()
|
|
||||||
else:
|
if frame is None:
|
||||||
return None
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
if frame is not None and self.config.color_mode == "rgb":
|
||||||
|
img, timestamp = frame
|
||||||
|
frame = (cv2.cvtColor(img, cv2.COLOR_BGR2RGB), timestamp)
|
||||||
|
|
||||||
|
return frame
|
||||||
|
|
|
@ -46,7 +46,7 @@ def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, f
|
||||||
log_dt("dt", dt_s)
|
log_dt("dt", dt_s)
|
||||||
|
|
||||||
# TODO(aliberts): move robot-specific logs logic in robot.print_logs()
|
# TODO(aliberts): move robot-specific logs logic in robot.print_logs()
|
||||||
if not robot.robot_type.startswith(("stretch", "Reachy")):
|
if not robot.robot_type.lower().startswith(("stretch", "reachy")):
|
||||||
for name in robot.leader_arms:
|
for name in robot.leader_arms:
|
||||||
key = f"read_leader_{name}_pos_dt_s"
|
key = f"read_leader_{name}_pos_dt_s"
|
||||||
if key in robot.logs:
|
if key in robot.logs:
|
||||||
|
|
|
@ -18,10 +18,11 @@ import time
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from dataclasses import dataclass, field, replace
|
from dataclasses import dataclass, field, replace
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from reachy2_sdk import ReachySDK
|
from reachy2_sdk import ReachySDK
|
||||||
|
|
||||||
from lerobot.common.robot_devices.cameras.utils import Camera
|
from lerobot.common.robot_devices.cameras.reachy2 import ReachyCamera
|
||||||
|
|
||||||
REACHY_MOTORS = [
|
REACHY_MOTORS = [
|
||||||
"neck_yaw.pos",
|
"neck_yaw.pos",
|
||||||
|
@ -52,8 +53,9 @@ REACHY_MOTORS = [
|
||||||
@dataclass
|
@dataclass
|
||||||
class ReachyRobotConfig:
|
class ReachyRobotConfig:
|
||||||
robot_type: str | None = "reachy2"
|
robot_type: str | None = "reachy2"
|
||||||
cameras: dict[str, Camera] = field(default_factory=lambda: {})
|
cameras: dict[str, ReachyCamera] = field(default_factory=lambda: {})
|
||||||
ip_address: str | None = "172.17.135.207"
|
ip_address: str | None = "172.17.135.207"
|
||||||
|
# ip_address: str | None = "192.168.0.197"
|
||||||
# ip_address: str | None = "localhost"
|
# ip_address: str | None = "localhost"
|
||||||
|
|
||||||
|
|
||||||
|
@ -74,10 +76,8 @@ class ReachyRobot:
|
||||||
self.is_connected = False
|
self.is_connected = False
|
||||||
self.teleop = None
|
self.teleop = None
|
||||||
self.logs = {}
|
self.logs = {}
|
||||||
self.reachy: ReachySDK = ReachySDK(host=config.ip_address)
|
self.reachy = None
|
||||||
self.reachy.turn_on()
|
self.mobile_base_available = False
|
||||||
self.is_connected = True # at init Reachy2 is in fact connected...
|
|
||||||
self.mobile_base_available = self.reachy.mobile_base is not None
|
|
||||||
|
|
||||||
self.state_keys = None
|
self.state_keys = None
|
||||||
self.action_keys = None
|
self.action_keys = None
|
||||||
|
@ -96,16 +96,19 @@ class ReachyRobot:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def motor_features(self) -> dict:
|
def motor_features(self) -> dict:
|
||||||
|
motors = REACHY_MOTORS
|
||||||
|
# if self.mobile_base_available:
|
||||||
|
# motors += REACHY_MOBILE_BASE
|
||||||
return {
|
return {
|
||||||
"action": {
|
"action": {
|
||||||
"dtype": "float32",
|
"dtype": "float32",
|
||||||
"shape": (len(REACHY_MOTORS),),
|
"shape": (len(motors),),
|
||||||
"names": REACHY_MOTORS,
|
"names": motors,
|
||||||
},
|
},
|
||||||
"observation.state": {
|
"observation.state": {
|
||||||
"dtype": "float32",
|
"dtype": "float32",
|
||||||
"shape": (len(REACHY_MOTORS),),
|
"shape": (len(motors),),
|
||||||
"names": REACHY_MOTORS,
|
"names": motors,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -114,14 +117,16 @@ class ReachyRobot:
|
||||||
return {**self.motor_features, **self.camera_features}
|
return {**self.motor_features, **self.camera_features}
|
||||||
|
|
||||||
def connect(self) -> None:
|
def connect(self) -> None:
|
||||||
|
self.reachy = ReachySDK(host=self.config.ip_address)
|
||||||
print("Connecting to Reachy")
|
print("Connecting to Reachy")
|
||||||
self.reachy.is_connected = self.reachy.connect()
|
self.reachy.connect()
|
||||||
|
self.is_connected = self.reachy.is_connected
|
||||||
if not self.is_connected:
|
if not self.is_connected:
|
||||||
print(
|
print(
|
||||||
f"Cannot connect to Reachy at address {self.config.ip_address}. Maybe a connection already exists."
|
f"Cannot connect to Reachy at address {self.config.ip_address}. Maybe a connection already exists."
|
||||||
)
|
)
|
||||||
raise ConnectionError()
|
raise ConnectionError()
|
||||||
self.reachy.turn_on()
|
# self.reachy.turn_on()
|
||||||
print(self.cameras)
|
print(self.cameras)
|
||||||
if self.cameras is not None:
|
if self.cameras is not None:
|
||||||
for name in self.cameras:
|
for name in self.cameras:
|
||||||
|
@ -133,6 +138,8 @@ class ReachyRobot:
|
||||||
print("Could not connect to the cameras, check that all cameras are plugged-in.")
|
print("Could not connect to the cameras, check that all cameras are plugged-in.")
|
||||||
raise ConnectionError()
|
raise ConnectionError()
|
||||||
|
|
||||||
|
self.mobile_base_available = self.reachy.mobile_base is not None
|
||||||
|
|
||||||
def run_calibration(self):
|
def run_calibration(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -169,8 +176,14 @@ class ReachyRobot:
|
||||||
action["mobile_base_x.vel"] = last_cmd_vel["x"]
|
action["mobile_base_x.vel"] = last_cmd_vel["x"]
|
||||||
action["mobile_base_y.vel"] = last_cmd_vel["y"]
|
action["mobile_base_y.vel"] = last_cmd_vel["y"]
|
||||||
action["mobile_base_theta.vel"] = last_cmd_vel["theta"]
|
action["mobile_base_theta.vel"] = last_cmd_vel["theta"]
|
||||||
|
else:
|
||||||
|
action["mobile_base_x.vel"] = 0
|
||||||
|
action["mobile_base_y.vel"] = 0
|
||||||
|
action["mobile_base_theta.vel"] = 0
|
||||||
|
|
||||||
action = torch.as_tensor(list(action.values()))
|
dtype = self.motor_features["action"]["dtype"]
|
||||||
|
action = np.array(list(action.values()), dtype=dtype)
|
||||||
|
# action = torch.as_tensor(list(action.values()))
|
||||||
|
|
||||||
obs_dict = self.capture_observation()
|
obs_dict = self.capture_observation()
|
||||||
action_dict = {}
|
action_dict = {}
|
||||||
|
@ -224,7 +237,9 @@ class ReachyRobot:
|
||||||
if self.state_keys is None:
|
if self.state_keys is None:
|
||||||
self.state_keys = list(state)
|
self.state_keys = list(state)
|
||||||
|
|
||||||
state = torch.as_tensor(list(state.values()))
|
dtype = self.motor_features["observation.state"]["dtype"]
|
||||||
|
state = np.array(list(state.values()), dtype=dtype)
|
||||||
|
# state = torch.as_tensor(list(state.values()))
|
||||||
|
|
||||||
# Capture images from cameras
|
# Capture images from cameras
|
||||||
images = {}
|
images = {}
|
||||||
|
@ -233,6 +248,7 @@ class ReachyRobot:
|
||||||
images[name] = self.cameras[name].read() # Reachy cameras read() is not blocking?
|
images[name] = self.cameras[name].read() # Reachy cameras read() is not blocking?
|
||||||
# print(f'name: {name} img: {images[name]}')
|
# print(f'name: {name} img: {images[name]}')
|
||||||
if images[name] is not None:
|
if images[name] is not None:
|
||||||
|
# images[name] = copy(images[name][0]) # seems like I need to copy?
|
||||||
images[name] = torch.from_numpy(copy(images[name][0])) # seems like I need to copy?
|
images[name] = torch.from_numpy(copy(images[name][0])) # seems like I need to copy?
|
||||||
self.logs[f"read_camera_{name}_dt_s"] = images[name][1] # full timestamp, TODO dt
|
self.logs[f"read_camera_{name}_dt_s"] = images[name][1] # full timestamp, TODO dt
|
||||||
|
|
||||||
|
@ -295,7 +311,7 @@ class ReachyRobot:
|
||||||
print("Disconnecting")
|
print("Disconnecting")
|
||||||
self.is_connected = False
|
self.is_connected = False
|
||||||
print("Turn off")
|
print("Turn off")
|
||||||
self.reachy.turn_off_smoothly()
|
# self.reachy.turn_off_smoothly()
|
||||||
# self.reachy.turn_off()
|
# self.reachy.turn_off()
|
||||||
print("\t turn off done")
|
print("\t turn off done")
|
||||||
self.reachy.disconnect()
|
self.reachy.disconnect()
|
||||||
|
|
|
@ -12,9 +12,13 @@ cameras:
|
||||||
head_left:
|
head_left:
|
||||||
_target_: lerobot.common.robot_devices.cameras.reachy2.ReachyCamera
|
_target_: lerobot.common.robot_devices.cameras.reachy2.ReachyCamera
|
||||||
name: teleop
|
name: teleop
|
||||||
host: 172.17.135.207
|
host: 172.17.134.85
|
||||||
|
# host: 192.168.0.197
|
||||||
# host: localhost
|
# host: localhost
|
||||||
port: 50065
|
port: 50065
|
||||||
|
fps: 30
|
||||||
|
width: 960
|
||||||
|
height: 720
|
||||||
image_type: left
|
image_type: left
|
||||||
# head_right:
|
# head_right:
|
||||||
# _target_: lerobot.common.robot_devices.cameras.reachy2.ReachyCamera
|
# _target_: lerobot.common.robot_devices.cameras.reachy2.ReachyCamera
|
||||||
|
@ -22,6 +26,9 @@ cameras:
|
||||||
# host: 172.17.135.207
|
# host: 172.17.135.207
|
||||||
# port: 50065
|
# port: 50065
|
||||||
# image_type: right
|
# image_type: right
|
||||||
|
# fps: 30
|
||||||
|
# width: 960
|
||||||
|
# height: 720
|
||||||
# torso_rgb:
|
# torso_rgb:
|
||||||
# _target_: lerobot.common.robot_devices.cameras.reachy2.ReachyCamera
|
# _target_: lerobot.common.robot_devices.cameras.reachy2.ReachyCamera
|
||||||
# name: depth
|
# name: depth
|
||||||
|
@ -29,9 +36,15 @@ cameras:
|
||||||
# # host: localhost
|
# # host: localhost
|
||||||
# port: 50065
|
# port: 50065
|
||||||
# image_type: rgb
|
# image_type: rgb
|
||||||
|
# fps: 30
|
||||||
|
# width: 1280
|
||||||
|
# height: 720
|
||||||
# torso_depth:
|
# torso_depth:
|
||||||
# _target_: lerobot.common.robot_devices.cameras.reachy2.ReachyCamera
|
# _target_: lerobot.common.robot_devices.cameras.reachy2.ReachyCamera
|
||||||
# name: depth
|
# name: depth
|
||||||
# host: 172.17.135.207
|
# host: 172.17.135.207
|
||||||
# port: 50065
|
# port: 50065
|
||||||
# image_type: depth
|
# image_type: depth
|
||||||
|
# fps: 30
|
||||||
|
# width: 1280
|
||||||
|
# height: 720
|
||||||
|
|
|
@ -191,7 +191,7 @@ def teleoperate(
|
||||||
@safe_disconnect
|
@safe_disconnect
|
||||||
def record(
|
def record(
|
||||||
robot: Robot,
|
robot: Robot,
|
||||||
root: str,
|
root: Path,
|
||||||
repo_id: str,
|
repo_id: str,
|
||||||
single_task: str,
|
single_task: str,
|
||||||
pretrained_policy_name_or_path: str | None = None,
|
pretrained_policy_name_or_path: str | None = None,
|
||||||
|
@ -204,6 +204,7 @@ def record(
|
||||||
video: bool = True,
|
video: bool = True,
|
||||||
run_compute_stats: bool = True,
|
run_compute_stats: bool = True,
|
||||||
push_to_hub: bool = True,
|
push_to_hub: bool = True,
|
||||||
|
tags: list[str] | None = None,
|
||||||
num_image_writer_processes: int = 0,
|
num_image_writer_processes: int = 0,
|
||||||
num_image_writer_threads_per_camera: int = 4,
|
num_image_writer_threads_per_camera: int = 4,
|
||||||
display_cameras: bool = True,
|
display_cameras: bool = True,
|
||||||
|
@ -331,7 +332,7 @@ def record(
|
||||||
dataset.consolidate(run_compute_stats)
|
dataset.consolidate(run_compute_stats)
|
||||||
|
|
||||||
if push_to_hub:
|
if push_to_hub:
|
||||||
dataset.push_to_hub()
|
dataset.push_to_hub(tags=tags)
|
||||||
|
|
||||||
log_say("Exiting", play_sounds)
|
log_say("Exiting", play_sounds)
|
||||||
return dataset
|
return dataset
|
||||||
|
@ -427,7 +428,7 @@ if __name__ == "__main__":
|
||||||
parser_record.add_argument(
|
parser_record.add_argument(
|
||||||
"--root",
|
"--root",
|
||||||
type=Path,
|
type=Path,
|
||||||
default="data",
|
default=None,
|
||||||
help="Root directory where the dataset will be stored locally at '{root}/{repo_id}' (e.g. 'data/hf_username/dataset_name').",
|
help="Root directory where the dataset will be stored locally at '{root}/{repo_id}' (e.g. 'data/hf_username/dataset_name').",
|
||||||
)
|
)
|
||||||
parser_record.add_argument(
|
parser_record.add_argument(
|
||||||
|
@ -436,6 +437,12 @@ if __name__ == "__main__":
|
||||||
default="lerobot/test",
|
default="lerobot/test",
|
||||||
help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
|
help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
|
||||||
)
|
)
|
||||||
|
parser_record.add_argument(
|
||||||
|
"--resume",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="Resume recording on an existing dataset.",
|
||||||
|
)
|
||||||
parser_record.add_argument(
|
parser_record.add_argument(
|
||||||
"--warmup-time-s",
|
"--warmup-time-s",
|
||||||
type=int,
|
type=int,
|
||||||
|
@ -494,12 +501,6 @@ if __name__ == "__main__":
|
||||||
"Not enough threads might cause low camera fps."
|
"Not enough threads might cause low camera fps."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
parser_record.add_argument(
|
|
||||||
"--force-override",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help="By default, data recording is resumed. When set to 1, delete the local directory and start data recording from scratch.",
|
|
||||||
)
|
|
||||||
parser_record.add_argument(
|
parser_record.add_argument(
|
||||||
"-p",
|
"-p",
|
||||||
"--pretrained-policy-name-or-path",
|
"--pretrained-policy-name-or-path",
|
||||||
|
@ -523,7 +524,7 @@ if __name__ == "__main__":
|
||||||
parser_replay.add_argument(
|
parser_replay.add_argument(
|
||||||
"--root",
|
"--root",
|
||||||
type=Path,
|
type=Path,
|
||||||
default="data",
|
default=None,
|
||||||
help="Root directory where the dataset will be stored locally at '{root}/{repo_id}' (e.g. 'data/hf_username/dataset_name').",
|
help="Root directory where the dataset will be stored locally at '{root}/{repo_id}' (e.g. 'data/hf_username/dataset_name').",
|
||||||
)
|
)
|
||||||
parser_replay.add_argument(
|
parser_replay.add_argument(
|
||||||
|
|
Loading…
Reference in New Issue