77 lines
3.1 KiB
Python
77 lines
3.1 KiB
Python
"""
|
|
Copied from https://github.com/openvla/openvla/blob/main/prismatic/vla/datasets/rlds/utils/data_utils.py
|
|
"""
|
|
|
|
from typing import Any, Dict
|
|
|
|
import tensorflow as tf
|
|
|
|
|
|
def binarize_gripper_actions(actions: tf.Tensor) -> tf.Tensor:
|
|
"""
|
|
Converts gripper actions from continuous to binary values (0 and 1).
|
|
|
|
We exploit that fact that most of the time, the gripper is fully open (near 1.0) or fully closed (near 0.0). As it
|
|
transitions between the two, it sometimes passes through a few intermediate values. We relabel those intermediate
|
|
values based on the state that is reached _after_ those intermediate values.
|
|
|
|
In the edge case that the trajectory ends with an intermediate value, we give up on binarizing and relabel that
|
|
chunk of intermediate values as the last action in the trajectory.
|
|
|
|
The `scan_fn` implements the following logic:
|
|
new_actions = np.empty_like(actions)
|
|
carry = actions[-1]
|
|
for i in reversed(range(actions.shape[0])):
|
|
if in_between_mask[i]:
|
|
carry = carry
|
|
else:
|
|
carry = float(open_mask[i])
|
|
new_actions[i] = carry
|
|
"""
|
|
open_mask, closed_mask = actions > 0.95, actions < 0.05
|
|
in_between_mask = tf.logical_not(tf.logical_or(open_mask, closed_mask))
|
|
is_open_float = tf.cast(open_mask, tf.float32)
|
|
|
|
def scan_fn(carry, i):
|
|
return tf.cond(in_between_mask[i], lambda: tf.cast(carry, tf.float32), lambda: is_open_float[i])
|
|
|
|
return tf.scan(scan_fn, tf.range(tf.shape(actions)[0]), actions[-1], reverse=True)
|
|
|
|
|
|
def invert_gripper_actions(actions: tf.Tensor) -> tf.Tensor:
|
|
return 1 - actions
|
|
|
|
|
|
def rel2abs_gripper_actions(actions: tf.Tensor) -> tf.Tensor:
|
|
"""
|
|
Converts relative gripper actions (+1 for closing, -1 for opening) to absolute actions (0 = closed; 1 = open).
|
|
|
|
Assumes that the first relative gripper is not redundant (i.e. close when already closed)!
|
|
"""
|
|
# Note =>> -1 for closing, 1 for opening, 0 for no change
|
|
opening_mask, closing_mask = actions < -0.1, actions > 0.1
|
|
thresholded_actions = tf.where(opening_mask, 1, tf.where(closing_mask, -1, 0))
|
|
|
|
def scan_fn(carry, i):
|
|
return tf.cond(thresholded_actions[i] == 0, lambda: carry, lambda: thresholded_actions[i])
|
|
|
|
# If no relative grasp, assumes open for whole trajectory
|
|
start = -1 * thresholded_actions[tf.argmax(thresholded_actions != 0, axis=0)]
|
|
start = tf.cond(start == 0, lambda: 1, lambda: start)
|
|
|
|
# Note =>> -1 for closed, 1 for open
|
|
new_actions = tf.scan(scan_fn, tf.range(tf.shape(actions)[0]), start)
|
|
new_actions = tf.cast(new_actions, tf.float32) / 2 + 0.5
|
|
|
|
return new_actions
|
|
|
|
|
|
# === Bridge-V2 =>> Dataset-Specific Transform ===
|
|
def relabel_bridge_actions(traj: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Relabels actions to use reached proprioceptive state; discards last timestep (no-action)."""
|
|
movement_actions = traj["observation"]["state"][1:, :6] - traj["observation"]["state"][:-1, :6]
|
|
traj_truncated = tf.nest.map_structure(lambda x: x[:-1], traj)
|
|
traj_truncated["action"] = tf.concat([movement_actions, traj["action"][:-1, -1:]], axis=1)
|
|
|
|
return traj_truncated
|