""" Copied from https://github.com/openvla/openvla/blob/main/prismatic/vla/datasets/rlds/utils/data_utils.py """ from typing import Any, Dict import tensorflow as tf def binarize_gripper_actions(actions: tf.Tensor) -> tf.Tensor: """ Converts gripper actions from continuous to binary values (0 and 1). We exploit that fact that most of the time, the gripper is fully open (near 1.0) or fully closed (near 0.0). As it transitions between the two, it sometimes passes through a few intermediate values. We relabel those intermediate values based on the state that is reached _after_ those intermediate values. In the edge case that the trajectory ends with an intermediate value, we give up on binarizing and relabel that chunk of intermediate values as the last action in the trajectory. The `scan_fn` implements the following logic: new_actions = np.empty_like(actions) carry = actions[-1] for i in reversed(range(actions.shape[0])): if in_between_mask[i]: carry = carry else: carry = float(open_mask[i]) new_actions[i] = carry """ open_mask, closed_mask = actions > 0.95, actions < 0.05 in_between_mask = tf.logical_not(tf.logical_or(open_mask, closed_mask)) is_open_float = tf.cast(open_mask, tf.float32) def scan_fn(carry, i): return tf.cond(in_between_mask[i], lambda: tf.cast(carry, tf.float32), lambda: is_open_float[i]) return tf.scan(scan_fn, tf.range(tf.shape(actions)[0]), actions[-1], reverse=True) def invert_gripper_actions(actions: tf.Tensor) -> tf.Tensor: return 1 - actions def rel2abs_gripper_actions(actions: tf.Tensor) -> tf.Tensor: """ Converts relative gripper actions (+1 for closing, -1 for opening) to absolute actions (0 = closed; 1 = open). Assumes that the first relative gripper is not redundant (i.e. close when already closed)! """ # Note =>> -1 for closing, 1 for opening, 0 for no change opening_mask, closing_mask = actions < -0.1, actions > 0.1 thresholded_actions = tf.where(opening_mask, 1, tf.where(closing_mask, -1, 0)) def scan_fn(carry, i): return tf.cond(thresholded_actions[i] == 0, lambda: carry, lambda: thresholded_actions[i]) # If no relative grasp, assumes open for whole trajectory start = -1 * thresholded_actions[tf.argmax(thresholded_actions != 0, axis=0)] start = tf.cond(start == 0, lambda: 1, lambda: start) # Note =>> -1 for closed, 1 for open new_actions = tf.scan(scan_fn, tf.range(tf.shape(actions)[0]), start) new_actions = tf.cast(new_actions, tf.float32) / 2 + 0.5 return new_actions # === Bridge-V2 =>> Dataset-Specific Transform === def relabel_bridge_actions(traj: Dict[str, Any]) -> Dict[str, Any]: """Relabels actions to use reached proprioceptive state; discards last timestep (no-action).""" movement_actions = traj["observation"]["state"][1:, :6] - traj["observation"]["state"][:-1, :6] traj_truncated = tf.nest.map_structure(lambda x: x[:-1], traj) traj_truncated["action"] = tf.concat([movement_actions, traj["action"][:-1, -1:]], axis=1) return traj_truncated