From bd5b181dfde6b26f34f9bafc37f988fd06e9c4ec Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Sat, 15 Mar 2025 21:33:45 +0100 Subject: [PATCH] Improve type hints --- lerobot/common/motors/dynamixel/dynamixel.py | 2 +- lerobot/common/motors/feetech/feetech.py | 2 +- lerobot/common/motors/motors_bus.py | 29 +++++++++++--------- 3 files changed, 18 insertions(+), 15 deletions(-) diff --git a/lerobot/common/motors/dynamixel/dynamixel.py b/lerobot/common/motors/dynamixel/dynamixel.py index b0f9556f..8f780de0 100644 --- a/lerobot/common/motors/dynamixel/dynamixel.py +++ b/lerobot/common/motors/dynamixel/dynamixel.py @@ -518,7 +518,7 @@ class DynamixelMotorsBus(MotorsBus): else: return values[0] - def _read(self, data_name, motor_names: str | list[str] | None = None): + def _read(self, data_name: str, motor_names: list[str]): import dynamixel_sdk as dxl motor_ids = [] diff --git a/lerobot/common/motors/feetech/feetech.py b/lerobot/common/motors/feetech/feetech.py index a0bcc083..3b4d4e84 100644 --- a/lerobot/common/motors/feetech/feetech.py +++ b/lerobot/common/motors/feetech/feetech.py @@ -372,7 +372,7 @@ class FeetechMotorsBus(MotorsBus): else: return values[0] - def _read(self, data_name, motor_names: str | list[str] | None = None): + def _read(self, data_name: str, motor_names: list[str]): import scservo_sdk as scs motor_ids = [] diff --git a/lerobot/common/motors/motors_bus.py b/lerobot/common/motors/motors_bus.py index 65fe54da..e2fdcb73 100644 --- a/lerobot/common/motors/motors_bus.py +++ b/lerobot/common/motors/motors_bus.py @@ -20,9 +20,9 @@ # ruff: noqa: N802 import abc -import enum import time import traceback +from enum import Enum from typing import Protocol import numpy as np @@ -34,18 +34,18 @@ from lerobot.common.utils.utils import capture_timestamp_utc MAX_ID_RANGE = 252 -def get_group_sync_key(data_name: str, motor_names: list[str]): +def get_group_sync_key(data_name: str, motor_names: list[str]) -> str: group_key = f"{data_name}_" + "_".join(motor_names) return group_key -def get_log_name(var_name: str, fn_name: str, data_name: str, motor_names: list[str]): +def get_log_name(var_name: str, fn_name: str, data_name: str, motor_names: list[str]) -> str: group_key = get_group_sync_key(data_name, motor_names) log_name = f"{var_name}_{fn_name}_{group_key}" return log_name -def assert_same_address(model_ctrl_table: dict[str, dict], motor_models: list[str], data_name: str): +def assert_same_address(model_ctrl_table: dict[str, dict], motor_models: list[str], data_name: str) -> None: all_addr = [] all_bytes = [] for model in motor_models: @@ -55,26 +55,28 @@ def assert_same_address(model_ctrl_table: dict[str, dict], motor_models: list[st if len(set(all_addr)) != 1: raise NotImplementedError( - f"At least two motor models use a different address for `data_name`='{data_name}' ({list(zip(motor_models, all_addr, strict=False))}). Contact a LeRobot maintainer." + f"At least two motor models use a different address for `data_name`='{data_name}'" + f"({list(zip(motor_models, all_addr, strict=False))}). Contact a LeRobot maintainer." ) if len(set(all_bytes)) != 1: raise NotImplementedError( - f"At least two motor models use a different bytes representation for `data_name`='{data_name}' ({list(zip(motor_models, all_bytes, strict=False))}). Contact a LeRobot maintainer." + f"At least two motor models use a different bytes representation for `data_name`='{data_name}'" + f"({list(zip(motor_models, all_bytes, strict=False))}). Contact a LeRobot maintainer." ) -class TorqueMode(enum.Enum): +class TorqueMode(Enum): ENABLED = 1 DISABLED = 0 -class DriveMode(enum.Enum): +class DriveMode(Enum): NON_INVERTED = 0 INVERTED = 1 -class CalibrationMode(enum.Enum): +class CalibrationMode(Enum): # Joints with rotational motions are expressed in degrees in nominal range of [-180, 180] DEGREE = 0 # Joints with liner motions (like gripper of Aloha) are expressed in nominal range of [0, 100] @@ -246,7 +248,8 @@ class MotorsBus(abc.ABC): if idx != present_idx: # sanity check raise OSError( - "Motor index used to communicate through the bus is not the same as the one present in the motor memory. The motor memory might be damaged." + "Motor index used to communicate through the bus is not the same as the one present in the motor " + "memory. The motor memory might be damaged." ) indices.append(idx) @@ -298,12 +301,12 @@ class MotorsBus(abc.ABC): return values @abc.abstractmethod - def _read(self, data_name, motor_names: str | list[str] | None = None): + def _read(self, data_name: str, motor_names: list[str]): pass def write( self, data_name: str, values: int | float | np.ndarray, motor_names: str | list[str] | None = None - ): + ) -> None: if not self.is_connected: raise DeviceNotConnectedError( f"{self.__name__}({self.port}) is not connected. You need to run `{self.__name__}.connect()`." @@ -335,7 +338,7 @@ class MotorsBus(abc.ABC): def _write(self, data_name: str, values: list[int], motor_names: list[str]) -> None: pass - def disconnect(self): + def disconnect(self) -> None: if not self.is_connected: raise DeviceNotConnectedError( f"{self.__name__}({self.port}) is not connected. Try running `{self.__name__}.connect()` first."