Add features + formatting

This commit is contained in:
Simon Alibert 2024-11-26 11:11:24 +01:00
parent c79d7ed146
commit 31429e82d0
4 changed files with 68 additions and 66 deletions

View File

@ -2,34 +2,15 @@
Wrapper for Reachy2 camera from sdk Wrapper for Reachy2 camera from sdk
""" """
import argparse from dataclasses import dataclass
import concurrent.futures
import math
import platform
import shutil
import threading
import time
from dataclasses import dataclass, replace
from pathlib import Path
from threading import Thread
import numpy as np import numpy as np
from PIL import Image
from lerobot.common.robot_devices.utils import (
RobotDeviceAlreadyConnectedError,
RobotDeviceNotConnectedError,
busy_wait,
)
from lerobot.common.utils.utils import capture_timestamp_utc
from reachy2_sdk.media.camera import CameraView from reachy2_sdk.media.camera import CameraView
from reachy2_sdk.media.camera_manager import CameraManager from reachy2_sdk.media.camera_manager import CameraManager
@dataclass @dataclass
class ReachyCameraConfig: class ReachyCameraConfig:
fps: int | None = None fps: int | None = None
width: int | None = None width: int | None = None
height: int | None = None height: int | None = None
@ -46,7 +27,7 @@ class ReachyCamera:
name: str, name: str,
image_type: str, image_type: str,
config: ReachyCameraConfig | None = None, config: ReachyCameraConfig | None = None,
**kwargs **kwargs,
): ):
self.host = host self.host = host
self.port = port self.port = port
@ -64,7 +45,6 @@ class ReachyCamera:
self.is_connected = True self.is_connected = True
def read(self) -> np.ndarray: def read(self) -> np.ndarray:
if not self.is_connected: if not self.is_connected:
self.connect() self.connect()
@ -78,7 +58,6 @@ class ReachyCamera:
else: else:
return None return None
elif self.name == "depth" and hasattr(self.cam_manager, "depth"): elif self.name == "depth" and hasattr(self.cam_manager, "depth"):
if self.image_type == "depth": if self.image_type == "depth":
return self.cam_manager.depth.get_depth_frame() return self.cam_manager.depth.get_depth_frame()
elif self.image_type == "rgb": elif self.image_type == "rgb":

View File

