From c93cbb83114ebab30feb3e4cfb5602fae7f9d73c Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Mon, 3 Mar 2025 18:49:40 +0100 Subject: [PATCH] Fix base robot class --- lerobot/common/robots/config.py | 17 ++++++++ lerobot/common/robots/config_abc.py | 59 -------------------------- lerobot/common/robots/robot.py | 64 +++++++++++++++++++++++++++++ lerobot/common/robots/robot_abc.py | 24 ----------- 4 files changed, 81 insertions(+), 83 deletions(-) create mode 100644 lerobot/common/robots/config.py delete mode 100644 lerobot/common/robots/config_abc.py create mode 100644 lerobot/common/robots/robot.py delete mode 100644 lerobot/common/robots/robot_abc.py diff --git a/lerobot/common/robots/config.py b/lerobot/common/robots/config.py new file mode 100644 index 00000000..83a13ca9 --- /dev/null +++ b/lerobot/common/robots/config.py @@ -0,0 +1,17 @@ +import abc +from dataclasses import dataclass +from pathlib import Path + +import draccus + + +@dataclass(kw_only=True) +class RobotConfig(draccus.ChoiceRegistry, abc.ABC): + # Allows to distinguish between different robots of the same type + id: str | None = None + # Directory to store calibration file + calibration_dir: Path | None = None + + @property + def type(self) -> str: + return self.get_choice_name(self.__class__) diff --git a/lerobot/common/robots/config_abc.py b/lerobot/common/robots/config_abc.py deleted file mode 100644 index 7a390ead..00000000 --- a/lerobot/common/robots/config_abc.py +++ /dev/null @@ -1,59 +0,0 @@ -import abc -from dataclasses import dataclass, field -from typing import Sequence - -import draccus - -from lerobot.common.cameras.configs import CameraConfig -from lerobot.common.motors.configs import MotorsBusConfig - - -@dataclass -class RobotConfig(draccus.ChoiceRegistry, abc.ABC): - @property - def type(self) -> str: - return self.get_choice_name(self.__class__) - - -# TODO(rcadene, aliberts): remove ManipulatorRobotConfig abstraction -@dataclass -class ManipulatorRobotConfig(RobotConfig): - leader_arms: dict[str, MotorsBusConfig] = field(default_factory=lambda: {}) - follower_arms: dict[str, MotorsBusConfig] = field(default_factory=lambda: {}) - cameras: dict[str, CameraConfig] = field(default_factory=lambda: {}) - - # Optionally limit the magnitude of the relative positional target vector for safety purposes. - # Set this to a positive scalar to have the same value for all motors, or a list that is the same length - # as the number of motors in your follower arms (assumes all follower arms have the same number of - # motors). - max_relative_target: list[float] | float | None = None - - # Optionally set the leader arm in torque mode with the gripper motor set to this angle. This makes it - # possible to squeeze the gripper and have it spring back to an open position on its own. If None, the - # gripper is not put in torque mode. - gripper_open_degree: float | None = None - - mock: bool = False - - def __post_init__(self): - if self.mock: - for arm in self.leader_arms.values(): - if not arm.mock: - arm.mock = True - for arm in self.follower_arms.values(): - if not arm.mock: - arm.mock = True - for cam in self.cameras.values(): - if not cam.mock: - cam.mock = True - - if self.max_relative_target is not None and isinstance(self.max_relative_target, Sequence): - for name in self.follower_arms: - if len(self.follower_arms[name].motors) != len(self.max_relative_target): - raise ValueError( - f"len(max_relative_target)={len(self.max_relative_target)} but the follower arm with name {name} has " - f"{len(self.follower_arms[name].motors)} motors. Please make sure that the " - f"`max_relative_target` list has as many parameters as there are motors per arm. " - "Note: This feature does not yet work with robots where different follower arms have " - "different numbers of motors." - ) diff --git a/lerobot/common/robots/robot.py b/lerobot/common/robots/robot.py new file mode 100644 index 00000000..50fd9154 --- /dev/null +++ b/lerobot/common/robots/robot.py @@ -0,0 +1,64 @@ +import abc + +import numpy as np + +from lerobot.common.constants import HF_LEROBOT_CALIBRATION, ROBOTS + +from .config import RobotConfig + + +class Robot(abc.ABC): + """The main LeRobot class for implementing robots.""" + + # Set these in ALL subclasses + config_class: RobotConfig + name: str + + def __init__(self, config: RobotConfig): + self.robot_type = self.name + self.calibration_dir = ( + config.calibration_dir if config.calibration_dir else HF_LEROBOT_CALIBRATION / ROBOTS / self.name + ) + self.calibration_dir.mkdir(parents=True, exist_ok=True) + + # TODO(aliberts): create a proper Feature class for this that links with datasets + @abc.abstractproperty + def state_feature(self) -> dict: + pass + + @abc.abstractproperty + def action_feature(self) -> dict: + pass + + @abc.abstractproperty + def camera_features(self) -> dict[str, dict]: + pass + + @abc.abstractmethod + def connect(self) -> None: + """Connects to the robot.""" + pass + + @abc.abstractmethod + def calibrate(self) -> None: + """Calibrates the robot.""" + pass + + @abc.abstractmethod + def get_observation(self) -> dict[str, np.ndarray]: + """Gets observation from the robot.""" + pass + + @abc.abstractmethod + def send_action(self, action: np.ndarray) -> np.ndarray: + """Sends actions to the robot.""" + pass + + @abc.abstractmethod + def disconnect(self) -> None: + """Disconnects from the robot.""" + pass + + def __del__(self): + if getattr(self, "is_connected", False): + self.disconnect() diff --git a/lerobot/common/robots/robot_abc.py b/lerobot/common/robots/robot_abc.py deleted file mode 100644 index 3b592f0a..00000000 --- a/lerobot/common/robots/robot_abc.py +++ /dev/null @@ -1,24 +0,0 @@ -import abc - - -class Robot(abc.ABC): - robot_type: str - features: dict - - @abc.abstractmethod - def connect(self): ... - - @abc.abstractmethod - def calibrate(self): ... - - @abc.abstractmethod - def teleop_step(self, record_data=False): ... - - @abc.abstractmethod - def capture_observation(self): ... - - @abc.abstractmethod - def send_action(self, action): ... - - @abc.abstractmethod - def disconnect(self): ...