Improve type hints

This commit is contained in:
Simon Alibert 2025-03-15 21:33:45 +01:00
parent 858678786a
commit bd5b181dfd
3 changed files with 18 additions and 15 deletions

View File

@ -518,7 +518,7 @@ class DynamixelMotorsBus(MotorsBus):
else: else:
return values[0] return values[0]
def _read(self, data_name, motor_names: str | list[str] | None = None): def _read(self, data_name: str, motor_names: list[str]):
import dynamixel_sdk as dxl import dynamixel_sdk as dxl
motor_ids = [] motor_ids = []

View File

@ -372,7 +372,7 @@ class FeetechMotorsBus(MotorsBus):
else: else:
return values[0] return values[0]
def _read(self, data_name, motor_names: str | list[str] | None = None): def _read(self, data_name: str, motor_names: list[str]):
import scservo_sdk as scs import scservo_sdk as scs
motor_ids = [] motor_ids = []

View File

@ -20,9 +20,9 @@
# ruff: noqa: N802 # ruff: noqa: N802
import abc import abc
import enum
import time import time
import traceback import traceback
from enum import Enum
from typing import Protocol from typing import Protocol
import numpy as np import numpy as np
@ -34,18 +34,18 @@ from lerobot.common.utils.utils import capture_timestamp_utc
MAX_ID_RANGE = 252 MAX_ID_RANGE = 252
def get_group_sync_key(data_name: str, motor_names: list[str]): def get_group_sync_key(data_name: str, motor_names: list[str]) -> str:
group_key = f"{data_name}_" + "_".join(motor_names) group_key = f"{data_name}_" + "_".join(motor_names)
return group_key return group_key
def get_log_name(var_name: str, fn_name: str, data_name: str, motor_names: list[str]): def get_log_name(var_name: str, fn_name: str, data_name: str, motor_names: list[str]) -> str:
group_key = get_group_sync_key(data_name, motor_names) group_key = get_group_sync_key(data_name, motor_names)
log_name = f"{var_name}_{fn_name}_{group_key}" log_name = f"{var_name}_{fn_name}_{group_key}"
return log_name return log_name
def assert_same_address(model_ctrl_table: dict[str, dict], motor_models: list[str], data_name: str): def assert_same_address(model_ctrl_table: dict[str, dict], motor_models: list[str], data_name: str) -> None:
all_addr = [] all_addr = []
all_bytes = [] all_bytes = []
for model in motor_models: for model in motor_models:
@ -55,26 +55,28 @@ def assert_same_address(model_ctrl_table: dict[str, dict], motor_models: list[st
if len(set(all_addr)) != 1: if len(set(all_addr)) != 1:
raise NotImplementedError( raise NotImplementedError(
f"At least two motor models use a different address for `data_name`='{data_name}' ({list(zip(motor_models, all_addr, strict=False))}). Contact a LeRobot maintainer." f"At least two motor models use a different address for `data_name`='{data_name}'"
f"({list(zip(motor_models, all_addr, strict=False))}). Contact a LeRobot maintainer."
) )
if len(set(all_bytes)) != 1: if len(set(all_bytes)) != 1:
raise NotImplementedError( raise NotImplementedError(
f"At least two motor models use a different bytes representation for `data_name`='{data_name}' ({list(zip(motor_models, all_bytes, strict=False))}). Contact a LeRobot maintainer." f"At least two motor models use a different bytes representation for `data_name`='{data_name}'"
f"({list(zip(motor_models, all_bytes, strict=False))}). Contact a LeRobot maintainer."
) )
class TorqueMode(enum.Enum): class TorqueMode(Enum):
ENABLED = 1 ENABLED = 1
DISABLED = 0 DISABLED = 0
class DriveMode(enum.Enum): class DriveMode(Enum):
NON_INVERTED = 0 NON_INVERTED = 0
INVERTED = 1 INVERTED = 1
class CalibrationMode(enum.Enum): class CalibrationMode(Enum):
# Joints with rotational motions are expressed in degrees in nominal range of [-180, 180] # Joints with rotational motions are expressed in degrees in nominal range of [-180, 180]
DEGREE = 0 DEGREE = 0
# Joints with liner motions (like gripper of Aloha) are expressed in nominal range of [0, 100] # Joints with liner motions (like gripper of Aloha) are expressed in nominal range of [0, 100]
@ -246,7 +248,8 @@ class MotorsBus(abc.ABC):
if idx != present_idx: if idx != present_idx:
# sanity check # sanity check
raise OSError( raise OSError(
"Motor index used to communicate through the bus is not the same as the one present in the motor memory. The motor memory might be damaged." "Motor index used to communicate through the bus is not the same as the one present in the motor "
"memory. The motor memory might be damaged."
) )
indices.append(idx) indices.append(idx)
@ -298,12 +301,12 @@ class MotorsBus(abc.ABC):
return values return values
@abc.abstractmethod @abc.abstractmethod
def _read(self, data_name, motor_names: str | list[str] | None = None): def _read(self, data_name: str, motor_names: list[str]):
pass pass
def write( def write(
self, data_name: str, values: int | float | np.ndarray, motor_names: str | list[str] | None = None self, data_name: str, values: int | float | np.ndarray, motor_names: str | list[str] | None = None
): ) -> None:
if not self.is_connected: if not self.is_connected:
raise DeviceNotConnectedError( raise DeviceNotConnectedError(
f"{self.__name__}({self.port}) is not connected. You need to run `{self.__name__}.connect()`." f"{self.__name__}({self.port}) is not connected. You need to run `{self.__name__}.connect()`."
@ -335,7 +338,7 @@ class MotorsBus(abc.ABC):
def _write(self, data_name: str, values: list[int], motor_names: list[str]) -> None: def _write(self, data_name: str, values: list[int], motor_names: list[str]) -> None:
pass pass
def disconnect(self): def disconnect(self) -> None:
if not self.is_connected: if not self.is_connected:
raise DeviceNotConnectedError( raise DeviceNotConnectedError(
f"{self.__name__}({self.port}) is not connected. Try running `{self.__name__}.connect()` first." f"{self.__name__}({self.port}) is not connected. Try running `{self.__name__}.connect()` first."