[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2025-03-05 12:12:42 +00:00
parent 472853a818
commit fd8032a023
5 changed files with 32 additions and 21 deletions

View File

@ -31,20 +31,20 @@ from lerobot.common.utils.utils import capture_timestamp_utc
MAX_OPENCV_INDEX = 60 MAX_OPENCV_INDEX = 60
undistort = True undistort = True
def undistort_image(image): def undistort_image(image):
import cv2 import cv2
camera_matrix = np.array([
[289.11451, 0., 347.23664], camera_matrix = np.array([[289.11451, 0.0, 347.23664], [0.0, 289.75319, 235.67429], [0.0, 0.0, 1.0]])
[0., 289.75319, 235.67429],
[0., 0., 1.]
])
dist_coeffs = np.array([-0.208848, 0.028006, -0.000705, -0.000820, 0.000000]) dist_coeffs = np.array([-0.208848, 0.028006, -0.000705, -0.000820, 0.000000])
undistorted_image = cv2.undistort(image, camera_matrix, dist_coeffs) undistorted_image = cv2.undistort(image, camera_matrix, dist_coeffs)
return undistorted_image return undistorted_image
def find_cameras(raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX, mock=False) -> list[dict]: def find_cameras(raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX, mock=False) -> list[dict]:
cameras = [] cameras = []
if platform.system() == "Linux": if platform.system() == "Linux":

View File

@ -153,13 +153,16 @@ def predict_action(observation, policy, device, use_amp):
# return listener, events # return listener, events
def init_keyboard_listener(): def init_keyboard_listener():
events = {} events = {}
events["exit_early"] = False events["exit_early"] = False
events["rerecord_episode"] = False events["rerecord_episode"] = False
events["stop_recording"] = False events["stop_recording"] = False
import threading
from sshkeyboard import listen_keyboard from sshkeyboard import listen_keyboard
import threading
def on_press(key): def on_press(key):
try: try:
if key == "right": if key == "right":
@ -179,7 +182,8 @@ def init_keyboard_listener():
listener = threading.Thread(target=listen_keyboard, kwargs={"on_press": on_press}) listener = threading.Thread(target=listen_keyboard, kwargs={"on_press": on_press})
listener.start() listener.start()
return listener,events return listener, events
def warmup_record( def warmup_record(
robot, robot,
@ -284,7 +288,7 @@ def control_loop(
dataset.add_frame(frame) dataset.add_frame(frame)
if display_cameras: if display_cameras:
# if display_cameras and not is_headless(): # if display_cameras and not is_headless():
image_keys = [key for key in observation if "image" in key] image_keys = [key for key in observation if "image" in key]
for key in image_keys: for key in image_keys:
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)) cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
@ -316,6 +320,7 @@ def reset_environment(robot, events, reset_time_s, fps):
teleoperate=True, teleoperate=True,
) )
# def stop_recording(robot, listener, display_cameras): # def stop_recording(robot, listener, display_cameras):
# robot.disconnect() # robot.disconnect()
@ -326,16 +331,19 @@ def reset_environment(robot, events, reset_time_s, fps):
# if display_cameras: # if display_cameras:
# cv2.destroyAllWindows() # cv2.destroyAllWindows()
def stop_recording(robot, listener, display_cameras): def stop_recording(robot, listener, display_cameras):
robot.disconnect() robot.disconnect()
from sshkeyboard import stop_listening from sshkeyboard import stop_listening
if listener is not None: if listener is not None:
stop_listening() stop_listening()
if display_cameras: if display_cameras:
cv2.destroyAllWindows() cv2.destroyAllWindows()
def sanity_check_dataset_name(repo_id, policy_cfg): def sanity_check_dataset_name(repo_id, policy_cfg):
_, dataset_name = repo_id.split("/") _, dataset_name = repo_id.split("/")
# either repo_id doesnt start with "eval_" and there is no policy # either repo_id doesnt start with "eval_" and there is no policy

View File

