Add features + formatting
This commit is contained in:
parent
c79d7ed146
commit
31429e82d0
|
@ -2,34 +2,15 @@
|
||||||
Wrapper for Reachy2 camera from sdk
|
Wrapper for Reachy2 camera from sdk
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
from dataclasses import dataclass
|
||||||
import concurrent.futures
|
|
||||||
import math
|
|
||||||
import platform
|
|
||||||
import shutil
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
from dataclasses import dataclass, replace
|
|
||||||
from pathlib import Path
|
|
||||||
from threading import Thread
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from lerobot.common.robot_devices.utils import (
|
|
||||||
RobotDeviceAlreadyConnectedError,
|
|
||||||
RobotDeviceNotConnectedError,
|
|
||||||
busy_wait,
|
|
||||||
)
|
|
||||||
from lerobot.common.utils.utils import capture_timestamp_utc
|
|
||||||
|
|
||||||
from reachy2_sdk.media.camera import CameraView
|
from reachy2_sdk.media.camera import CameraView
|
||||||
from reachy2_sdk.media.camera_manager import CameraManager
|
from reachy2_sdk.media.camera_manager import CameraManager
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ReachyCameraConfig:
|
class ReachyCameraConfig:
|
||||||
|
|
||||||
fps: int | None = None
|
fps: int | None = None
|
||||||
width: int | None = None
|
width: int | None = None
|
||||||
height: int | None = None
|
height: int | None = None
|
||||||
|
@ -46,7 +27,7 @@ class ReachyCamera:
|
||||||
name: str,
|
name: str,
|
||||||
image_type: str,
|
image_type: str,
|
||||||
config: ReachyCameraConfig | None = None,
|
config: ReachyCameraConfig | None = None,
|
||||||
**kwargs
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.host = host
|
self.host = host
|
||||||
self.port = port
|
self.port = port
|
||||||
|
@ -64,7 +45,6 @@ class ReachyCamera:
|
||||||
self.is_connected = True
|
self.is_connected = True
|
||||||
|
|
||||||
def read(self) -> np.ndarray:
|
def read(self) -> np.ndarray:
|
||||||
|
|
||||||
if not self.is_connected:
|
if not self.is_connected:
|
||||||
self.connect()
|
self.connect()
|
||||||
|
|
||||||
|
@ -78,7 +58,6 @@ class ReachyCamera:
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
elif self.name == "depth" and hasattr(self.cam_manager, "depth"):
|
elif self.name == "depth" and hasattr(self.cam_manager, "depth"):
|
||||||
|
|
||||||
if self.image_type == "depth":
|
if self.image_type == "depth":
|
||||||
return self.cam_manager.depth.get_depth_frame()
|
return self.cam_manager.depth.get_depth_frame()
|
||||||
elif self.image_type == "rgb":
|
elif self.image_type == "rgb":
|
||||||
|
|
|
@ -18,17 +18,40 @@ import time
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from dataclasses import dataclass, field, replace
|
from dataclasses import dataclass, field, replace
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from reachy2_sdk import ReachySDK
|
from reachy2_sdk import ReachySDK
|
||||||
|
|
||||||
from lerobot.common.robot_devices.cameras.utils import Camera
|
from lerobot.common.robot_devices.cameras.utils import Camera
|
||||||
|
|
||||||
|
REACHY_MOTORS = [
|
||||||
|
"neck_yaw.pos",
|
||||||
|
"neck_pitch.pos",
|
||||||
|
"neck_roll.pos",
|
||||||
|
"r_shoulder_pitch.pos",
|
||||||
|
"r_shoulder_roll.pos",
|
||||||
|
"r_elbow_yaw.pos",
|
||||||
|
"r_elbow_pitch.pos",
|
||||||
|
"r_wrist_roll.pos",
|
||||||
|
"r_wrist_pitch.pos",
|
||||||
|
"r_wrist_yaw.pos",
|
||||||
|
"r_gripper.pos",
|
||||||
|
"l_shoulder_pitch.pos",
|
||||||
|
"l_shoulder_roll.pos",
|
||||||
|
"l_elbow_yaw.pos",
|
||||||
|
"l_elbow_pitch.pos",
|
||||||
|
"l_wrist_roll.pos",
|
||||||
|
"l_wrist_pitch.pos",
|
||||||
|
"l_wrist_yaw.pos",
|
||||||
|
"l_gripper.pos",
|
||||||
|
"mobile_base.vx",
|
||||||
|
"mobile_base.vy",
|
||||||
|
"mobile_base.vtheta",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ReachyRobotConfig:
|
class ReachyRobotConfig:
|
||||||
robot_type: str | None = "Reachy2"
|
robot_type: str | None = "reachy2"
|
||||||
cameras: dict[str, Camera] = field(default_factory=lambda: {})
|
cameras: dict[str, Camera] = field(default_factory=lambda: {})
|
||||||
ip_address: str | None = "172.17.135.207"
|
ip_address: str | None = "172.17.135.207"
|
||||||
# ip_address: str | None = "localhost"
|
# ip_address: str | None = "localhost"
|
||||||
|
@ -59,6 +82,37 @@ class ReachyRobot:
|
||||||
self.state_keys = None
|
self.state_keys = None
|
||||||
self.action_keys = None
|
self.action_keys = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def camera_features(self) -> dict:
|
||||||
|
cam_ft = {}
|
||||||
|
for cam_key, cam in self.cameras.items():
|
||||||
|
key = f"observation.images.{cam_key}"
|
||||||
|
cam_ft[key] = {
|
||||||
|
"shape": (cam.height, cam.width, cam.channels),
|
||||||
|
"names": ["height", "width", "channels"],
|
||||||
|
"info": None,
|
||||||
|
}
|
||||||
|
return cam_ft
|
||||||
|
|
||||||
|
@property
|
||||||
|
def motor_features(self) -> dict:
|
||||||
|
return {
|
||||||
|
"action": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (len(REACHY_MOTORS),),
|
||||||
|
"names": REACHY_MOTORS,
|
||||||
|
},
|
||||||
|
"observation.state": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (len(REACHY_MOTORS),),
|
||||||
|
"names": REACHY_MOTORS,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def features(self):
|
||||||
|
return {**self.motor_features, **self.camera_features}
|
||||||
|
|
||||||
def connect(self) -> None:
|
def connect(self) -> None:
|
||||||
print("Connecting to Reachy")
|
print("Connecting to Reachy")
|
||||||
self.reachy.is_connected = self.reachy.connect()
|
self.reachy.is_connected = self.reachy.connect()
|
||||||
|
@ -73,14 +127,10 @@ class ReachyRobot:
|
||||||
for name in self.cameras:
|
for name in self.cameras:
|
||||||
print(f"Connecting camera: {name}")
|
print(f"Connecting camera: {name}")
|
||||||
self.cameras[name].connect()
|
self.cameras[name].connect()
|
||||||
self.is_connected = (
|
self.is_connected = self.is_connected and self.cameras[name].is_connected
|
||||||
self.is_connected and self.cameras[name].is_connected
|
|
||||||
)
|
|
||||||
|
|
||||||
if not self.is_connected:
|
if not self.is_connected:
|
||||||
print(
|
print("Could not connect to the cameras, check that all cameras are plugged-in.")
|
||||||
"Could not connect to the cameras, check that all cameras are plugged-in."
|
|
||||||
)
|
|
||||||
raise ConnectionError()
|
raise ConnectionError()
|
||||||
|
|
||||||
def run_calibration(self):
|
def run_calibration(self):
|
||||||
|
@ -179,18 +229,12 @@ class ReachyRobot:
|
||||||
# Capture images from cameras
|
# Capture images from cameras
|
||||||
images = {}
|
images = {}
|
||||||
for name in self.cameras:
|
for name in self.cameras:
|
||||||
before_camread_t = time.perf_counter()
|
# before_camread_t = time.perf_counter()
|
||||||
images[name] = self.cameras[
|
images[name] = self.cameras[name].read() # Reachy cameras read() is not blocking?
|
||||||
name
|
|
||||||
].read() # Reachy cameras read() is not blocking?
|
|
||||||
# print(f'name: {name} img: {images[name]}')
|
# print(f'name: {name} img: {images[name]}')
|
||||||
if images[name] is not None:
|
if images[name] is not None:
|
||||||
images[name] = torch.from_numpy(
|
images[name] = torch.from_numpy(copy(images[name][0])) # seems like I need to copy?
|
||||||
copy(images[name][0])
|
self.logs[f"read_camera_{name}_dt_s"] = images[name][1] # full timestamp, TODO dt
|
||||||
) # seems like I need to copy?
|
|
||||||
self.logs[f"read_camera_{name}_dt_s"] = images[name][
|
|
||||||
1
|
|
||||||
] # full timestamp, TODO dt
|
|
||||||
|
|
||||||
# Populate output dictionnaries
|
# Populate output dictionnaries
|
||||||
obs_dict = {}
|
obs_dict = {}
|
||||||
|
|
|
@ -6,8 +6,7 @@
|
||||||
|
|
||||||
|
|
||||||
_target_: lerobot.common.robot_devices.robots.reachy2.ReachyRobot
|
_target_: lerobot.common.robot_devices.robots.reachy2.ReachyRobot
|
||||||
|
robot_type: reachy2
|
||||||
robot_type: Reachy2
|
|
||||||
|
|
||||||
cameras:
|
cameras:
|
||||||
head_left:
|
head_left:
|
||||||
|
|
|
@ -1,30 +1,10 @@
|
||||||
|
|
||||||
import argparse
|
|
||||||
import logging
|
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
# from safetensors.torch import load_file, save_file
|
# from safetensors.torch import load_file, save_file
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
|
||||||
from lerobot.common.datasets.populate_dataset import (create_lerobot_dataset,
|
|
||||||
delete_current_episode,
|
|
||||||
init_dataset,
|
|
||||||
save_current_episode)
|
|
||||||
from lerobot.common.robot_devices.control_utils import (
|
|
||||||
control_loop, has_method, init_keyboard_listener, init_policy,
|
|
||||||
log_control_info, record_episode, reset_environment,
|
|
||||||
sanity_check_dataset_name, stop_recording, warmup_record)
|
|
||||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||||
from lerobot.common.robot_devices.robots.utils import Robot
|
from lerobot.common.utils.utils import init_hydra_config, init_logging
|
||||||
from lerobot.common.robot_devices.utils import busy_wait, safe_disconnect
|
|
||||||
from lerobot.common.utils.utils import (init_hydra_config, init_logging,
|
|
||||||
log_say, none_or_int)
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
import time
|
|
||||||
import cv2
|
|
||||||
if __name__ == '__main__':
|
|
||||||
init_logging()
|
init_logging()
|
||||||
|
|
||||||
control_mode = "test"
|
control_mode = "test"
|
||||||
|
|
Loading…
Reference in New Issue