Add features + formatting

This commit is contained in:
Simon Alibert 2024-11-26 11:11:24 +01:00
parent c79d7ed146
commit 31429e82d0
4 changed files with 68 additions and 66 deletions

View File

@ -2,34 +2,15 @@
Wrapper for Reachy2 camera from sdk
"""
import argparse
import concurrent.futures
import math
import platform
import shutil
import threading
import time
from dataclasses import dataclass, replace
from pathlib import Path
from threading import Thread
from dataclasses import dataclass
import numpy as np
from PIL import Image
from lerobot.common.robot_devices.utils import (
RobotDeviceAlreadyConnectedError,
RobotDeviceNotConnectedError,
busy_wait,
)
from lerobot.common.utils.utils import capture_timestamp_utc
from reachy2_sdk.media.camera import CameraView
from reachy2_sdk.media.camera_manager import CameraManager
@dataclass
class ReachyCameraConfig:
fps: int | None = None
width: int | None = None
height: int | None = None
@ -46,7 +27,7 @@ class ReachyCamera:
name: str,
image_type: str,
config: ReachyCameraConfig | None = None,
**kwargs
**kwargs,
):
self.host = host
self.port = port
@ -64,7 +45,6 @@ class ReachyCamera:
self.is_connected = True
def read(self) -> np.ndarray:
if not self.is_connected:
self.connect()
@ -78,7 +58,6 @@ class ReachyCamera:
else:
return None
elif self.name == "depth" and hasattr(self.cam_manager, "depth"):
if self.image_type == "depth":
return self.cam_manager.depth.get_depth_frame()
elif self.image_type == "rgb":

View File

@ -18,17 +18,40 @@ import time
from copy import copy
from dataclasses import dataclass, field, replace
import numpy as np
import torch
from reachy2_sdk import ReachySDK
from lerobot.common.robot_devices.cameras.utils import Camera
REACHY_MOTORS = [
"neck_yaw.pos",
"neck_pitch.pos",
"neck_roll.pos",
"r_shoulder_pitch.pos",
"r_shoulder_roll.pos",
"r_elbow_yaw.pos",
"r_elbow_pitch.pos",
"r_wrist_roll.pos",
"r_wrist_pitch.pos",
"r_wrist_yaw.pos",
"r_gripper.pos",
"l_shoulder_pitch.pos",
"l_shoulder_roll.pos",
"l_elbow_yaw.pos",
"l_elbow_pitch.pos",
"l_wrist_roll.pos",
"l_wrist_pitch.pos",
"l_wrist_yaw.pos",
"l_gripper.pos",
"mobile_base.vx",
"mobile_base.vy",
"mobile_base.vtheta",
]
@dataclass
class ReachyRobotConfig:
robot_type: str | None = "Reachy2"
robot_type: str | None = "reachy2"
cameras: dict[str, Camera] = field(default_factory=lambda: {})
ip_address: str | None = "172.17.135.207"
# ip_address: str | None = "localhost"
@ -59,6 +82,37 @@ class ReachyRobot:
self.state_keys = None
self.action_keys = None
@property
def camera_features(self) -> dict:
cam_ft = {}
for cam_key, cam in self.cameras.items():
key = f"observation.images.{cam_key}"
cam_ft[key] = {
"shape": (cam.height, cam.width, cam.channels),
"names": ["height", "width", "channels"],
"info": None,
}
return cam_ft
@property
def motor_features(self) -> dict:
return {
"action": {
"dtype": "float32",
"shape": (len(REACHY_MOTORS),),
"names": REACHY_MOTORS,
},
"observation.state": {
"dtype": "float32",
"shape": (len(REACHY_MOTORS),),
"names": REACHY_MOTORS,
},
}
@property
def features(self):
return {**self.motor_features, **self.camera_features}
def connect(self) -> None:
print("Connecting to Reachy")
self.reachy.is_connected = self.reachy.connect()
@ -73,14 +127,10 @@ class ReachyRobot:
for name in self.cameras:
print(f"Connecting camera: {name}")
self.cameras[name].connect()
self.is_connected = (
self.is_connected and self.cameras[name].is_connected
)
self.is_connected = self.is_connected and self.cameras[name].is_connected
if not self.is_connected:
print(
"Could not connect to the cameras, check that all cameras are plugged-in."
)
print("Could not connect to the cameras, check that all cameras are plugged-in.")
raise ConnectionError()
def run_calibration(self):
@ -119,7 +169,7 @@ class ReachyRobot:
action["mobile_base_x.vel"] = last_cmd_vel["x"]
action["mobile_base_y.vel"] = last_cmd_vel["y"]
action["mobile_base_theta.vel"] = last_cmd_vel["theta"]
action = torch.as_tensor(list(action.values()))
obs_dict = self.capture_observation()
@ -179,18 +229,12 @@ class ReachyRobot:
# Capture images from cameras
images = {}
for name in self.cameras:
before_camread_t = time.perf_counter()
images[name] = self.cameras[
name
].read() # Reachy cameras read() is not blocking?
# before_camread_t = time.perf_counter()
images[name] = self.cameras[name].read() # Reachy cameras read() is not blocking?
# print(f'name: {name} img: {images[name]}')
if images[name] is not None:
images[name] = torch.from_numpy(
copy(images[name][0])
) # seems like I need to copy?
self.logs[f"read_camera_{name}_dt_s"] = images[name][
1
] # full timestamp, TODO dt
images[name] = torch.from_numpy(copy(images[name][0])) # seems like I need to copy?
self.logs[f"read_camera_{name}_dt_s"] = images[name][1] # full timestamp, TODO dt
# Populate output dictionnaries
obs_dict = {}

View File

@ -6,8 +6,7 @@
_target_: lerobot.common.robot_devices.robots.reachy2.ReachyRobot
robot_type: Reachy2
robot_type: reachy2
cameras:
head_left:

View File

@ -1,30 +1,10 @@
import argparse
import logging
import time
from pathlib import Path
from typing import List
# from safetensors.torch import load_file, save_file
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.populate_dataset import (create_lerobot_dataset,
delete_current_episode,
init_dataset,
save_current_episode)
from lerobot.common.robot_devices.control_utils import (
control_loop, has_method, init_keyboard_listener, init_policy,
log_control_info, record_episode, reset_environment,
sanity_check_dataset_name, stop_recording, warmup_record)
from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.robot_devices.utils import busy_wait, safe_disconnect
from lerobot.common.utils.utils import (init_hydra_config, init_logging,
log_say, none_or_int)
from lerobot.common.utils.utils import init_hydra_config, init_logging
import time
import cv2
if __name__ == '__main__':
if __name__ == "__main__":
init_logging()
control_mode = "test"