@ -18,17 +18,40 @@ import time
from copy import copy from copy import copy
from dataclasses import dataclass, field, replace from dataclasses import dataclass, field, replace
import numpy as np
import torch import torch
from reachy2_sdk import ReachySDK from reachy2_sdk import ReachySDK
from lerobot.common.robot_devices.cameras.utils import Camera from lerobot.common.robot_devices.cameras.utils import Camera
REACHY_MOTORS = [
"neck_yaw.pos",
"neck_pitch.pos",
"neck_roll.pos",
"r_shoulder_pitch.pos",
"r_shoulder_roll.pos",
"r_elbow_yaw.pos",
"r_elbow_pitch.pos",
"r_wrist_roll.pos",
"r_wrist_pitch.pos",
"r_wrist_yaw.pos",
"r_gripper.pos",
"l_shoulder_pitch.pos",
"l_shoulder_roll.pos",
"l_elbow_yaw.pos",
"l_elbow_pitch.pos",
"l_wrist_roll.pos",
"l_wrist_pitch.pos",
"l_wrist_yaw.pos",
"l_gripper.pos",
"mobile_base.vx",
"mobile_base.vy",
"mobile_base.vtheta",
]
@dataclass @dataclass
class ReachyRobotConfig: class ReachyRobotConfig:
robot_type: str | None = "Reachy2" robot_type: str | None = "reachy2"
cameras: dict[str, Camera] = field(default_factory=lambda: {}) cameras: dict[str, Camera] = field(default_factory=lambda: {})
ip_address: str | None = "172.17.135.207" ip_address: str | None = "172.17.135.207"
# ip_address: str | None = "localhost" # ip_address: str | None = "localhost"
@ -59,6 +82,37 @@ class ReachyRobot:
self.state_keys = None self.state_keys = None
self.action_keys = None self.action_keys = None
@property
def camera_features(self) -> dict:
cam_ft = {}
for cam_key, cam in self.cameras.items():
key = f"observation.images.{cam_key}"
cam_ft[key] = {
"shape": (cam.height, cam.width, cam.channels),
"names": ["height", "width", "channels"],
"info": None,
}
return cam_ft
@property
def motor_features(self) -> dict:
return {
"action": {
"dtype": "float32",
"shape": (len(REACHY_MOTORS),),
"names": REACHY_MOTORS,
},
"observation.state": {
"dtype": "float32",
"shape": (len(REACHY_MOTORS),),
"names": REACHY_MOTORS,
},
}
@property
def features(self):
return {**self.motor_features, **self.camera_features}
def connect(self) -> None: def connect(self) -> None:
print("Connecting to Reachy") print("Connecting to Reachy")
self.reachy.is_connected = self.reachy.connect() self.reachy.is_connected = self.reachy.connect()
@ -73,14 +127,10 @@ class ReachyRobot:
for name in self.cameras: for name in self.cameras:
print(f"Connecting camera: {name}") print(f"Connecting camera: {name}")
self.cameras[name].connect() self.cameras[name].connect()
self.is_connected = ( self.is_connected = self.is_connected and self.cameras[name].is_connected
self.is_connected and self.cameras[name].is_connected
)
if not self.is_connected: if not self.is_connected:
print( print("Could not connect to the cameras, check that all cameras are plugged-in.")
"Could not connect to the cameras, check that all cameras are plugged-in."
)
raise ConnectionError() raise ConnectionError()
def run_calibration(self): def run_calibration(self):
@ -179,18 +229,12 @@ class ReachyRobot:
# Capture images from cameras # Capture images from cameras
images = {} images = {}
for name in self.cameras: for name in self.cameras:
before_camread_t = time.perf_counter() # before_camread_t = time.perf_counter()
images[name] = self.cameras[ images[name] = self.cameras[name].read() # Reachy cameras read() is not blocking?
name
].read() # Reachy cameras read() is not blocking?
# print(f'name: {name} img: {images[name]}') # print(f'name: {name} img: {images[name]}')
if images[name] is not None: if images[name] is not None:
images[name] = torch.from_numpy( images[name] = torch.from_numpy(copy(images[name][0])) # seems like I need to copy?
copy(images[name][0]) self.logs[f"read_camera_{name}_dt_s"] = images[name][1] # full timestamp, TODO dt
) # seems like I need to copy?
self.logs[f"read_camera_{name}_dt_s"] = images[name][
1
] # full timestamp, TODO dt
# Populate output dictionnaries # Populate output dictionnaries
obs_dict = {} obs_dict = {}

View File

@ -6,8 +6,7 @@
_target_: lerobot.common.robot_devices.robots.reachy2.ReachyRobot _target_: lerobot.common.robot_devices.robots.reachy2.ReachyRobot
robot_type: reachy2
robot_type: Reachy2
cameras: cameras:
head_left: head_left:

View File

@ -1,30 +1,10 @@
import argparse
import logging
import time import time
from pathlib import Path
from typing import List
# from safetensors.torch import load_file, save_file # from safetensors.torch import load_file, save_file
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.populate_dataset import (create_lerobot_dataset,
delete_current_episode,
init_dataset,
save_current_episode)
from lerobot.common.robot_devices.control_utils import (
control_loop, has_method, init_keyboard_listener, init_policy,
log_control_info, record_episode, reset_environment,
sanity_check_dataset_name, stop_recording, warmup_record)
from lerobot.common.robot_devices.robots.factory import make_robot from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.robot_devices.robots.utils import Robot from lerobot.common.utils.utils import init_hydra_config, init_logging
from lerobot.common.robot_devices.utils import busy_wait, safe_disconnect
from lerobot.common.utils.utils import (init_hydra_config, init_logging,
log_say, none_or_int)
if __name__ == "__main__":
import time
import cv2
if __name__ == '__main__':
init_logging() init_logging()
control_mode = "test" control_mode = "test"