Add features + formatting
This commit is contained in:
parent
c79d7ed146
commit
31429e82d0
|
@ -2,34 +2,15 @@
|
|||
Wrapper for Reachy2 camera from sdk
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import concurrent.futures
|
||||
import math
|
||||
import platform
|
||||
import shutil
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, replace
|
||||
from pathlib import Path
|
||||
from threading import Thread
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from lerobot.common.robot_devices.utils import (
|
||||
RobotDeviceAlreadyConnectedError,
|
||||
RobotDeviceNotConnectedError,
|
||||
busy_wait,
|
||||
)
|
||||
from lerobot.common.utils.utils import capture_timestamp_utc
|
||||
|
||||
from reachy2_sdk.media.camera import CameraView
|
||||
from reachy2_sdk.media.camera_manager import CameraManager
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReachyCameraConfig:
|
||||
|
||||
fps: int | None = None
|
||||
width: int | None = None
|
||||
height: int | None = None
|
||||
|
@ -46,7 +27,7 @@ class ReachyCamera:
|
|||
name: str,
|
||||
image_type: str,
|
||||
config: ReachyCameraConfig | None = None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
self.host = host
|
||||
self.port = port
|
||||
|
@ -64,7 +45,6 @@ class ReachyCamera:
|
|||
self.is_connected = True
|
||||
|
||||
def read(self) -> np.ndarray:
|
||||
|
||||
if not self.is_connected:
|
||||
self.connect()
|
||||
|
||||
|
@ -78,7 +58,6 @@ class ReachyCamera:
|
|||
else:
|
||||
return None
|
||||
elif self.name == "depth" and hasattr(self.cam_manager, "depth"):
|
||||
|
||||
if self.image_type == "depth":
|
||||
return self.cam_manager.depth.get_depth_frame()
|
||||
elif self.image_type == "rgb":
|
||||
|
|
|
@ -18,17 +18,40 @@ import time
|
|||
from copy import copy
|
||||
from dataclasses import dataclass, field, replace
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from reachy2_sdk import ReachySDK
|
||||
|
||||
from lerobot.common.robot_devices.cameras.utils import Camera
|
||||
|
||||
REACHY_MOTORS = [
|
||||
"neck_yaw.pos",
|
||||
"neck_pitch.pos",
|
||||
"neck_roll.pos",
|
||||
"r_shoulder_pitch.pos",
|
||||
"r_shoulder_roll.pos",
|
||||
"r_elbow_yaw.pos",
|
||||
"r_elbow_pitch.pos",
|
||||
"r_wrist_roll.pos",
|
||||
"r_wrist_pitch.pos",
|
||||
"r_wrist_yaw.pos",
|
||||
"r_gripper.pos",
|
||||
"l_shoulder_pitch.pos",
|
||||
"l_shoulder_roll.pos",
|
||||
"l_elbow_yaw.pos",
|
||||
"l_elbow_pitch.pos",
|
||||
"l_wrist_roll.pos",
|
||||
"l_wrist_pitch.pos",
|
||||
"l_wrist_yaw.pos",
|
||||
"l_gripper.pos",
|
||||
"mobile_base.vx",
|
||||
"mobile_base.vy",
|
||||
"mobile_base.vtheta",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReachyRobotConfig:
|
||||
robot_type: str | None = "Reachy2"
|
||||
robot_type: str | None = "reachy2"
|
||||
cameras: dict[str, Camera] = field(default_factory=lambda: {})
|
||||
ip_address: str | None = "172.17.135.207"
|
||||
# ip_address: str | None = "localhost"
|
||||
|
@ -59,6 +82,37 @@ class ReachyRobot:
|
|||
self.state_keys = None
|
||||
self.action_keys = None
|
||||
|
||||
@property
|
||||
def camera_features(self) -> dict:
|
||||
cam_ft = {}
|
||||
for cam_key, cam in self.cameras.items():
|
||||
key = f"observation.images.{cam_key}"
|
||||
cam_ft[key] = {
|
||||
"shape": (cam.height, cam.width, cam.channels),
|
||||
"names": ["height", "width", "channels"],
|
||||
"info": None,
|
||||
}
|
||||
return cam_ft
|
||||
|
||||
@property
|
||||
def motor_features(self) -> dict:
|
||||
return {
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (len(REACHY_MOTORS),),
|
||||
"names": REACHY_MOTORS,
|
||||
},
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": (len(REACHY_MOTORS),),
|
||||
"names": REACHY_MOTORS,
|
||||
},
|
||||
}
|
||||
|
||||
@property
|
||||
def features(self):
|
||||
return {**self.motor_features, **self.camera_features}
|
||||
|
||||
def connect(self) -> None:
|
||||
print("Connecting to Reachy")
|
||||
self.reachy.is_connected = self.reachy.connect()
|
||||
|
@ -73,14 +127,10 @@ class ReachyRobot:
|
|||
for name in self.cameras:
|
||||
print(f"Connecting camera: {name}")
|
||||
self.cameras[name].connect()
|
||||
self.is_connected = (
|
||||
self.is_connected and self.cameras[name].is_connected
|
||||
)
|
||||
self.is_connected = self.is_connected and self.cameras[name].is_connected
|
||||
|
||||
if not self.is_connected:
|
||||
print(
|
||||
"Could not connect to the cameras, check that all cameras are plugged-in."
|
||||
)
|
||||
print("Could not connect to the cameras, check that all cameras are plugged-in.")
|
||||
raise ConnectionError()
|
||||
|
||||
def run_calibration(self):
|
||||
|
@ -119,7 +169,7 @@ class ReachyRobot:
|
|||
action["mobile_base_x.vel"] = last_cmd_vel["x"]
|
||||
action["mobile_base_y.vel"] = last_cmd_vel["y"]
|
||||
action["mobile_base_theta.vel"] = last_cmd_vel["theta"]
|
||||
|
||||
|
||||
action = torch.as_tensor(list(action.values()))
|
||||
|
||||
obs_dict = self.capture_observation()
|
||||
|
@ -179,18 +229,12 @@ class ReachyRobot:
|
|||
# Capture images from cameras
|
||||
images = {}
|
||||
for name in self.cameras:
|
||||
before_camread_t = time.perf_counter()
|
||||
images[name] = self.cameras[
|
||||
name
|
||||
].read() # Reachy cameras read() is not blocking?
|
||||
# before_camread_t = time.perf_counter()
|
||||
images[name] = self.cameras[name].read() # Reachy cameras read() is not blocking?
|
||||
# print(f'name: {name} img: {images[name]}')
|
||||
if images[name] is not None:
|
||||
images[name] = torch.from_numpy(
|
||||
copy(images[name][0])
|
||||
) # seems like I need to copy?
|
||||
self.logs[f"read_camera_{name}_dt_s"] = images[name][
|
||||
1
|
||||
] # full timestamp, TODO dt
|
||||
images[name] = torch.from_numpy(copy(images[name][0])) # seems like I need to copy?
|
||||
self.logs[f"read_camera_{name}_dt_s"] = images[name][1] # full timestamp, TODO dt
|
||||
|
||||
# Populate output dictionnaries
|
||||
obs_dict = {}
|
||||
|
|
|
@ -6,8 +6,7 @@
|
|||
|
||||
|
||||
_target_: lerobot.common.robot_devices.robots.reachy2.ReachyRobot
|
||||
|
||||
robot_type: Reachy2
|
||||
robot_type: reachy2
|
||||
|
||||
cameras:
|
||||
head_left:
|
||||
|
|
|
@ -1,30 +1,10 @@
|
|||
|
||||
import argparse
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
# from safetensors.torch import load_file, save_file
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.populate_dataset import (create_lerobot_dataset,
|
||||
delete_current_episode,
|
||||
init_dataset,
|
||||
save_current_episode)
|
||||
from lerobot.common.robot_devices.control_utils import (
|
||||
control_loop, has_method, init_keyboard_listener, init_policy,
|
||||
log_control_info, record_episode, reset_environment,
|
||||
sanity_check_dataset_name, stop_recording, warmup_record)
|
||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
from lerobot.common.robot_devices.utils import busy_wait, safe_disconnect
|
||||
from lerobot.common.utils.utils import (init_hydra_config, init_logging,
|
||||
log_say, none_or_int)
|
||||
from lerobot.common.utils.utils import init_hydra_config, init_logging
|
||||
|
||||
|
||||
import time
|
||||
import cv2
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
init_logging()
|
||||
|
||||
control_mode = "test"
|
||||
|
|
Loading…
Reference in New Issue