diff --git a/lerobot/common/robots/utils.py b/lerobot/common/robots/utils.py index 12b9dac5..7e538376 100644 --- a/lerobot/common/robots/utils.py +++ b/lerobot/common/robots/utils.py @@ -1,8 +1,7 @@ import logging +from pprint import pformat from typing import Protocol -import numpy as np - from lerobot.common.robots import RobotConfig @@ -81,20 +80,38 @@ def make_robot(robot_type: str, **kwargs) -> Robot: def ensure_safe_goal_position( - goal_pos: np.ndarray, present_pos: np.ndarray, max_relative_target: float | list[float] -): - # Cap relative action target magnitude for safety. - diff = goal_pos - present_pos - max_relative_target = np.array(max_relative_target) - safe_diff = np.min(diff, max_relative_target) - safe_diff = np.max(safe_diff, -max_relative_target) - safe_goal_pos = present_pos + safe_diff + goal_present_pos: dict[str, tuple[float, float]], max_relative_target: float | dict[float] +) -> dict[str, float]: + """Caps relative action target magnitude for safety.""" - if not np.allclose(goal_pos, safe_goal_pos): + if isinstance(max_relative_target, float): + diff_cap = {key: max_relative_target for key in goal_present_pos} + elif isinstance(max_relative_target, dict): + if not set(goal_present_pos) == set(max_relative_target): + raise ValueError("max_relative_target keys must match those of goal_present_pos.") + diff_cap = max_relative_target + else: + raise TypeError(max_relative_target) + + warnings_dict = {} + safe_goal_positions = {} + for key, (goal_pos, present_pos) in goal_present_pos.items(): + diff = goal_pos - present_pos + max_diff = diff_cap[key] + safe_diff = min(diff, max_diff) + safe_diff = max(safe_diff, -max_diff) + safe_goal_pos = present_pos + safe_diff + safe_goal_positions[key] = safe_goal_pos + if abs(safe_goal_pos - goal_pos) > 1e-4: + warnings_dict[key] = { + "original goal_pos": goal_pos, + "safe goal_pos": safe_goal_pos, + } + + if warnings_dict: logging.warning( "Relative goal position magnitude had to be clamped to be safe.\n" - f" requested relative goal position target: {diff}\n" - f" clamped relative goal position target: {safe_diff}" + f"{pformat(warnings_dict, indent=4)}" ) - return safe_goal_pos + return safe_goal_positions