Update ensure_safe_goal_position

This commit is contained in:
Simon Alibert 2025-03-23 19:43:58 +01:00
parent 7582a0a2b0
commit 039b437ef0
1 changed files with 31 additions and 14 deletions

View File

@ -1,8 +1,7 @@
import logging
from pprint import pformat
from typing import Protocol
import numpy as np
from lerobot.common.robots import RobotConfig
@ -81,20 +80,38 @@ def make_robot(robot_type: str, **kwargs) -> Robot:
def ensure_safe_goal_position(
goal_pos: np.ndarray, present_pos: np.ndarray, max_relative_target: float | list[float]
):
# Cap relative action target magnitude for safety.
diff = goal_pos - present_pos
max_relative_target = np.array(max_relative_target)
safe_diff = np.min(diff, max_relative_target)
safe_diff = np.max(safe_diff, -max_relative_target)
safe_goal_pos = present_pos + safe_diff
goal_present_pos: dict[str, tuple[float, float]], max_relative_target: float | dict[float]
) -> dict[str, float]:
"""Caps relative action target magnitude for safety."""
if not np.allclose(goal_pos, safe_goal_pos):
if isinstance(max_relative_target, float):
diff_cap = {key: max_relative_target for key in goal_present_pos}
elif isinstance(max_relative_target, dict):
if not set(goal_present_pos) == set(max_relative_target):
raise ValueError("max_relative_target keys must match those of goal_present_pos.")
diff_cap = max_relative_target
else:
raise TypeError(max_relative_target)
warnings_dict = {}
safe_goal_positions = {}
for key, (goal_pos, present_pos) in goal_present_pos.items():
diff = goal_pos - present_pos
max_diff = diff_cap[key]
safe_diff = min(diff, max_diff)
safe_diff = max(safe_diff, -max_diff)
safe_goal_pos = present_pos + safe_diff
safe_goal_positions[key] = safe_goal_pos
if abs(safe_goal_pos - goal_pos) > 1e-4:
warnings_dict[key] = {
"original goal_pos": goal_pos,
"safe goal_pos": safe_goal_pos,
}
if warnings_dict:
logging.warning(
"Relative goal position magnitude had to be clamped to be safe.\n"
f" requested relative goal position target: {diff}\n"
f" clamped relative goal position target: {safe_diff}"
f"{pformat(warnings_dict, indent=4)}"
)
return safe_goal_pos
return safe_goal_positions