From cf963eb1b0fbc929b765ee75b1aba2df19188d31 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Tue, 25 Mar 2025 11:12:26 +0100 Subject: [PATCH] Ensure motors exist at connection time --- lerobot/common/motors/dynamixel/dynamixel.py | 4 +- lerobot/common/motors/feetech/feetech.py | 14 ++- lerobot/common/motors/motors_bus.py | 95 ++++++++++++++------ 3 files changed, 80 insertions(+), 33 deletions(-) diff --git a/lerobot/common/motors/dynamixel/dynamixel.py b/lerobot/common/motors/dynamixel/dynamixel.py index affe774b..11094f60 100644 --- a/lerobot/common/motors/dynamixel/dynamixel.py +++ b/lerobot/common/motors/dynamixel/dynamixel.py @@ -138,7 +138,7 @@ class DynamixelMotorsBus(MotorsBus): ] return data - def broadcast_ping(self, num_retry: int = 0, raise_on_error: bool = False) -> dict[int, str] | None: + def broadcast_ping(self, num_retry: int = 0, raise_on_error: bool = False) -> dict[int, int] | None: for n_try in range(1 + num_retry): data_list, comm = self.packet_handler.broadcastPing(self.port_handler) if self._is_comm_success(comm): @@ -152,4 +152,4 @@ class DynamixelMotorsBus(MotorsBus): return data_list if data_list else None - return {id_: self._model_nb_to_model(data[0]) for id_, data in data_list.items()} + return {id_: data[0] for id_, data in data_list.items()} diff --git a/lerobot/common/motors/feetech/feetech.py b/lerobot/common/motors/feetech/feetech.py index 67f56efd..188350af 100644 --- a/lerobot/common/motors/feetech/feetech.py +++ b/lerobot/common/motors/feetech/feetech.py @@ -193,7 +193,7 @@ class FeetechMotorsBus(MotorsBus): del rxpacket[0:id_] rx_length = rx_length - id_ - def broadcast_ping(self, num_retry: int = 0, raise_on_error: bool = False) -> dict[int, str] | None: + def broadcast_ping(self, num_retry: int = 0, raise_on_error: bool = False) -> dict[int, int] | None: for n_try in range(1 + num_retry): ids_status, comm = self._broadcast_ping() if self._is_comm_success(comm): @@ -211,5 +211,13 @@ class FeetechMotorsBus(MotorsBus): if ids_errors: display_dict = {id_: self.packet_handler.getRxPacketError(err) for id_, err in ids_errors.items()} logger.error(f"Some motors found returned an error status:\n{pformat(display_dict, indent=4)}") - model_numbers = self.sync_read("Model_Number", list(ids_status), num_retry) - return {id_: self._model_nb_to_model(model_nb) for id_, model_nb in model_numbers.items()} + comm, model_numbers = self._sync_read( + "Model_Number", list(ids_status), model="scs_series", num_retry=num_retry + ) + if not self._is_comm_success(comm): + if raise_on_error: + raise ConnectionError(self.packet_handler.getRxPacketError(comm)) + + return model_numbers if model_numbers else None + + return model_numbers diff --git a/lerobot/common/motors/motors_bus.py b/lerobot/common/motors/motors_bus.py index 224726e7..37b1fd77 100644 --- a/lerobot/common/motors/motors_bus.py +++ b/lerobot/common/motors/motors_bus.py @@ -42,11 +42,27 @@ MAX_ID_RANGE = 252 logger = logging.getLogger(__name__) +def get_ctrl_table(model_ctrl_table: dict[str, dict], model: str) -> dict[str, tuple[int, int]]: + try: + return model_ctrl_table[model] + except KeyError: + raise KeyError(f"Control table for {model=} not found.") from None + + +def get_address(model_ctrl_table: dict[str, dict], model: str, data_name: str) -> tuple[int, int]: + ctrl_table = get_ctrl_table(model_ctrl_table, model) + try: + addr, bytes = ctrl_table[data_name] + return addr, bytes + except KeyError: + raise KeyError(f"Address for '{data_name}' not found in {model} control table.") from None + + def assert_same_address(model_ctrl_table: dict[str, dict], motor_models: list[str], data_name: str) -> None: all_addr = [] all_bytes = [] for model in motor_models: - addr, bytes = model_ctrl_table[model][data_name] + addr, bytes = get_address(model_ctrl_table, model, data_name) all_addr.append(addr) all_bytes.append(bytes) @@ -275,7 +291,7 @@ class MotorsBus(abc.ABC): return ( f"{self.__class__.__name__}(\n" f" Port: '{self.port}',\n" - f" Motors: \n{pformat(self.motors, indent=8)},\n" + f" Motors: \n{pformat(self.motors, indent=8, sort_dicts=False)},\n" ")',\n" ) @@ -285,7 +301,9 @@ class MotorsBus(abc.ABC): return False first_table = self.model_ctrl_table[self.models[0]] - return any(DeepDiff(first_table, self.model_ctrl_table[model]) for model in self.models[1:]) + return any( + DeepDiff(first_table, get_ctrl_table(self.model_ctrl_table, model)) for model in self.models[1:] + ) @cached_property def names(self) -> list[str]: @@ -317,15 +335,12 @@ class MotorsBus(abc.ABC): raise TypeError(f"'{motor}' should be int, str.") def _validate_motors(self) -> None: - # TODO(aliberts): improve error messages for this (display problematics values) if len(self.ids) != len(set(self.ids)): - raise ValueError("Some motors have the same id.") + raise ValueError(f"Some motors have the same id!\n{self}") - if len(self.names) != len(set(self.names)): - raise ValueError("Some motors have the same name.") - - if any(model not in self.model_resolution_table for model in self.models): - raise ValueError("Some motors models are not available.") + # Ensure ctrl table available for all models + for model in self.models: + get_ctrl_table(self.model_ctrl_table, model) def _is_comm_success(self, comm: int) -> bool: return comm == self._comm_success @@ -333,11 +348,30 @@ class MotorsBus(abc.ABC): def _is_error(self, error: int) -> bool: return error != self._no_error + def _assert_motors_exist(self) -> None: + found_models = self.broadcast_ping() + expected_models = {m.id: self.model_number_table[m.model] for m in self.motors.values()} + if not set(found_models) == set(self.ids): + raise RuntimeError( + f"{self.__class__.__name__} is supposed to have these motors: ({{id: model_nb}})" + f"\n{pformat(expected_models, indent=4, sort_dicts=False)}\n" + f"But it found these motors on port '{self.port}':" + f"\n{pformat(found_models, indent=4, sort_dicts=False)}\n" + ) + + for id_, model in expected_models.items(): + if found_models[id_] != model: + raise RuntimeError( + f"Motor '{self._id_to_name(id_)}' (id={id_}) is supposed to be of model_number={model} " + f"('{self._id_to_model(id_)}') but a model_number={found_models[id_]} " + "was found instead for that id." + ) + @property def is_connected(self) -> bool: return self.port_handler.is_open - def connect(self) -> None: + def connect(self, assert_motors_exist: bool = True) -> None: if self.is_connected: raise DeviceAlreadyConnectedError( f"{self.__class__.__name__}('{self.port}') is already connected. Do not call `{self.__class__.__name__}.connect()` twice." @@ -346,6 +380,8 @@ class MotorsBus(abc.ABC): try: if not self.port_handler.openPort(): raise OSError(f"Failed to open port '{self.port}'.") + elif assert_motors_exist: + self._assert_motors_exist() except (FileNotFoundError, OSError, serial.SerialException) as e: logger.error( f"\nCould not connect on port '{self.port}'. Make sure you are using the correct port." @@ -441,7 +477,7 @@ class MotorsBus(abc.ABC): """ pass - def ping(self, motor: NameOrID, num_retry: int = 0, raise_on_error: bool = False) -> str | None: + def ping(self, motor: NameOrID, num_retry: int = 0, raise_on_error: bool = False) -> int | None: id_ = self._get_motor_id(motor) for n_try in range(1 + num_retry): model_number, comm, error = self.packet_handler.ping(self.port_handler, id_) @@ -460,7 +496,7 @@ class MotorsBus(abc.ABC): else: return - return self._model_nb_to_model(model_number) + return model_number @abc.abstractmethod def broadcast_ping( @@ -470,16 +506,22 @@ class MotorsBus(abc.ABC): @overload def sync_read( - self, data_name: str, motors: None = ..., raw_values: bool = ..., num_retry: int = ... + self, data_name: str, motors: None = ..., *, raw_values: bool = ..., num_retry: int = ... ) -> dict[str, Value]: ... @overload def sync_read( - self, data_name: str, motors: NameOrID | list[NameOrID], raw_values: bool = ..., num_retry: int = ... + self, + data_name: str, + motors: NameOrID | list[NameOrID], + *, + raw_values: bool = ..., + num_retry: int = ..., ) -> dict[NameOrID, Value]: ... def sync_read( self, data_name: str, motors: NameOrID | list[NameOrID] | None = None, + *, raw_values: bool = False, num_retry: int = 0, ) -> dict[NameOrID, Value]: @@ -500,7 +542,7 @@ class MotorsBus(abc.ABC): motor_ids = list(id_key_map) - comm, ids_values = self._sync_read(data_name, motor_ids, num_retry) + comm, ids_values = self._sync_read(data_name, motor_ids, num_retry=num_retry) if not self._is_comm_success(comm): raise ConnectionError( f"Failed to sync read '{data_name}' on {motor_ids=} after {num_retry + 1} tries." @@ -513,14 +555,14 @@ class MotorsBus(abc.ABC): return {id_key_map[id_]: val for id_, val in ids_values.items()} def _sync_read( - self, data_name: str, motor_ids: list[str], num_retry: int = 0 + self, data_name: str, motor_ids: list[str], model: str | None = None, num_retry: int = 0 ) -> tuple[int, dict[int, int]]: if self._has_different_ctrl_tables: models = [self._id_to_model(id_) for id_ in motor_ids] assert_same_address(self.model_ctrl_table, models, data_name) - model = self._id_to_model(next(iter(motor_ids))) - addr, n_bytes = self.model_ctrl_table[model][data_name] + model = self._id_to_model(next(iter(motor_ids))) if model is None else model + addr, n_bytes = get_address(self.model_ctrl_table, model, data_name) self._setup_sync_reader(motor_ids, addr, n_bytes) # FIXME(aliberts, pkooij): We should probably not have to do this. @@ -559,6 +601,7 @@ class MotorsBus(abc.ABC): self, data_name: str, values: Value | dict[NameOrID, Value], + *, raw_values: bool = False, num_retry: int = 0, ) -> None: @@ -577,7 +620,7 @@ class MotorsBus(abc.ABC): if not raw_values and data_name in self.calibration_required and self.calibration is not None: ids_values = self._uncalibrate_values(ids_values) - comm = self._sync_write(data_name, ids_values, num_retry) + comm = self._sync_write(data_name, ids_values, num_retry=num_retry) if not self._is_comm_success(comm): raise ConnectionError( f"Failed to sync write '{data_name}' with {ids_values=} after {num_retry + 1} tries." @@ -590,7 +633,7 @@ class MotorsBus(abc.ABC): assert_same_address(self.model_ctrl_table, models, data_name) model = self._id_to_model(next(iter(ids_values))) - addr, n_bytes = self.model_ctrl_table[model][data_name] + addr, n_bytes = get_address(self.model_ctrl_table, model, data_name) self._setup_sync_writer(ids_values, addr, n_bytes) for n_try in range(1 + num_retry): @@ -613,7 +656,7 @@ class MotorsBus(abc.ABC): self.sync_writer.addParam(id_, data) def write( - self, data_name: str, motor: NameOrID, value: Value, raw_value: bool = False, num_retry: int = 0 + self, data_name: str, motor: NameOrID, value: Value, *, raw_value: bool = False, num_retry: int = 0 ) -> None: if not self.is_connected: raise DeviceNotConnectedError( @@ -626,7 +669,7 @@ class MotorsBus(abc.ABC): id_value = self._uncalibrate_values({id_: value}) value = id_value[id_] - comm, error = self._write(data_name, id_, value, num_retry) + comm, error = self._write(data_name, id_, value, num_retry=num_retry) if not self._is_comm_success(comm): raise ConnectionError( f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries." @@ -640,7 +683,7 @@ class MotorsBus(abc.ABC): def _write(self, data_name: str, motor_id: int, value: int, num_retry: int = 0) -> tuple[int, int]: model = self._id_to_model(motor_id) - addr, n_bytes = self.model_ctrl_table[model][data_name] + addr, n_bytes = get_address(self.model_ctrl_table, model, data_name) data = self._split_int_to_bytes(value, n_bytes) for n_try in range(1 + num_retry): @@ -662,7 +705,3 @@ class MotorsBus(abc.ABC): self.port_handler.closePort() logger.debug(f"{self.__class__.__name__} disconnected.") - - def __del__(self): - if self.is_connected: - self.disconnect()