Fix base robot class

This commit is contained in:
Simon Alibert 2025-03-03 18:49:40 +01:00
parent c0137e89b9
commit c93cbb8311
4 changed files with 81 additions and 83 deletions

View File

@ -0,0 +1,17 @@
import abc
from dataclasses import dataclass
from pathlib import Path
import draccus
@dataclass(kw_only=True)
class RobotConfig(draccus.ChoiceRegistry, abc.ABC):
# Allows to distinguish between different robots of the same type
id: str | None = None
# Directory to store calibration file
calibration_dir: Path | None = None
@property
def type(self) -> str:
return self.get_choice_name(self.__class__)

View File

@ -1,59 +0,0 @@
import abc
from dataclasses import dataclass, field
from typing import Sequence
import draccus
from lerobot.common.cameras.configs import CameraConfig
from lerobot.common.motors.configs import MotorsBusConfig
@dataclass
class RobotConfig(draccus.ChoiceRegistry, abc.ABC):
@property
def type(self) -> str:
return self.get_choice_name(self.__class__)
# TODO(rcadene, aliberts): remove ManipulatorRobotConfig abstraction
@dataclass
class ManipulatorRobotConfig(RobotConfig):
leader_arms: dict[str, MotorsBusConfig] = field(default_factory=lambda: {})
follower_arms: dict[str, MotorsBusConfig] = field(default_factory=lambda: {})
cameras: dict[str, CameraConfig] = field(default_factory=lambda: {})
# Optionally limit the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length
# as the number of motors in your follower arms (assumes all follower arms have the same number of
# motors).
max_relative_target: list[float] | float | None = None
# Optionally set the leader arm in torque mode with the gripper motor set to this angle. This makes it
# possible to squeeze the gripper and have it spring back to an open position on its own. If None, the
# gripper is not put in torque mode.
gripper_open_degree: float | None = None
mock: bool = False
def __post_init__(self):
if self.mock:
for arm in self.leader_arms.values():
if not arm.mock:
arm.mock = True
for arm in self.follower_arms.values():
if not arm.mock:
arm.mock = True
for cam in self.cameras.values():
if not cam.mock:
cam.mock = True
if self.max_relative_target is not None and isinstance(self.max_relative_target, Sequence):
for name in self.follower_arms:
if len(self.follower_arms[name].motors) != len(self.max_relative_target):
raise ValueError(
f"len(max_relative_target)={len(self.max_relative_target)} but the follower arm with name {name} has "
f"{len(self.follower_arms[name].motors)} motors. Please make sure that the "
f"`max_relative_target` list has as many parameters as there are motors per arm. "
"Note: This feature does not yet work with robots where different follower arms have "
"different numbers of motors."
)

View File

@ -0,0 +1,64 @@
import abc
import numpy as np
from lerobot.common.constants import HF_LEROBOT_CALIBRATION, ROBOTS
from .config import RobotConfig
class Robot(abc.ABC):
"""The main LeRobot class for implementing robots."""
# Set these in ALL subclasses
config_class: RobotConfig
name: str
def __init__(self, config: RobotConfig):
self.robot_type = self.name
self.calibration_dir = (
config.calibration_dir if config.calibration_dir else HF_LEROBOT_CALIBRATION / ROBOTS / self.name
)
self.calibration_dir.mkdir(parents=True, exist_ok=True)
# TODO(aliberts): create a proper Feature class for this that links with datasets
@abc.abstractproperty
def state_feature(self) -> dict:
pass
@abc.abstractproperty
def action_feature(self) -> dict:
pass
@abc.abstractproperty
def camera_features(self) -> dict[str, dict]:
pass
@abc.abstractmethod
def connect(self) -> None:
"""Connects to the robot."""
pass
@abc.abstractmethod
def calibrate(self) -> None:
"""Calibrates the robot."""
pass
@abc.abstractmethod
def get_observation(self) -> dict[str, np.ndarray]:
"""Gets observation from the robot."""
pass
@abc.abstractmethod
def send_action(self, action: np.ndarray) -> np.ndarray:
"""Sends actions to the robot."""
pass
@abc.abstractmethod
def disconnect(self) -> None:
"""Disconnects from the robot."""
pass
def __del__(self):
if getattr(self, "is_connected", False):
self.disconnect()

View File

@ -1,24 +0,0 @@
import abc
class Robot(abc.ABC):
robot_type: str
features: dict
@abc.abstractmethod
def connect(self): ...
@abc.abstractmethod
def calibrate(self): ...
@abc.abstractmethod
def teleop_step(self, record_data=False): ...
@abc.abstractmethod
def capture_observation(self): ...
@abc.abstractmethod
def send_action(self, action): ...
@abc.abstractmethod
def disconnect(self): ...