[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
472853a818
commit
fd8032a023
|
@ -31,20 +31,20 @@ from lerobot.common.utils.utils import capture_timestamp_utc
|
||||||
MAX_OPENCV_INDEX = 60
|
MAX_OPENCV_INDEX = 60
|
||||||
|
|
||||||
undistort = True
|
undistort = True
|
||||||
|
|
||||||
|
|
||||||
def undistort_image(image):
|
def undistort_image(image):
|
||||||
import cv2
|
import cv2
|
||||||
camera_matrix = np.array([
|
|
||||||
[289.11451, 0., 347.23664],
|
camera_matrix = np.array([[289.11451, 0.0, 347.23664], [0.0, 289.75319, 235.67429], [0.0, 0.0, 1.0]])
|
||||||
[0., 289.75319, 235.67429],
|
|
||||||
[0., 0., 1.]
|
|
||||||
])
|
|
||||||
|
|
||||||
dist_coeffs = np.array([-0.208848, 0.028006, -0.000705, -0.000820, 0.000000])
|
dist_coeffs = np.array([-0.208848, 0.028006, -0.000705, -0.000820, 0.000000])
|
||||||
|
|
||||||
undistorted_image = cv2.undistort(image, camera_matrix, dist_coeffs)
|
undistorted_image = cv2.undistort(image, camera_matrix, dist_coeffs)
|
||||||
|
|
||||||
return undistorted_image
|
return undistorted_image
|
||||||
|
|
||||||
|
|
||||||
def find_cameras(raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX, mock=False) -> list[dict]:
|
def find_cameras(raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX, mock=False) -> list[dict]:
|
||||||
cameras = []
|
cameras = []
|
||||||
if platform.system() == "Linux":
|
if platform.system() == "Linux":
|
||||||
|
|
|
@ -153,13 +153,16 @@ def predict_action(observation, policy, device, use_amp):
|
||||||
|
|
||||||
# return listener, events
|
# return listener, events
|
||||||
|
|
||||||
|
|
||||||
def init_keyboard_listener():
|
def init_keyboard_listener():
|
||||||
events = {}
|
events = {}
|
||||||
events["exit_early"] = False
|
events["exit_early"] = False
|
||||||
events["rerecord_episode"] = False
|
events["rerecord_episode"] = False
|
||||||
events["stop_recording"] = False
|
events["stop_recording"] = False
|
||||||
|
import threading
|
||||||
|
|
||||||
from sshkeyboard import listen_keyboard
|
from sshkeyboard import listen_keyboard
|
||||||
import threading
|
|
||||||
def on_press(key):
|
def on_press(key):
|
||||||
try:
|
try:
|
||||||
if key == "right":
|
if key == "right":
|
||||||
|
@ -179,7 +182,8 @@ def init_keyboard_listener():
|
||||||
listener = threading.Thread(target=listen_keyboard, kwargs={"on_press": on_press})
|
listener = threading.Thread(target=listen_keyboard, kwargs={"on_press": on_press})
|
||||||
listener.start()
|
listener.start()
|
||||||
|
|
||||||
return listener,events
|
return listener, events
|
||||||
|
|
||||||
|
|
||||||
def warmup_record(
|
def warmup_record(
|
||||||
robot,
|
robot,
|
||||||
|
@ -284,7 +288,7 @@ def control_loop(
|
||||||
dataset.add_frame(frame)
|
dataset.add_frame(frame)
|
||||||
|
|
||||||
if display_cameras:
|
if display_cameras:
|
||||||
# if display_cameras and not is_headless():
|
# if display_cameras and not is_headless():
|
||||||
image_keys = [key for key in observation if "image" in key]
|
image_keys = [key for key in observation if "image" in key]
|
||||||
for key in image_keys:
|
for key in image_keys:
|
||||||
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||||
|
@ -316,6 +320,7 @@ def reset_environment(robot, events, reset_time_s, fps):
|
||||||
teleoperate=True,
|
teleoperate=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# def stop_recording(robot, listener, display_cameras):
|
# def stop_recording(robot, listener, display_cameras):
|
||||||
# robot.disconnect()
|
# robot.disconnect()
|
||||||
|
|
||||||
|
@ -326,16 +331,19 @@ def reset_environment(robot, events, reset_time_s, fps):
|
||||||
# if display_cameras:
|
# if display_cameras:
|
||||||
# cv2.destroyAllWindows()
|
# cv2.destroyAllWindows()
|
||||||
|
|
||||||
|
|
||||||
def stop_recording(robot, listener, display_cameras):
|
def stop_recording(robot, listener, display_cameras):
|
||||||
robot.disconnect()
|
robot.disconnect()
|
||||||
|
|
||||||
from sshkeyboard import stop_listening
|
from sshkeyboard import stop_listening
|
||||||
|
|
||||||
if listener is not None:
|
if listener is not None:
|
||||||
stop_listening()
|
stop_listening()
|
||||||
|
|
||||||
if display_cameras:
|
if display_cameras:
|
||||||
cv2.destroyAllWindows()
|
cv2.destroyAllWindows()
|
||||||
|
|
||||||
|
|
||||||
def sanity_check_dataset_name(repo_id, policy_cfg):
|
def sanity_check_dataset_name(repo_id, policy_cfg):
|
||||||
_, dataset_name = repo_id.split("/")
|
_, dataset_name = repo_id.split("/")
|
||||||
# either repo_id doesnt start with "eval_" and there is no policy
|
# either repo_id doesnt start with "eval_" and there is no policy
|
||||||
|
|
|
@ -479,6 +479,7 @@ class So100RobotConfig(ManipulatorRobotConfig):
|
||||||
|
|
||||||
mock: bool = False
|
mock: bool = False
|
||||||
|
|
||||||
|
|
||||||
@RobotConfig.register_subclass("roarm_m3")
|
@RobotConfig.register_subclass("roarm_m3")
|
||||||
@dataclass
|
@dataclass
|
||||||
class RoarmRobotConfig(RobotConfig):
|
class RoarmRobotConfig(RobotConfig):
|
||||||
|
@ -489,14 +490,14 @@ class RoarmRobotConfig(RobotConfig):
|
||||||
|
|
||||||
leader_arms: dict[str, str] = field(
|
leader_arms: dict[str, str] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"main": "/dev/ttyUSB0",
|
"main": "/dev/ttyUSB0",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Follower arms configuration: left and right ports
|
# Follower arms configuration: left and right ports
|
||||||
follower_arms: dict[str, str] = field(
|
follower_arms: dict[str, str] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"main": "/dev/ttyUSB1",
|
"main": "/dev/ttyUSB1",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -519,6 +520,7 @@ class RoarmRobotConfig(RobotConfig):
|
||||||
|
|
||||||
mock: bool = False
|
mock: bool = False
|
||||||
|
|
||||||
|
|
||||||
@RobotConfig.register_subclass("stretch")
|
@RobotConfig.register_subclass("stretch")
|
||||||
@dataclass
|
@dataclass
|
||||||
class StretchRobotConfig(RobotConfig):
|
class StretchRobotConfig(RobotConfig):
|
||||||
|
@ -554,6 +556,7 @@ class StretchRobotConfig(RobotConfig):
|
||||||
|
|
||||||
mock: bool = False
|
mock: bool = False
|
||||||
|
|
||||||
|
|
||||||
@RobotConfig.register_subclass("lekiwi")
|
@RobotConfig.register_subclass("lekiwi")
|
||||||
@dataclass
|
@dataclass
|
||||||
class LeKiwiRobotConfig(RobotConfig):
|
class LeKiwiRobotConfig(RobotConfig):
|
||||||
|
|
|
@ -4,19 +4,17 @@ and send orders to its motors.
|
||||||
# TODO(rcadene, aliberts): reorganize the codebase into one file per robot, with the associated
|
# TODO(rcadene, aliberts): reorganize the codebase into one file per robot, with the associated
|
||||||
# calibration procedure, to make it easy for people to add their own robot.
|
# calibration procedure, to make it easy for people to add their own robot.
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
import warnings
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from roarm_sdk.roarm import roarm
|
||||||
|
|
||||||
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
|
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
|
||||||
from lerobot.common.robot_devices.robots.configs import RoarmRobotConfig
|
from lerobot.common.robot_devices.robots.configs import RoarmRobotConfig
|
||||||
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
|
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
|
||||||
from roarm_sdk.roarm import roarm
|
|
||||||
|
|
||||||
def ensure_safe_goal_position(
|
def ensure_safe_goal_position(
|
||||||
goal_pos: torch.Tensor, present_pos: torch.Tensor, max_relative_target: float | list[float]
|
goal_pos: torch.Tensor, present_pos: torch.Tensor, max_relative_target: float | list[float]
|
||||||
|
@ -37,7 +35,8 @@ def ensure_safe_goal_position(
|
||||||
|
|
||||||
return safe_goal_pos
|
return safe_goal_pos
|
||||||
|
|
||||||
def make_roarm_from_configs(configs: dict[str, str]) -> (dict[str, roarm]):
|
|
||||||
|
def make_roarm_from_configs(configs: dict[str, str]) -> dict[str, roarm]:
|
||||||
roarms = {}
|
roarms = {}
|
||||||
|
|
||||||
for key, port in configs.items():
|
for key, port in configs.items():
|
||||||
|
@ -45,8 +44,8 @@ def make_roarm_from_configs(configs: dict[str, str]) -> (dict[str, roarm]):
|
||||||
|
|
||||||
return roarms
|
return roarms
|
||||||
|
|
||||||
class RoarmRobot:
|
|
||||||
|
|
||||||
|
class RoarmRobot:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: RoarmRobotConfig,
|
config: RoarmRobotConfig,
|
||||||
|
@ -58,7 +57,7 @@ class RoarmRobot:
|
||||||
self.cameras = make_cameras_from_configs(self.config.cameras)
|
self.cameras = make_cameras_from_configs(self.config.cameras)
|
||||||
self.is_connected = False
|
self.is_connected = False
|
||||||
self.logs = {}
|
self.logs = {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def camera_features(self) -> dict:
|
def camera_features(self) -> dict:
|
||||||
cam_ft = {}
|
cam_ft = {}
|
||||||
|
|
|
@ -7,9 +7,9 @@ from lerobot.common.robot_devices.robots.configs import (
|
||||||
LeKiwiRobotConfig,
|
LeKiwiRobotConfig,
|
||||||
ManipulatorRobotConfig,
|
ManipulatorRobotConfig,
|
||||||
MossRobotConfig,
|
MossRobotConfig,
|
||||||
|
RoarmRobotConfig,
|
||||||
RobotConfig,
|
RobotConfig,
|
||||||
So100RobotConfig,
|
So100RobotConfig,
|
||||||
RoarmRobotConfig,
|
|
||||||
StretchRobotConfig,
|
StretchRobotConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -73,6 +73,7 @@ def make_robot_from_config(config: RobotConfig):
|
||||||
|
|
||||||
return StretchRobot(config)
|
return StretchRobot(config)
|
||||||
|
|
||||||
|
|
||||||
def make_robot(robot_type: str, **kwargs) -> Robot:
|
def make_robot(robot_type: str, **kwargs) -> Robot:
|
||||||
config = make_robot_config(robot_type, **kwargs)
|
config = make_robot_config(robot_type, **kwargs)
|
||||||
return make_robot_from_config(config)
|
return make_robot_from_config(config)
|
||||||
|
|
Loading…
Reference in New Issue