[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2025-03-05 12:12:42 +00:00
parent 472853a818
commit fd8032a023
5 changed files with 32 additions and 21 deletions

View File

@ -31,20 +31,20 @@ from lerobot.common.utils.utils import capture_timestamp_utc
MAX_OPENCV_INDEX = 60
undistort = True
def undistort_image(image):
import cv2
camera_matrix = np.array([
[289.11451, 0., 347.23664],
[0., 289.75319, 235.67429],
[0., 0., 1.]
])
camera_matrix = np.array([[289.11451, 0.0, 347.23664], [0.0, 289.75319, 235.67429], [0.0, 0.0, 1.0]])
dist_coeffs = np.array([-0.208848, 0.028006, -0.000705, -0.000820, 0.000000])
undistorted_image = cv2.undistort(image, camera_matrix, dist_coeffs)
return undistorted_image
def find_cameras(raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX, mock=False) -> list[dict]:
cameras = []
if platform.system() == "Linux":

View File

@ -153,13 +153,16 @@ def predict_action(observation, policy, device, use_amp):
# return listener, events
def init_keyboard_listener():
events = {}
events["exit_early"] = False
events["rerecord_episode"] = False
events["stop_recording"] = False
import threading
from sshkeyboard import listen_keyboard
import threading
def on_press(key):
try:
if key == "right":
@ -179,7 +182,8 @@ def init_keyboard_listener():
listener = threading.Thread(target=listen_keyboard, kwargs={"on_press": on_press})
listener.start()
return listener,events
return listener, events
def warmup_record(
robot,
@ -284,7 +288,7 @@ def control_loop(
dataset.add_frame(frame)
if display_cameras:
# if display_cameras and not is_headless():
# if display_cameras and not is_headless():
image_keys = [key for key in observation if "image" in key]
for key in image_keys:
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
@ -316,6 +320,7 @@ def reset_environment(robot, events, reset_time_s, fps):
teleoperate=True,
)
# def stop_recording(robot, listener, display_cameras):
# robot.disconnect()
@ -326,16 +331,19 @@ def reset_environment(robot, events, reset_time_s, fps):
# if display_cameras:
# cv2.destroyAllWindows()
def stop_recording(robot, listener, display_cameras):
robot.disconnect()
from sshkeyboard import stop_listening
if listener is not None:
stop_listening()
if display_cameras:
cv2.destroyAllWindows()
def sanity_check_dataset_name(repo_id, policy_cfg):
_, dataset_name = repo_id.split("/")
# either repo_id doesnt start with "eval_" and there is no policy

View File

@ -479,6 +479,7 @@ class So100RobotConfig(ManipulatorRobotConfig):
mock: bool = False
@RobotConfig.register_subclass("roarm_m3")
@dataclass
class RoarmRobotConfig(RobotConfig):
@ -489,14 +490,14 @@ class RoarmRobotConfig(RobotConfig):
leader_arms: dict[str, str] = field(
default_factory=lambda: {
"main": "/dev/ttyUSB0",
"main": "/dev/ttyUSB0",
}
)
# Follower arms configuration: left and right ports
follower_arms: dict[str, str] = field(
default_factory=lambda: {
"main": "/dev/ttyUSB1",
"main": "/dev/ttyUSB1",
}
)
@ -519,6 +520,7 @@ class RoarmRobotConfig(RobotConfig):
mock: bool = False
@RobotConfig.register_subclass("stretch")
@dataclass
class StretchRobotConfig(RobotConfig):
@ -554,6 +556,7 @@ class StretchRobotConfig(RobotConfig):
mock: bool = False
@RobotConfig.register_subclass("lekiwi")
@dataclass
class LeKiwiRobotConfig(RobotConfig):

View File

@ -4,19 +4,17 @@ and send orders to its motors.
# TODO(rcadene, aliberts): reorganize the codebase into one file per robot, with the associated
# calibration procedure, to make it easy for people to add their own robot.
import json
import logging
import time
import warnings
from pathlib import Path
import numpy as np
import torch
from roarm_sdk.roarm import roarm
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
from lerobot.common.robot_devices.robots.configs import RoarmRobotConfig
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
from roarm_sdk.roarm import roarm
def ensure_safe_goal_position(
goal_pos: torch.Tensor, present_pos: torch.Tensor, max_relative_target: float | list[float]
@ -37,7 +35,8 @@ def ensure_safe_goal_position(
return safe_goal_pos
def make_roarm_from_configs(configs: dict[str, str]) -> (dict[str, roarm]):
def make_roarm_from_configs(configs: dict[str, str]) -> dict[str, roarm]:
roarms = {}
for key, port in configs.items():
@ -45,8 +44,8 @@ def make_roarm_from_configs(configs: dict[str, str]) -> (dict[str, roarm]):
return roarms
class RoarmRobot:
class RoarmRobot:
def __init__(
self,
config: RoarmRobotConfig,
@ -58,7 +57,7 @@ class RoarmRobot:
self.cameras = make_cameras_from_configs(self.config.cameras)
self.is_connected = False
self.logs = {}
@property
def camera_features(self) -> dict:
cam_ft = {}

View File

@ -7,9 +7,9 @@ from lerobot.common.robot_devices.robots.configs import (
LeKiwiRobotConfig,
ManipulatorRobotConfig,
MossRobotConfig,
RoarmRobotConfig,
RobotConfig,
So100RobotConfig,
RoarmRobotConfig,
StretchRobotConfig,
)
@ -73,6 +73,7 @@ def make_robot_from_config(config: RobotConfig):
return StretchRobot(config)
def make_robot(robot_type: str, **kwargs) -> Robot:
config = make_robot_config(robot_type, **kwargs)
return make_robot_from_config(config)