Add safety limits on relative action target (#373)
This commit is contained in:
parent
97086cdcdf
commit
9ce98bb93c
|
@ -1,7 +1,9 @@
|
|||
import logging
|
||||
import pickle
|
||||
import time
|
||||
from dataclasses import dataclass, field, replace
|
||||
from pathlib import Path
|
||||
from typing import Sequence
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -164,11 +166,30 @@ class KochRobotConfig:
|
|||
follower_arms: dict[str, MotorsBus] = field(default_factory=lambda: {})
|
||||
cameras: dict[str, Camera] = field(default_factory=lambda: {})
|
||||
|
||||
# Optionally limit the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length
|
||||
# as the number of motors in your follower arms (assumes all follower arms have the same number of
|
||||
# motors).
|
||||
max_relative_target: list[float] | float | None = None
|
||||
|
||||
# Optionally set the leader arm in torque mode with the gripper motor set to this angle. This makes it
|
||||
# possible to squeeze the gripper and have it spring back to an open position on its own. If None, the
|
||||
# gripper is not put in torque mode.
|
||||
gripper_open_degree: float | None = None
|
||||
|
||||
def __setattr__(self, prop: str, val):
|
||||
if prop == "max_relative_target" and val is not None and isinstance(val, Sequence):
|
||||
for name in self.follower_arms:
|
||||
if len(self.follower_arms[name].motors) != len(val):
|
||||
raise ValueError(
|
||||
f"len(max_relative_target)={len(val)} but the follower arm with name {name} has "
|
||||
f"{len(self.follower_arms[name].motors)} motors. Please make sure that the "
|
||||
f"`max_relative_target` list has as many parameters as there are motors per arm. "
|
||||
"Note: This feature does not yet work with robots where different follower arms have "
|
||||
"different numbers of motors."
|
||||
)
|
||||
super().__setattr__(prop, val)
|
||||
|
||||
|
||||
class KochRobot:
|
||||
# TODO(rcadene): Implement force feedback
|
||||
|
@ -210,7 +231,10 @@ class KochRobot:
|
|||
},
|
||||
),
|
||||
}
|
||||
robot = KochRobot(leader_arms, follower_arms)
|
||||
robot = KochRobot(
|
||||
leader_arms=leader_arms,
|
||||
follower_arms=follower_arms,
|
||||
)
|
||||
|
||||
# Connect motors buses and cameras if any (Required)
|
||||
robot.connect()
|
||||
|
@ -222,7 +246,10 @@ class KochRobot:
|
|||
Example of highest frequency data collection without camera:
|
||||
```python
|
||||
# Assumes leader and follower arms have been instantiated already (see first example)
|
||||
robot = KochRobot(leader_arms, follower_arms)
|
||||
robot = KochRobot(
|
||||
leader_arms=leader_arms,
|
||||
follower_arms=follower_arms,
|
||||
)
|
||||
robot.connect()
|
||||
while True:
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
|
@ -240,7 +267,11 @@ class KochRobot:
|
|||
}
|
||||
|
||||
# Assumes leader and follower arms have been instantiated already (see first example)
|
||||
robot = KochRobot(leader_arms, follower_arms, cameras)
|
||||
robot = KochRobot(
|
||||
leader_arms=leader_arms,
|
||||
follower_arms=follower_arms,
|
||||
cameras=cameras,
|
||||
)
|
||||
robot.connect()
|
||||
while True:
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
|
@ -249,7 +280,11 @@ class KochRobot:
|
|||
Example of controlling the robot with a policy (without running multiple policies in parallel to ensure highest frequency):
|
||||
```python
|
||||
# Assumes leader and follower arms + cameras have been instantiated already (see previous example)
|
||||
robot = KochRobot(leader_arms, follower_arms, cameras)
|
||||
robot = KochRobot(
|
||||
leader_arms=leader_arms,
|
||||
follower_arms=follower_arms,
|
||||
cameras=cameras,
|
||||
)
|
||||
robot.connect()
|
||||
while True:
|
||||
# Uses the follower arms and cameras to capture an observation
|
||||
|
@ -397,7 +432,7 @@ class KochRobot:
|
|||
# Send action
|
||||
for name in self.follower_arms:
|
||||
before_fwrite_t = time.perf_counter()
|
||||
self.follower_arms[name].write("Goal_Position", follower_goal_pos[name])
|
||||
self.send_action(torch.tensor(follower_goal_pos[name]), [name])
|
||||
self.logs[f"write_follower_{name}_goal_pos_dt_s"] = time.perf_counter() - before_fwrite_t
|
||||
|
||||
# Early exit when recording data is not requested
|
||||
|
@ -479,21 +514,55 @@ class KochRobot:
|
|||
obs_dict[f"observation.images.{name}"] = torch.from_numpy(images[name])
|
||||
return obs_dict
|
||||
|
||||
def send_action(self, action: torch.Tensor):
|
||||
"""The provided action is expected to be a vector."""
|
||||
def send_action(self, action: torch.Tensor, follower_names: list[str] | None = None):
|
||||
"""Command the follower arms to move to a target joint configuration.
|
||||
|
||||
The relative action magnitude may be clipped depending on the configuration parameter
|
||||
`max_relative_target`.
|
||||
|
||||
Args:
|
||||
action: tensor containing the concatenated joint positions for the follower arms.
|
||||
follower_names: Pass follower arm names to only control a subset of all the follower arms.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
"KochRobot is not connected. You need to run `robot.connect()`."
|
||||
)
|
||||
|
||||
if follower_names is None:
|
||||
follower_names = list(self.follower_arms)
|
||||
elif not set(follower_names).issubset(self.follower_arms):
|
||||
raise ValueError(
|
||||
f"You provided {follower_names=} but only the following arms are registered: "
|
||||
f"{list(self.follower_arms)}"
|
||||
)
|
||||
|
||||
from_idx = 0
|
||||
to_idx = 0
|
||||
follower_goal_pos = {}
|
||||
for name in self.follower_arms:
|
||||
if name in self.follower_arms:
|
||||
to_idx += len(self.follower_arms[name].motor_names)
|
||||
follower_goal_pos[name] = action[from_idx:to_idx].numpy()
|
||||
from_idx = to_idx
|
||||
for name in follower_names:
|
||||
to_idx += len(self.follower_arms[name].motor_names)
|
||||
this_action = action[from_idx:to_idx]
|
||||
|
||||
if self.config.max_relative_target is not None:
|
||||
if not isinstance(self.config.max_relative_target, list):
|
||||
max_relative_target = [self.config.max_relative_target for _ in range(from_idx, to_idx)]
|
||||
max_relative_target = torch.tensor(self.config.max_relative_target)
|
||||
# Cap relative action target magnitude for safety.
|
||||
current_pos = torch.tensor(self.follower_arms[name].read("Present_Position"))
|
||||
diff = this_action - current_pos
|
||||
safe_diff = torch.minimum(diff, max_relative_target)
|
||||
safe_diff = torch.maximum(safe_diff, -max_relative_target)
|
||||
safe_action = current_pos + safe_diff
|
||||
if not torch.allclose(safe_action, action):
|
||||
logging.warning(
|
||||
"Relative action magnitude had to be clamped to be safe.\n"
|
||||
f" requested relative action target: {diff}\n"
|
||||
f" clamped relative action target: {safe_diff}"
|
||||
)
|
||||
|
||||
follower_goal_pos[name] = safe_action.numpy()
|
||||
from_idx = to_idx
|
||||
|
||||
for name in self.follower_arms:
|
||||
self.follower_arms[name].write("Goal_Position", follower_goal_pos[name].astype(np.int32))
|
||||
|
|
|
@ -37,6 +37,10 @@ cameras:
|
|||
fps: 30
|
||||
width: 640
|
||||
height: 480
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: null
|
||||
# Sets the leader arm in torque mode with the gripper motor set to this angle. This makes it possible
|
||||
# to squeeze the gripper and have it spring back to an open position on its own.
|
||||
gripper_open_degree: 35.156
|
||||
|
|
Loading…
Reference in New Issue