Update ensure_safe_goal_position
This commit is contained in:
parent
7582a0a2b0
commit
039b437ef0
|
@ -1,8 +1,7 @@
|
||||||
import logging
|
import logging
|
||||||
|
from pprint import pformat
|
||||||
from typing import Protocol
|
from typing import Protocol
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from lerobot.common.robots import RobotConfig
|
from lerobot.common.robots import RobotConfig
|
||||||
|
|
||||||
|
|
||||||
|
@ -81,20 +80,38 @@ def make_robot(robot_type: str, **kwargs) -> Robot:
|
||||||
|
|
||||||
|
|
||||||
def ensure_safe_goal_position(
|
def ensure_safe_goal_position(
|
||||||
goal_pos: np.ndarray, present_pos: np.ndarray, max_relative_target: float | list[float]
|
goal_present_pos: dict[str, tuple[float, float]], max_relative_target: float | dict[float]
|
||||||
):
|
) -> dict[str, float]:
|
||||||
# Cap relative action target magnitude for safety.
|
"""Caps relative action target magnitude for safety."""
|
||||||
diff = goal_pos - present_pos
|
|
||||||
max_relative_target = np.array(max_relative_target)
|
|
||||||
safe_diff = np.min(diff, max_relative_target)
|
|
||||||
safe_diff = np.max(safe_diff, -max_relative_target)
|
|
||||||
safe_goal_pos = present_pos + safe_diff
|
|
||||||
|
|
||||||
if not np.allclose(goal_pos, safe_goal_pos):
|
if isinstance(max_relative_target, float):
|
||||||
|
diff_cap = {key: max_relative_target for key in goal_present_pos}
|
||||||
|
elif isinstance(max_relative_target, dict):
|
||||||
|
if not set(goal_present_pos) == set(max_relative_target):
|
||||||
|
raise ValueError("max_relative_target keys must match those of goal_present_pos.")
|
||||||
|
diff_cap = max_relative_target
|
||||||
|
else:
|
||||||
|
raise TypeError(max_relative_target)
|
||||||
|
|
||||||
|
warnings_dict = {}
|
||||||
|
safe_goal_positions = {}
|
||||||
|
for key, (goal_pos, present_pos) in goal_present_pos.items():
|
||||||
|
diff = goal_pos - present_pos
|
||||||
|
max_diff = diff_cap[key]
|
||||||
|
safe_diff = min(diff, max_diff)
|
||||||
|
safe_diff = max(safe_diff, -max_diff)
|
||||||
|
safe_goal_pos = present_pos + safe_diff
|
||||||
|
safe_goal_positions[key] = safe_goal_pos
|
||||||
|
if abs(safe_goal_pos - goal_pos) > 1e-4:
|
||||||
|
warnings_dict[key] = {
|
||||||
|
"original goal_pos": goal_pos,
|
||||||
|
"safe goal_pos": safe_goal_pos,
|
||||||
|
}
|
||||||
|
|
||||||
|
if warnings_dict:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
"Relative goal position magnitude had to be clamped to be safe.\n"
|
"Relative goal position magnitude had to be clamped to be safe.\n"
|
||||||
f" requested relative goal position target: {diff}\n"
|
f"{pformat(warnings_dict, indent=4)}"
|
||||||
f" clamped relative goal position target: {safe_diff}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return safe_goal_pos
|
return safe_goal_positions
|
||||||
|
|
Loading…
Reference in New Issue