Improve type hints
This commit is contained in:
parent
858678786a
commit
bd5b181dfd
|
@ -518,7 +518,7 @@ class DynamixelMotorsBus(MotorsBus):
|
|||
else:
|
||||
return values[0]
|
||||
|
||||
def _read(self, data_name, motor_names: str | list[str] | None = None):
|
||||
def _read(self, data_name: str, motor_names: list[str]):
|
||||
import dynamixel_sdk as dxl
|
||||
|
||||
motor_ids = []
|
||||
|
|
|
@ -372,7 +372,7 @@ class FeetechMotorsBus(MotorsBus):
|
|||
else:
|
||||
return values[0]
|
||||
|
||||
def _read(self, data_name, motor_names: str | list[str] | None = None):
|
||||
def _read(self, data_name: str, motor_names: list[str]):
|
||||
import scservo_sdk as scs
|
||||
|
||||
motor_ids = []
|
||||
|
|
|
@ -20,9 +20,9 @@
|
|||
# ruff: noqa: N802
|
||||
|
||||
import abc
|
||||
import enum
|
||||
import time
|
||||
import traceback
|
||||
from enum import Enum
|
||||
from typing import Protocol
|
||||
|
||||
import numpy as np
|
||||
|
@ -34,18 +34,18 @@ from lerobot.common.utils.utils import capture_timestamp_utc
|
|||
MAX_ID_RANGE = 252
|
||||
|
||||
|
||||
def get_group_sync_key(data_name: str, motor_names: list[str]):
|
||||
def get_group_sync_key(data_name: str, motor_names: list[str]) -> str:
|
||||
group_key = f"{data_name}_" + "_".join(motor_names)
|
||||
return group_key
|
||||
|
||||
|
||||
def get_log_name(var_name: str, fn_name: str, data_name: str, motor_names: list[str]):
|
||||
def get_log_name(var_name: str, fn_name: str, data_name: str, motor_names: list[str]) -> str:
|
||||
group_key = get_group_sync_key(data_name, motor_names)
|
||||
log_name = f"{var_name}_{fn_name}_{group_key}"
|
||||
return log_name
|
||||
|
||||
|
||||
def assert_same_address(model_ctrl_table: dict[str, dict], motor_models: list[str], data_name: str):
|
||||
def assert_same_address(model_ctrl_table: dict[str, dict], motor_models: list[str], data_name: str) -> None:
|
||||
all_addr = []
|
||||
all_bytes = []
|
||||
for model in motor_models:
|
||||
|
@ -55,26 +55,28 @@ def assert_same_address(model_ctrl_table: dict[str, dict], motor_models: list[st
|
|||
|
||||
if len(set(all_addr)) != 1:
|
||||
raise NotImplementedError(
|
||||
f"At least two motor models use a different address for `data_name`='{data_name}' ({list(zip(motor_models, all_addr, strict=False))}). Contact a LeRobot maintainer."
|
||||
f"At least two motor models use a different address for `data_name`='{data_name}'"
|
||||
f"({list(zip(motor_models, all_addr, strict=False))}). Contact a LeRobot maintainer."
|
||||
)
|
||||
|
||||
if len(set(all_bytes)) != 1:
|
||||
raise NotImplementedError(
|
||||
f"At least two motor models use a different bytes representation for `data_name`='{data_name}' ({list(zip(motor_models, all_bytes, strict=False))}). Contact a LeRobot maintainer."
|
||||
f"At least two motor models use a different bytes representation for `data_name`='{data_name}'"
|
||||
f"({list(zip(motor_models, all_bytes, strict=False))}). Contact a LeRobot maintainer."
|
||||
)
|
||||
|
||||
|
||||
class TorqueMode(enum.Enum):
|
||||
class TorqueMode(Enum):
|
||||
ENABLED = 1
|
||||
DISABLED = 0
|
||||
|
||||
|
||||
class DriveMode(enum.Enum):
|
||||
class DriveMode(Enum):
|
||||
NON_INVERTED = 0
|
||||
INVERTED = 1
|
||||
|
||||
|
||||
class CalibrationMode(enum.Enum):
|
||||
class CalibrationMode(Enum):
|
||||
# Joints with rotational motions are expressed in degrees in nominal range of [-180, 180]
|
||||
DEGREE = 0
|
||||
# Joints with liner motions (like gripper of Aloha) are expressed in nominal range of [0, 100]
|
||||
|
@ -246,7 +248,8 @@ class MotorsBus(abc.ABC):
|
|||
if idx != present_idx:
|
||||
# sanity check
|
||||
raise OSError(
|
||||
"Motor index used to communicate through the bus is not the same as the one present in the motor memory. The motor memory might be damaged."
|
||||
"Motor index used to communicate through the bus is not the same as the one present in the motor "
|
||||
"memory. The motor memory might be damaged."
|
||||
)
|
||||
indices.append(idx)
|
||||
|
||||
|
@ -298,12 +301,12 @@ class MotorsBus(abc.ABC):
|
|||
return values
|
||||
|
||||
@abc.abstractmethod
|
||||
def _read(self, data_name, motor_names: str | list[str] | None = None):
|
||||
def _read(self, data_name: str, motor_names: list[str]):
|
||||
pass
|
||||
|
||||
def write(
|
||||
self, data_name: str, values: int | float | np.ndarray, motor_names: str | list[str] | None = None
|
||||
):
|
||||
) -> None:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
f"{self.__name__}({self.port}) is not connected. You need to run `{self.__name__}.connect()`."
|
||||
|
@ -335,7 +338,7 @@ class MotorsBus(abc.ABC):
|
|||
def _write(self, data_name: str, values: list[int], motor_names: list[str]) -> None:
|
||||
pass
|
||||
|
||||
def disconnect(self):
|
||||
def disconnect(self) -> None:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
f"{self.__name__}({self.port}) is not connected. Try running `{self.__name__}.connect()` first."
|
||||
|
|
Loading…
Reference in New Issue