From c6212d585de741539040d064e533d6ead509ae8a Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Mon, 24 Mar 2025 20:56:58 +0100 Subject: [PATCH] Add raw_values option --- lerobot/common/motors/motors_bus.py | 30 +++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/lerobot/common/motors/motors_bus.py b/lerobot/common/motors/motors_bus.py index a63393ec..19af9a5d 100644 --- a/lerobot/common/motors/motors_bus.py +++ b/lerobot/common/motors/motors_bus.py @@ -471,13 +471,19 @@ class MotorsBus(abc.ABC): pass @overload - def sync_read(self, data_name: str, motors: None = ..., num_retry: int = ...) -> dict[str, Value]: ... + def sync_read( + self, data_name: str, motors: None = ..., raw_values: bool = ..., num_retry: int = ... + ) -> dict[str, Value]: ... @overload def sync_read( - self, data_name: str, motors: NameOrID | list[NameOrID], num_retry: int = ... + self, data_name: str, motors: NameOrID | list[NameOrID], raw_values: bool = ..., num_retry: int = ... ) -> dict[NameOrID, Value]: ... def sync_read( - self, data_name: str, motors: NameOrID | list[NameOrID] | None = None, num_retry: int = 0 + self, + data_name: str, + motors: NameOrID | list[NameOrID] | None = None, + raw_values: bool = False, + num_retry: int = 0, ) -> dict[NameOrID, Value]: if not self.is_connected: raise DeviceNotConnectedError( @@ -503,7 +509,7 @@ class MotorsBus(abc.ABC): f"{self.packet_handler.getTxRxResult(comm)}" ) - if data_name in self.calibration_required and self.calibration is not None: + if not raw_values and data_name in self.calibration_required and self.calibration is not None: ids_values = self._calibrate_values(ids_values) return {id_key_map[idx]: val for idx, val in ids_values.items()} @@ -551,7 +557,13 @@ class MotorsBus(abc.ABC): # for idx in motor_ids: # value = self.reader.getData(idx, address, n_bytes) - def sync_write(self, data_name: str, values: Value | dict[NameOrID, Value], num_retry: int = 0) -> None: + def sync_write( + self, + data_name: str, + values: Value | dict[NameOrID, Value], + raw_values: bool = False, + num_retry: int = 0, + ) -> None: if not self.is_connected: raise DeviceNotConnectedError( f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`." @@ -564,7 +576,7 @@ class MotorsBus(abc.ABC): else: raise ValueError(f"'values' is expected to be a single value or a dict. Got {values}") - if data_name in self.calibration_required and self.calibration is not None: + if not raw_values and data_name in self.calibration_required and self.calibration is not None: ids_values = self._uncalibrate_values(ids_values) comm = self._sync_write(data_name, ids_values, num_retry) @@ -602,7 +614,9 @@ class MotorsBus(abc.ABC): data = self._split_int_to_bytes(value, n_bytes) self.sync_writer.addParam(idx, data) - def write(self, data_name: str, motor: NameOrID, value: Value, num_retry: int = 0) -> None: + def write( + self, data_name: str, motor: NameOrID, value: Value, raw_value: bool = False, num_retry: int = 0 + ) -> None: if not self.is_connected: raise DeviceNotConnectedError( f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`." @@ -610,7 +624,7 @@ class MotorsBus(abc.ABC): idx = self._get_motor_id(motor) - if data_name in self.calibration_required and self.calibration is not None: + if not raw_value and data_name in self.calibration_required and self.calibration is not None: id_value = self._uncalibrate_values({idx: value}) value = id_value[idx]