Improve type hints

This commit is contained in:
Simon Alibert 2025-03-15 21:33:45 +01:00
parent 858678786a
commit bd5b181dfd
3 changed files with 18 additions and 15 deletions

View File

@ -518,7 +518,7 @@ class DynamixelMotorsBus(MotorsBus):
else:
return values[0]
def _read(self, data_name, motor_names: str | list[str] | None = None):
def _read(self, data_name: str, motor_names: list[str]):
import dynamixel_sdk as dxl
motor_ids = []

View File

@ -372,7 +372,7 @@ class FeetechMotorsBus(MotorsBus):
else:
return values[0]
def _read(self, data_name, motor_names: str | list[str] | None = None):
def _read(self, data_name: str, motor_names: list[str]):
import scservo_sdk as scs
motor_ids = []

View File

@ -20,9 +20,9 @@
# ruff: noqa: N802
import abc
import enum
import time
import traceback
from enum import Enum
from typing import Protocol
import numpy as np
@ -34,18 +34,18 @@ from lerobot.common.utils.utils import capture_timestamp_utc
MAX_ID_RANGE = 252
def get_group_sync_key(data_name: str, motor_names: list[str]):
def get_group_sync_key(data_name: str, motor_names: list[str]) -> str:
group_key = f"{data_name}_" + "_".join(motor_names)
return group_key
def get_log_name(var_name: str, fn_name: str, data_name: str, motor_names: list[str]):
def get_log_name(var_name: str, fn_name: str, data_name: str, motor_names: list[str]) -> str:
group_key = get_group_sync_key(data_name, motor_names)
log_name = f"{var_name}_{fn_name}_{group_key}"
return log_name
def assert_same_address(model_ctrl_table: dict[str, dict], motor_models: list[str], data_name: str):
def assert_same_address(model_ctrl_table: dict[str, dict], motor_models: list[str], data_name: str) -> None:
all_addr = []
all_bytes = []
for model in motor_models:
@ -55,26 +55,28 @@ def assert_same_address(model_ctrl_table: dict[str, dict], motor_models: list[st
if len(set(all_addr)) != 1:
raise NotImplementedError(
f"At least two motor models use a different address for `data_name`='{data_name}' ({list(zip(motor_models, all_addr, strict=False))}). Contact a LeRobot maintainer."
f"At least two motor models use a different address for `data_name`='{data_name}'"
f"({list(zip(motor_models, all_addr, strict=False))}). Contact a LeRobot maintainer."
)
if len(set(all_bytes)) != 1:
raise NotImplementedError(
f"At least two motor models use a different bytes representation for `data_name`='{data_name}' ({list(zip(motor_models, all_bytes, strict=False))}). Contact a LeRobot maintainer."
f"At least two motor models use a different bytes representation for `data_name`='{data_name}'"
f"({list(zip(motor_models, all_bytes, strict=False))}). Contact a LeRobot maintainer."
)
class TorqueMode(enum.Enum):
class TorqueMode(Enum):
ENABLED = 1
DISABLED = 0
class DriveMode(enum.Enum):
class DriveMode(Enum):
NON_INVERTED = 0
INVERTED = 1
class CalibrationMode(enum.Enum):
class CalibrationMode(Enum):
# Joints with rotational motions are expressed in degrees in nominal range of [-180, 180]
DEGREE = 0
# Joints with liner motions (like gripper of Aloha) are expressed in nominal range of [0, 100]
@ -246,7 +248,8 @@ class MotorsBus(abc.ABC):
if idx != present_idx:
# sanity check
raise OSError(
"Motor index used to communicate through the bus is not the same as the one present in the motor memory. The motor memory might be damaged."
"Motor index used to communicate through the bus is not the same as the one present in the motor "
"memory. The motor memory might be damaged."
)
indices.append(idx)
@ -298,12 +301,12 @@ class MotorsBus(abc.ABC):
return values
@abc.abstractmethod
def _read(self, data_name, motor_names: str | list[str] | None = None):
def _read(self, data_name: str, motor_names: list[str]):
pass
def write(
self, data_name: str, values: int | float | np.ndarray, motor_names: str | list[str] | None = None
):
) -> None:
if not self.is_connected:
raise DeviceNotConnectedError(
f"{self.__name__}({self.port}) is not connected. You need to run `{self.__name__}.connect()`."
@ -335,7 +338,7 @@ class MotorsBus(abc.ABC):
def _write(self, data_name: str, values: list[int], motor_names: list[str]) -> None:
pass
def disconnect(self):
def disconnect(self) -> None:
if not self.is_connected:
raise DeviceNotConnectedError(
f"{self.__name__}({self.port}) is not connected. Try running `{self.__name__}.connect()` first."