@ -479,6 +479,7 @@ class So100RobotConfig(ManipulatorRobotConfig):
mock: bool = False mock: bool = False
@RobotConfig.register_subclass("roarm_m3") @RobotConfig.register_subclass("roarm_m3")
@dataclass @dataclass
class RoarmRobotConfig(RobotConfig): class RoarmRobotConfig(RobotConfig):
@ -489,14 +490,14 @@ class RoarmRobotConfig(RobotConfig):
leader_arms: dict[str, str] = field( leader_arms: dict[str, str] = field(
default_factory=lambda: { default_factory=lambda: {
"main": "/dev/ttyUSB0", "main": "/dev/ttyUSB0",
} }
) )
# Follower arms configuration: left and right ports # Follower arms configuration: left and right ports
follower_arms: dict[str, str] = field( follower_arms: dict[str, str] = field(
default_factory=lambda: { default_factory=lambda: {
"main": "/dev/ttyUSB1", "main": "/dev/ttyUSB1",
} }
) )
@ -519,6 +520,7 @@ class RoarmRobotConfig(RobotConfig):
mock: bool = False mock: bool = False
@RobotConfig.register_subclass("stretch") @RobotConfig.register_subclass("stretch")
@dataclass @dataclass
class StretchRobotConfig(RobotConfig): class StretchRobotConfig(RobotConfig):
@ -554,6 +556,7 @@ class StretchRobotConfig(RobotConfig):
mock: bool = False mock: bool = False
@RobotConfig.register_subclass("lekiwi") @RobotConfig.register_subclass("lekiwi")
@dataclass @dataclass
class LeKiwiRobotConfig(RobotConfig): class LeKiwiRobotConfig(RobotConfig):

View File

@ -4,19 +4,17 @@ and send orders to its motors.
# TODO(rcadene, aliberts): reorganize the codebase into one file per robot, with the associated # TODO(rcadene, aliberts): reorganize the codebase into one file per robot, with the associated
# calibration procedure, to make it easy for people to add their own robot. # calibration procedure, to make it easy for people to add their own robot.
import json
import logging import logging
import time import time
import warnings
from pathlib import Path
import numpy as np import numpy as np
import torch import torch
from roarm_sdk.roarm import roarm
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
from lerobot.common.robot_devices.robots.configs import RoarmRobotConfig from lerobot.common.robot_devices.robots.configs import RoarmRobotConfig
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
from roarm_sdk.roarm import roarm
def ensure_safe_goal_position( def ensure_safe_goal_position(
goal_pos: torch.Tensor, present_pos: torch.Tensor, max_relative_target: float | list[float] goal_pos: torch.Tensor, present_pos: torch.Tensor, max_relative_target: float | list[float]
@ -37,7 +35,8 @@ def ensure_safe_goal_position(
return safe_goal_pos return safe_goal_pos
def make_roarm_from_configs(configs: dict[str, str]) -> (dict[str, roarm]):
def make_roarm_from_configs(configs: dict[str, str]) -> dict[str, roarm]:
roarms = {} roarms = {}
for key, port in configs.items(): for key, port in configs.items():
@ -45,8 +44,8 @@ def make_roarm_from_configs(configs: dict[str, str]) -> (dict[str, roarm]):
return roarms return roarms
class RoarmRobot:
class RoarmRobot:
def __init__( def __init__(
self, self,
config: RoarmRobotConfig, config: RoarmRobotConfig,
@ -58,7 +57,7 @@ class RoarmRobot:
self.cameras = make_cameras_from_configs(self.config.cameras) self.cameras = make_cameras_from_configs(self.config.cameras)
self.is_connected = False self.is_connected = False
self.logs = {} self.logs = {}
@property @property
def camera_features(self) -> dict: def camera_features(self) -> dict:
cam_ft = {} cam_ft = {}

View File

@ -7,9 +7,9 @@ from lerobot.common.robot_devices.robots.configs import (
LeKiwiRobotConfig, LeKiwiRobotConfig,
ManipulatorRobotConfig, ManipulatorRobotConfig,
MossRobotConfig, MossRobotConfig,
RoarmRobotConfig,
RobotConfig, RobotConfig,
So100RobotConfig, So100RobotConfig,
RoarmRobotConfig,
StretchRobotConfig, StretchRobotConfig,
) )
@ -73,6 +73,7 @@ def make_robot_from_config(config: RobotConfig):
return StretchRobot(config) return StretchRobot(config)
def make_robot(robot_type: str, **kwargs) -> Robot: def make_robot(robot_type: str, **kwargs) -> Robot:
config = make_robot_config(robot_type, **kwargs) config = make_robot_config(robot_type, **kwargs)
return make_robot_from_config(config) return make_robot_from_config(config)