Add logger, rm logs

This commit is contained in:
Simon Alibert 2025-03-22 10:33:42 +01:00
parent 9e34c1d731
commit 40675ec76c
1 changed files with 9 additions and 39 deletions

View File

@ -21,7 +21,7 @@
import abc
import json
import time
import logging
from dataclasses import dataclass
from enum import Enum
from functools import cached_property
@ -33,23 +33,13 @@ import serial
from deepdiff import DeepDiff
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.common.utils.utils import capture_timestamp_utc
NameOrID: TypeAlias = str | int
Value: TypeAlias = int | float
MAX_ID_RANGE = 252
def get_group_sync_key(data_name: str, motor_ids: list[int]) -> str:
group_key = f"{data_name}_" + "_".join([str(idx) for idx in motor_ids])
return group_key
def get_log_name(var_name: str, fn_name: str, data_name: str, motor_names: list[str]) -> str:
group_key = get_group_sync_key(data_name, motor_names)
log_name = f"{var_name}_{fn_name}_{group_key}"
return log_name
logger = logging.getLogger(__name__)
def assert_same_address(model_ctrl_table: dict[str, dict], motor_models: list[str], data_name: str) -> None:
@ -268,7 +258,6 @@ class MotorsBus(abc.ABC):
self.reader: GroupSyncRead
self.writer: GroupSyncWrite
self.logs = {} # TODO(aliberts): use subclass logger
self.calibration = None
self._id_to_model = {m.id: m.model for m in self.motors.values()}
@ -363,10 +352,11 @@ class MotorsBus(abc.ABC):
def ping(self, motor: NameOrID, num_retry: int = 0, raise_on_error: bool = False) -> int | None:
idx = self.get_motor_id(motor)
for _ in range(1 + num_retry):
for n_try in range(1 + num_retry):
model_number, comm, error = self.packet_handler.ping(self.port_handler, idx)
if self._is_comm_success(comm):
return model_number
logger.debug(f"ping failed for {idx=}: {n_try=} got {comm=} {error=}")
if raise_on_error:
raise ConnectionError(f"Ping motor {motor} returned a {error} error code.")
@ -463,8 +453,6 @@ class MotorsBus(abc.ABC):
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
)
start_time = time.perf_counter()
id_key_map: dict[int, NameOrID] = {}
if motors is None:
id_key_map = {m.id: name for name, m in self.motors.items()}
@ -493,17 +481,7 @@ class MotorsBus(abc.ABC):
if data_name in self.calibration_required and self.calibration is not None:
ids_values = self.calibrate_values(ids_values)
keys_values = {id_key_map[idx]: val for idx, val in ids_values.items()}
# log the number of seconds it took to read the data from the motors
delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_ids)
self.logs[delta_ts_name] = time.perf_counter() - start_time
# log the utc time at which the data was received
ts_utc_name = get_log_name("timestamp_utc", "read", data_name, motor_ids)
self.logs[ts_utc_name] = capture_timestamp_utc()
return keys_values
return {id_key_map[idx]: val for idx, val in ids_values.items()}
def _read(
self, motor_ids: list[str], address: int, n_bytes: int, num_retry: int = 0
@ -520,10 +498,11 @@ class MotorsBus(abc.ABC):
for idx in motor_ids:
self.reader.addParam(idx)
for _ in range(1 + num_retry):
for n_try in range(1 + num_retry):
comm = self.reader.txRxPacket()
if self._is_comm_success(comm):
break
logger.debug(f"ids={list(motor_ids)} @{address} ({n_bytes} bytes) {n_try=} got {comm=}")
values = {idx: self.reader.getData(idx, address, n_bytes) for idx in motor_ids}
return comm, values
@ -543,8 +522,6 @@ class MotorsBus(abc.ABC):
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
)
start_time = time.perf_counter()
if isinstance(values, int):
ids_values = {id_: values for id_ in self.ids}
elif isinstance(values, dict):
@ -569,14 +546,6 @@ class MotorsBus(abc.ABC):
f"{self.packet_handler.getTxRxResult(comm)}"
)
# log the number of seconds it took to write the data to the motors
delta_ts_name = get_log_name("delta_timestamp_s", "write", data_name, list(ids_values))
self.logs[delta_ts_name] = time.perf_counter() - start_time
# log the utc time when the write has been completed
ts_utc_name = get_log_name("timestamp_utc", "write", data_name, list(ids_values))
self.logs[ts_utc_name] = capture_timestamp_utc()
def _write(self, ids_values: dict[int, int], address: int, n_bytes: int, num_retry: int = 0) -> int:
self.writer.clearParam()
self.writer.start_address = address
@ -586,10 +555,11 @@ class MotorsBus(abc.ABC):
data = self.split_int_bytes(value, n_bytes)
self.writer.addParam(idx, data)
for _ in range(1 + num_retry):
for n_try in range(1 + num_retry):
comm = self.writer.txPacket()
if self._is_comm_success(comm):
break
logger.debug(f"ids={list(ids_values)} @{address} ({n_bytes} bytes) {n_try=} got {comm=}")
return comm