Improve type hints
This commit is contained in:
parent
858678786a
commit
bd5b181dfd
|
@ -518,7 +518,7 @@ class DynamixelMotorsBus(MotorsBus):
|
||||||
else:
|
else:
|
||||||
return values[0]
|
return values[0]
|
||||||
|
|
||||||
def _read(self, data_name, motor_names: str | list[str] | None = None):
|
def _read(self, data_name: str, motor_names: list[str]):
|
||||||
import dynamixel_sdk as dxl
|
import dynamixel_sdk as dxl
|
||||||
|
|
||||||
motor_ids = []
|
motor_ids = []
|
||||||
|
|
|
@ -372,7 +372,7 @@ class FeetechMotorsBus(MotorsBus):
|
||||||
else:
|
else:
|
||||||
return values[0]
|
return values[0]
|
||||||
|
|
||||||
def _read(self, data_name, motor_names: str | list[str] | None = None):
|
def _read(self, data_name: str, motor_names: list[str]):
|
||||||
import scservo_sdk as scs
|
import scservo_sdk as scs
|
||||||
|
|
||||||
motor_ids = []
|
motor_ids = []
|
||||||
|
|
|
@ -20,9 +20,9 @@
|
||||||
# ruff: noqa: N802
|
# ruff: noqa: N802
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
import enum
|
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
from enum import Enum
|
||||||
from typing import Protocol
|
from typing import Protocol
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -34,18 +34,18 @@ from lerobot.common.utils.utils import capture_timestamp_utc
|
||||||
MAX_ID_RANGE = 252
|
MAX_ID_RANGE = 252
|
||||||
|
|
||||||
|
|
||||||
def get_group_sync_key(data_name: str, motor_names: list[str]):
|
def get_group_sync_key(data_name: str, motor_names: list[str]) -> str:
|
||||||
group_key = f"{data_name}_" + "_".join(motor_names)
|
group_key = f"{data_name}_" + "_".join(motor_names)
|
||||||
return group_key
|
return group_key
|
||||||
|
|
||||||
|
|
||||||
def get_log_name(var_name: str, fn_name: str, data_name: str, motor_names: list[str]):
|
def get_log_name(var_name: str, fn_name: str, data_name: str, motor_names: list[str]) -> str:
|
||||||
group_key = get_group_sync_key(data_name, motor_names)
|
group_key = get_group_sync_key(data_name, motor_names)
|
||||||
log_name = f"{var_name}_{fn_name}_{group_key}"
|
log_name = f"{var_name}_{fn_name}_{group_key}"
|
||||||
return log_name
|
return log_name
|
||||||
|
|
||||||
|
|
||||||
def assert_same_address(model_ctrl_table: dict[str, dict], motor_models: list[str], data_name: str):
|
def assert_same_address(model_ctrl_table: dict[str, dict], motor_models: list[str], data_name: str) -> None:
|
||||||
all_addr = []
|
all_addr = []
|
||||||
all_bytes = []
|
all_bytes = []
|
||||||
for model in motor_models:
|
for model in motor_models:
|
||||||
|
@ -55,26 +55,28 @@ def assert_same_address(model_ctrl_table: dict[str, dict], motor_models: list[st
|
||||||
|
|
||||||
if len(set(all_addr)) != 1:
|
if len(set(all_addr)) != 1:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"At least two motor models use a different address for `data_name`='{data_name}' ({list(zip(motor_models, all_addr, strict=False))}). Contact a LeRobot maintainer."
|
f"At least two motor models use a different address for `data_name`='{data_name}'"
|
||||||
|
f"({list(zip(motor_models, all_addr, strict=False))}). Contact a LeRobot maintainer."
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(set(all_bytes)) != 1:
|
if len(set(all_bytes)) != 1:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"At least two motor models use a different bytes representation for `data_name`='{data_name}' ({list(zip(motor_models, all_bytes, strict=False))}). Contact a LeRobot maintainer."
|
f"At least two motor models use a different bytes representation for `data_name`='{data_name}'"
|
||||||
|
f"({list(zip(motor_models, all_bytes, strict=False))}). Contact a LeRobot maintainer."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TorqueMode(enum.Enum):
|
class TorqueMode(Enum):
|
||||||
ENABLED = 1
|
ENABLED = 1
|
||||||
DISABLED = 0
|
DISABLED = 0
|
||||||
|
|
||||||
|
|
||||||
class DriveMode(enum.Enum):
|
class DriveMode(Enum):
|
||||||
NON_INVERTED = 0
|
NON_INVERTED = 0
|
||||||
INVERTED = 1
|
INVERTED = 1
|
||||||
|
|
||||||
|
|
||||||
class CalibrationMode(enum.Enum):
|
class CalibrationMode(Enum):
|
||||||
# Joints with rotational motions are expressed in degrees in nominal range of [-180, 180]
|
# Joints with rotational motions are expressed in degrees in nominal range of [-180, 180]
|
||||||
DEGREE = 0
|
DEGREE = 0
|
||||||
# Joints with liner motions (like gripper of Aloha) are expressed in nominal range of [0, 100]
|
# Joints with liner motions (like gripper of Aloha) are expressed in nominal range of [0, 100]
|
||||||
|
@ -246,7 +248,8 @@ class MotorsBus(abc.ABC):
|
||||||
if idx != present_idx:
|
if idx != present_idx:
|
||||||
# sanity check
|
# sanity check
|
||||||
raise OSError(
|
raise OSError(
|
||||||
"Motor index used to communicate through the bus is not the same as the one present in the motor memory. The motor memory might be damaged."
|
"Motor index used to communicate through the bus is not the same as the one present in the motor "
|
||||||
|
"memory. The motor memory might be damaged."
|
||||||
)
|
)
|
||||||
indices.append(idx)
|
indices.append(idx)
|
||||||
|
|
||||||
|
@ -298,12 +301,12 @@ class MotorsBus(abc.ABC):
|
||||||
return values
|
return values
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def _read(self, data_name, motor_names: str | list[str] | None = None):
|
def _read(self, data_name: str, motor_names: list[str]):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def write(
|
def write(
|
||||||
self, data_name: str, values: int | float | np.ndarray, motor_names: str | list[str] | None = None
|
self, data_name: str, values: int | float | np.ndarray, motor_names: str | list[str] | None = None
|
||||||
):
|
) -> None:
|
||||||
if not self.is_connected:
|
if not self.is_connected:
|
||||||
raise DeviceNotConnectedError(
|
raise DeviceNotConnectedError(
|
||||||
f"{self.__name__}({self.port}) is not connected. You need to run `{self.__name__}.connect()`."
|
f"{self.__name__}({self.port}) is not connected. You need to run `{self.__name__}.connect()`."
|
||||||
|
@ -335,7 +338,7 @@ class MotorsBus(abc.ABC):
|
||||||
def _write(self, data_name: str, values: list[int], motor_names: list[str]) -> None:
|
def _write(self, data_name: str, values: list[int], motor_names: list[str]) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def disconnect(self):
|
def disconnect(self) -> None:
|
||||||
if not self.is_connected:
|
if not self.is_connected:
|
||||||
raise DeviceNotConnectedError(
|
raise DeviceNotConnectedError(
|
||||||
f"{self.__name__}({self.port}) is not connected. Try running `{self.__name__}.connect()` first."
|
f"{self.__name__}({self.port}) is not connected. Try running `{self.__name__}.connect()` first."
|
||||||
|
|
Loading…
Reference in New Issue