Merge remote-tracking branch 'tavish9_lerobot_openx/main' into user/rcadene/2025_02_19_port_openx

This commit is contained in:
Remi Cadene 2025-02-19 12:58:18 +00:00
commit 76436ca1de
4 changed files with 2217 additions and 0 deletions

View File

@ -0,0 +1,308 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
For all datasets in the RLDS format.
For https://github.com/google-deepmind/open_x_embodiment (OPENX) datasets.
NOTE: You need to install tensorflow and tensorflow_datsets before running this script.
Example:
python openx_rlds.py \
--raw-dir /path/to/bridge_orig/1.0.0 \
--local-dir /path/to/local_dir \
--repo-id your_id \
--use-videos \
--push-to-hub
"""
import argparse
import os
import re
import shutil
import sys
from functools import partial
from pathlib import Path
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME, LeRobotDataset
current_dir = os.path.dirname(os.path.abspath(__file__))
oxe_utils_dir = os.path.join(current_dir, "oxe_utils")
sys.path.append(oxe_utils_dir)
from oxe_utils.configs import OXE_DATASET_CONFIGS, StateEncoding
from oxe_utils.transforms import OXE_STANDARDIZATION_TRANSFORMS
np.set_printoptions(precision=2)
def transform_raw_dataset(episode, dataset_name):
traj = next(iter(episode["steps"].batch(episode["steps"].cardinality())))
if dataset_name in OXE_STANDARDIZATION_TRANSFORMS:
traj = OXE_STANDARDIZATION_TRANSFORMS[dataset_name](traj)
if dataset_name in OXE_DATASET_CONFIGS:
state_obs_keys = OXE_DATASET_CONFIGS[dataset_name]["state_obs_keys"]
else:
state_obs_keys = [None for _ in range(8)]
proprio = tf.concat(
[
(
tf.zeros((tf.shape(traj["action"])[0], 1), dtype=tf.float32) # padding
if key is None
else tf.cast(traj["observation"][key], tf.float32)
)
for key in state_obs_keys
],
axis=1,
)
traj.update(
{
"proprio": proprio,
"task": traj.pop("language_instruction"),
"action": tf.cast(traj["action"], tf.float32),
}
)
episode["steps"] = traj
return episode
def generate_features_from_raw(builder: tfds.core.DatasetBuilder, use_videos: bool = True):
dataset_name = builder.name
state_names = [f"motor_{i}" for i in range(8)]
if dataset_name in OXE_DATASET_CONFIGS:
state_encoding = OXE_DATASET_CONFIGS[dataset_name]["state_encoding"]
if state_encoding == StateEncoding.POS_EULER:
state_names = ["x", "y", "z", "roll", "pitch", "yaw", "pad", "gripper"]
if "libero" in dataset_name:
state_names = ["x", "y", "z", "roll", "pitch", "yaw", "gripper", "gripper"] # 2D gripper state
elif state_encoding == StateEncoding.POS_QUAT:
state_names = ["x", "y", "z", "rx", "ry", "rz", "rw", "gripper"]
DEFAULT_FEATURES = {
"observation.state": {
"dtype": "float32",
"shape": (8,),
"names": {"motors": state_names},
},
"action": {
"dtype": "float32",
"shape": (7,),
"names": {"motors": ["x", "y", "z", "roll", "pitch", "yaw", "gripper"]},
},
}
obs = builder.info.features["steps"]["observation"]
features = {
f"observation.images.{key}": {
"dtype": "video" if use_videos else "image",
"shape": value.shape,
"names": ["height", "width", "rgb"],
}
for key, value in obs.items()
if "depth" not in key and any(x in key for x in ["image", "rgb"])
}
return {**features, **DEFAULT_FEATURES}
def save_as_lerobot_dataset(lerobot_dataset: LeRobotDataset, raw_dataset: tf.data.Dataset, **kwargs):
for episode in raw_dataset.as_numpy_iterator():
traj = episode["steps"]
for i in range(traj["action"].shape[0]):
image_dict = {
f"observation.images.{key}": value[i]
for key, value in traj["observation"].items()
if "depth" not in key and any(x in key for x in ["image", "rgb"])
}
lerobot_dataset.add_frame(
{
**image_dict,
"observation.state": traj["proprio"][i],
"action": traj["action"][i],
}
)
lerobot_dataset.save_episode(task=traj["task"][0].decode())
lerobot_dataset.consolidate(
run_compute_stats=True,
keep_image_files=kwargs["keep_images"],
stat_kwargs={"batch_size": kwargs["batch_size"], "num_workers": kwargs["num_workers"]},
)
def create_lerobot_dataset(
raw_dir: Path,
repo_id: str = None,
local_dir: Path = None,
push_to_hub: bool = False,
fps: int = None,
robot_type: str = None,
use_videos: bool = True,
batch_size: int = 32,
num_workers: int = 8,
image_writer_process: int = 5,
image_writer_threads: int = 10,
keep_images: bool = True,
):
last_part = raw_dir.name
if re.match(r"^\d+\.\d+\.\d+$", last_part):
version = last_part
dataset_name = raw_dir.parent.name
data_dir = raw_dir.parent.parent
else:
version = ""
dataset_name = last_part
data_dir = raw_dir.parent
if local_dir is None:
local_dir = Path(LEROBOT_HOME)
local_dir /= f"{dataset_name}_{version}_lerobot"
if local_dir.exists():
shutil.rmtree(local_dir)
builder = tfds.builder(dataset_name, data_dir=data_dir, version=version)
features = generate_features_from_raw(builder, use_videos)
raw_dataset = builder.as_dataset(split="train").map(partial(transform_raw_dataset, dataset_name=dataset_name))
if fps is None:
if dataset_name in OXE_DATASET_CONFIGS:
fps = OXE_DATASET_CONFIGS[dataset_name]["control_frequency"]
else:
fps = 10
if robot_type is None:
if dataset_name in OXE_DATASET_CONFIGS:
robot_type = OXE_DATASET_CONFIGS[dataset_name]["robot_type"]
robot_type = robot_type.lower().replace(" ", "_").replace("-", "_")
else:
robot_type = "unknown"
lerobot_dataset = LeRobotDataset.create(
repo_id=repo_id,
robot_type=robot_type,
root=local_dir,
fps=fps,
use_videos=use_videos,
features=features,
image_writer_threads=image_writer_threads,
image_writer_processes=image_writer_process,
)
save_as_lerobot_dataset(
lerobot_dataset, raw_dataset, keep_images=keep_images, batch_size=batch_size, num_workers=num_workers
)
if push_to_hub:
assert repo_id is not None
tags = ["LeRobot", dataset_name, "rlds"]
if dataset_name in OXE_DATASET_CONFIGS:
tags.append("openx")
if robot_type != "unknown":
tags.append(robot_type)
lerobot_dataset.push_to_hub(
tags=tags,
private=False,
push_videos=True,
license="apache-2.0",
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--raw-dir",
type=Path,
required=True,
help="Directory containing input raw datasets (e.g. `path/to/dataset` or `path/to/dataset/version).",
)
parser.add_argument(
"--local-dir",
type=Path,
required=True,
help="When provided, writes the dataset converted to LeRobotDataset format in this directory (e.g. `data/lerobot/aloha_mobile_chair`).",
)
parser.add_argument(
"--repo-id",
type=str,
help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset, required when push-to-hub is True",
)
parser.add_argument(
"--push-to-hub",
action="store_true",
help="Upload to hub.",
)
parser.add_argument(
"--robot-type",
type=str,
default=None,
help="Robot type of this dataset.",
)
parser.add_argument(
"--fps",
type=int,
default=None,
help="Frame rate used to collect videos. Default fps equals to the control frequency of the robot.",
)
parser.add_argument(
"--use-videos",
action="store_true",
help="Convert each episode of the raw dataset to an mp4 video. This option allows 60 times lower disk space consumption and 25 faster loading time during training.",
)
parser.add_argument(
"--batch-size",
type=int,
default=32,
help="Batch size loaded by DataLoader for computing the dataset statistics.",
)
parser.add_argument(
"--num-workers",
type=int,
default=8,
help="Number of processes of Dataloader for computing the dataset statistics.",
)
parser.add_argument(
"--image-writer-process",
type=int,
default=5,
help="Number of processes of image writer for saving images.",
)
parser.add_argument(
"--image-writer-threads",
type=int,
default=10,
help="Number of threads per process of image writer for saving images.",
)
parser.add_argument(
"--keep-images",
action="store_true",
help="Whether to keep the cached images.",
)
args = parser.parse_args()
create_lerobot_dataset(**vars(args))
if __name__ == "__main__":
main()

View File

@ -0,0 +1,848 @@
"""
Adapt from https://github.com/openvla/openvla/blob/main/prismatic/vla/datasets/rlds/oxe/configs.py
configs.py
Defines per-dataset configuration (kwargs) for each dataset in Open-X Embodiment.
Configuration adopts the following structure:
image_obs_keys:
primary: primary external RGB
secondary: secondary external RGB
wrist: wrist RGB
depth_obs_keys:
primary: primary external depth
secondary: secondary external depth
wrist: wrist depth
# Always 8-dim =>> changes based on `StateEncoding`
state_obs_keys:
StateEncoding.POS_EULER: EEF XYZ (3) + Roll-Pitch-Yaw (3) + <PAD> (1) + Gripper Open/Close (1)
StateEncoding.POS_QUAT: EEF XYZ (3) + Quaternion (4) + Gripper Open/Close (1)
StateEncoding.JOINT: Joint Angles (7, <PAD> if fewer) + Gripper Open/Close (1)
state_encoding: Type of `StateEncoding`
action_encoding: Type of action encoding (e.g., EEF Position vs. Joint Position)
"""
from enum import IntEnum
from typing import Dict
import tensorflow as tf
def zero_action_filter(traj: Dict) -> bool:
"""
Filters transitions whose actions are all-0 (only relative actions, no gripper action).
Note: this filter is applied *after* action normalization, so need to compare to "normalized 0".
"""
DROID_Q01 = tf.convert_to_tensor(
[
-0.7776297926902771,
-0.5803514122962952,
-0.5795090794563293,
-0.6464047729969025,
-0.7041108310222626,
-0.8895104378461838,
]
)
DROID_Q99 = tf.convert_to_tensor(
[
0.7597932070493698,
0.5726242214441299,
0.7351000607013702,
0.6705610305070877,
0.6464948207139969,
0.8897542208433151,
]
)
DROID_NORM_0_ACT = 2 * (tf.zeros_like(traj["action"][:, :6]) - DROID_Q01) / (DROID_Q99 - DROID_Q01 + 1e-8) - 1
return tf.reduce_any(tf.math.abs(traj["action"][:, :6] - DROID_NORM_0_ACT) > 1e-5)
# Defines Proprioceptive State Encoding Schemes
class StateEncoding(IntEnum):
# fmt: off
NONE = -1 # No Proprioceptive State
POS_EULER = 1 # EEF XYZ (3) + Roll-Pitch-Yaw (3) + <PAD> (1) + Gripper Open/Close (1)
POS_QUAT = 2 # EEF XYZ (3) + Quaternion (4) + Gripper Open/Close (1)
JOINT = 3 # Joint Angles (7, <PAD> if fewer) + Gripper Open/Close (1)
JOINT_BIMANUAL = 4 # Joint Angles (2 x [ Joint Angles (6) + Gripper Open/Close (1) ])
# fmt: on
# Defines Action Encoding Schemes
class ActionEncoding(IntEnum):
# fmt: off
EEF_POS = 1 # EEF Delta XYZ (3) + Roll-Pitch-Yaw (3) + Gripper Open/Close (1)
JOINT_POS = 2 # Joint Delta Position (7) + Gripper Open/Close (1)
JOINT_POS_BIMANUAL = 3 # Joint Delta Position (2 x [ Joint Delta Position (6) + Gripper Open/Close (1) ])
EEF_R6 = 4 # EEF Delta XYZ (3) + R6 (6) + Gripper Open/Close (1)
# fmt: on
# === Individual Dataset Configs ===
OXE_DATASET_CONFIGS = {
"fractal20220817_data": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["base_pose_tool_reached", "gripper_closed"],
"state_encoding": StateEncoding.POS_QUAT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 3,
"robot_type": "Google Robot",
},
"kuka": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": [
"clip_function_input/base_pose_tool_reached",
"gripper_closed",
],
"state_encoding": StateEncoding.POS_QUAT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "Kuka iiwa",
},
"bridge_oxe": { # Version of Bridge V2 in Open X-Embodiment mixture
"image_obs_keys": {"primary": "image", "secondary": "image_1", "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 5,
"robot_type": "WidowX",
},
"bridge_orig": { # Original version of Bridge V2 from project website
"image_obs_keys": {"primary": "image_0", "secondary": "image_1", "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 5,
"robot_type": "WidowX",
},
"bridge_dataset": { # Original version of Bridge V2 from project website
"image_obs_keys": {"primary": "image_0", "secondary": "image_1", "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 5,
"robot_type": "WidowX",
},
"taco_play": {
"image_obs_keys": {
"primary": "rgb_static",
"secondary": None,
"wrist": "rgb_gripper",
},
"depth_obs_keys": {
"primary": "depth_static",
"secondary": None,
"wrist": "depth_gripper",
},
"state_obs_keys": ["state_eef", None, "state_gripper"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 15,
"robot_type": "Franka",
},
"jaco_play": {
"image_obs_keys": {
"primary": "image",
"secondary": None,
"wrist": "image_wrist",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["state_eef", None, "state_gripper"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "Jaco 2",
},
"berkeley_cable_routing": {
"image_obs_keys": {
"primary": "image",
"secondary": "top_image",
"wrist": "wrist45_image",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["robot_state", None],
"state_encoding": StateEncoding.JOINT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "Franka",
},
"roboturk": {
"image_obs_keys": {"primary": "front_rgb", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": [None, None, None, None, None, None, None, None],
"state_encoding": StateEncoding.NONE,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "Sawyer",
},
"nyu_door_opening_surprising_effectiveness": {
"image_obs_keys": {"primary": None, "secondary": None, "wrist": "image"},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": [None, None, None, None, None, None, None, None],
"state_encoding": StateEncoding.NONE,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 3,
"robot_type": "Hello Stretch",
},
"viola": {
"image_obs_keys": {
"primary": "agentview_rgb",
"secondary": None,
"wrist": "eye_in_hand_rgb",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["joint_states", "gripper_states"],
"state_encoding": StateEncoding.JOINT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 20,
"robot_type": "Franka",
},
"berkeley_autolab_ur5": {
"image_obs_keys": {
"primary": "image",
"secondary": None,
"wrist": "hand_image",
},
"depth_obs_keys": {"primary": "depth", "secondary": None, "wrist": None},
"state_obs_keys": ["state"],
"state_encoding": StateEncoding.POS_QUAT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 5,
"robot_type": "UR5",
},
"toto": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["state", None],
"state_encoding": StateEncoding.JOINT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 30,
"robot_type": "Franka",
},
"language_table": {
"image_obs_keys": {"primary": "rgb", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["effector_translation", None, None, None, None, None, None],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "xArm",
},
"columbia_cairlab_pusht_real": {
"image_obs_keys": {
"primary": "image",
"secondary": None,
"wrist": "wrist_image",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["robot_state", None, None, None, None, None, None],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "UR5",
},
"stanford_kuka_multimodal_dataset_converted_externally_to_rlds": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": "depth_image", "secondary": None, "wrist": None},
"state_obs_keys": ["ee_position", "ee_orientation", None],
"state_encoding": StateEncoding.POS_QUAT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 20,
"robot_type": "Kuka iiwa",
},
"nyu_rot_dataset_converted_externally_to_rlds": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["eef_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 3,
"robot_type": "xArm",
},
"stanford_hydra_dataset_converted_externally_to_rlds": {
"image_obs_keys": {
"primary": "image",
"secondary": None,
"wrist": "wrist_image",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["eef_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "Franka",
},
"austin_buds_dataset_converted_externally_to_rlds": {
"image_obs_keys": {
"primary": "image",
"secondary": None,
"wrist": "wrist_image",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["state"],
"state_encoding": StateEncoding.JOINT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 20,
"robot_type": "Franka",
},
"nyu_franka_play_dataset_converted_externally_to_rlds": {
"image_obs_keys": {
"primary": "image",
"secondary": "image_additional_view",
"wrist": None,
},
"depth_obs_keys": {
"primary": "depth",
"secondary": "depth_additional_view",
"wrist": None,
},
"state_obs_keys": ["eef_state", None, None],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 3,
"robot_type": "Franka",
},
"maniskill_dataset_converted_externally_to_rlds": {
"image_obs_keys": {
"primary": "image",
"secondary": None,
"wrist": "wrist_image",
},
"depth_obs_keys": {
"primary": "depth",
"secondary": None,
"wrist": "wrist_depth",
},
"state_obs_keys": ["tcp_pose", "gripper_state"],
"state_encoding": StateEncoding.POS_QUAT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 20,
"robot_type": "Franka",
},
"furniture_bench_dataset_converted_externally_to_rlds": {
"image_obs_keys": {
"primary": "image",
"secondary": None,
"wrist": "wrist_image",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["state"],
"state_encoding": StateEncoding.POS_QUAT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "Franka",
},
"cmu_franka_exploration_dataset_converted_externally_to_rlds": {
"image_obs_keys": {
"primary": "highres_image",
"secondary": None,
"wrist": None,
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": [None, None, None, None, None, None, None, None],
"state_encoding": StateEncoding.NONE,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "Franka",
},
"ucsd_kitchen_dataset_converted_externally_to_rlds": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["joint_state", None],
"state_encoding": StateEncoding.JOINT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 2,
"robot_type": "xArm",
},
"ucsd_pick_and_place_dataset_converted_externally_to_rlds": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["eef_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 3,
"robot_type": "xArm",
},
"austin_sailor_dataset_converted_externally_to_rlds": {
"image_obs_keys": {
"primary": "image",
"secondary": None,
"wrist": "wrist_image",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["state"],
"state_encoding": StateEncoding.POS_QUAT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 20,
"robot_type": "Franka",
},
"austin_sirius_dataset_converted_externally_to_rlds": {
"image_obs_keys": {
"primary": "image",
"secondary": None,
"wrist": "wrist_image",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["state"],
"state_encoding": StateEncoding.POS_QUAT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 20,
"robot_type": "Franka",
},
"bc_z": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": [
"present/xyz",
"present/axis_angle",
None,
"present/sensed_close",
],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "Google Robot",
},
"utokyo_pr2_opening_fridge_converted_externally_to_rlds": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["eef_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "PR2",
},
"utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["eef_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "PR2",
},
"utokyo_xarm_pick_and_place_converted_externally_to_rlds": {
"image_obs_keys": {
"primary": "image",
"secondary": "image2",
"wrist": "hand_image",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["end_effector_pose", None, None],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "xArm",
},
"utokyo_xarm_bimanual_converted_externally_to_rlds": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["pose_r", None, None],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "xArm Bimanual",
},
"robo_net": {
"image_obs_keys": {"primary": "image", "secondary": "image1", "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["eef_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 1,
"robot_type": "Multi-Robot",
},
"berkeley_mvp_converted_externally_to_rlds": {
"image_obs_keys": {"primary": None, "secondary": None, "wrist": "hand_image"},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["pose", "gripper"],
"state_encoding": StateEncoding.POS_QUAT,
"action_encoding": ActionEncoding.JOINT_POS,
"control_frequency": 5,
"robot_type": "xArm",
},
"berkeley_rpt_converted_externally_to_rlds": {
"image_obs_keys": {"primary": None, "secondary": None, "wrist": "hand_image"},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["joint_pos", "gripper"],
"state_encoding": StateEncoding.JOINT,
"action_encoding": ActionEncoding.JOINT_POS,
"control_frequency": 30,
"robot_type": "Franka",
},
"kaist_nonprehensile_converted_externally_to_rlds": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["state", None],
"state_encoding": StateEncoding.POS_QUAT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "Franka",
},
"stanford_mask_vit_converted_externally_to_rlds": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["eef_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": None,
"robot_type": "Sawyer",
},
"tokyo_u_lsmo_converted_externally_to_rlds": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["eef_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "Cobotta",
},
"dlr_sara_pour_converted_externally_to_rlds": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["state", None, None],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "DLR SARA",
},
"dlr_sara_grid_clamp_converted_externally_to_rlds": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["state", None, None],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "DLR SARA",
},
"dlr_edan_shared_control_converted_externally_to_rlds": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["state", None],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 5,
"robot_type": "DLR EDAN",
},
"asu_table_top_converted_externally_to_rlds": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["eef_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 12.5,
"robot_type": "UR5",
},
"stanford_robocook_converted_externally_to_rlds": {
"image_obs_keys": {"primary": "image_1", "secondary": "image_2", "wrist": None},
"depth_obs_keys": {"primary": "depth_1", "secondary": "depth_2", "wrist": None},
"state_obs_keys": ["eef_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 5,
"robot_type": "Franka",
},
"imperialcollege_sawyer_wrist_cam": {
"image_obs_keys": {
"primary": "image",
"secondary": None,
"wrist": "wrist_image",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": [None, None, None, None, None, None, None, "state"],
"state_encoding": StateEncoding.NONE,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "Sawyer",
},
"iamlab_cmu_pickup_insert_converted_externally_to_rlds": {
"image_obs_keys": {
"primary": "image",
"secondary": None,
"wrist": "wrist_image",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["joint_state", "gripper_state"],
"state_encoding": StateEncoding.JOINT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 20,
"robot_type": "Franka",
},
"uiuc_d3field": {
"image_obs_keys": {"primary": "image_1", "secondary": "image_2", "wrist": None},
"depth_obs_keys": {"primary": "depth_1", "secondary": "depth_2", "wrist": None},
"state_obs_keys": [None, None, None, None, None, None, None, None],
"state_encoding": StateEncoding.NONE,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 1,
"robot_type": "Kinova Gen3",
},
"utaustin_mutex": {
"image_obs_keys": {
"primary": "image",
"secondary": None,
"wrist": "wrist_image",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["state"],
"state_encoding": StateEncoding.JOINT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 20,
"robot_type": "Franka",
},
"berkeley_fanuc_manipulation": {
"image_obs_keys": {
"primary": "image",
"secondary": None,
"wrist": "wrist_image",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["joint_state", None, "gripper_state"],
"state_encoding": StateEncoding.JOINT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "Fanuc Mate",
},
"cmu_playing_with_food": {
"image_obs_keys": {
"primary": "image",
"secondary": None,
"wrist": "finger_vision_1",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["state", None, None],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "Franka",
},
"cmu_play_fusion": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["state"],
"state_encoding": StateEncoding.JOINT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 5,
"robot_type": "Franka",
},
"cmu_stretch": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["eef_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "Hello Stretch",
},
"berkeley_gnm_recon": {
"image_obs_keys": {"primary": None, "secondary": None, "wrist": "image"},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["state", None, None],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 3,
"robot_type": "Jackal",
},
"berkeley_gnm_cory_hall": {
"image_obs_keys": {"primary": None, "secondary": None, "wrist": "image"},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["state", None, None],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 5,
"robot_type": "RC Car",
},
"berkeley_gnm_sac_son": {
"image_obs_keys": {"primary": None, "secondary": None, "wrist": "image"},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["state", None, None],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "TurtleBot 2",
},
# NOTE: modified
"droid": {
"image_obs_keys": {
"primary": "exterior_image_1_left",
"secondary": "exterior_image_2_left",
"wrist": "wrist_image_left",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_QUAT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 15,
"robot_type": "Franka",
"aux_kwargs": {
"dataset_frame_transform_kwargs": {
"chunk_filter_fn": zero_action_filter,
},
},
},
"fmb_dataset": {
"image_obs_keys": {
"primary": "image_side_1",
"secondary": "image_side_2",
"wrist": "image_wrist_1",
},
"depth_obs_keys": {
"primary": "image_side_1_depth",
"secondary": "image_side_2_depth",
"wrist": "image_wrist_1_depth",
},
"state_obs_keys": ["proprio"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "Franka",
},
# NOTE: modified
"dobbe": {
"image_obs_keys": {"primary": "wrist_image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 3.75,
"robot_type": "Hello Stretch",
},
"roboset": {
"image_obs_keys": {
"primary": "image_left",
"secondary": "image_right",
"wrist": "image_wrist",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["proprio"],
"state_encoding": StateEncoding.JOINT,
"action_encoding": ActionEncoding.JOINT_POS,
"control_frequency": 5,
"robot_type": "Franka",
},
"rh20t": {
"image_obs_keys": {
"primary": "image_front",
"secondary": "image_side_right",
"wrist": "image_wrist",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["proprio"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "Flexiv",
},
### T-DROID datasets
"tdroid_carrot_in_bowl": { # "put carrot in bowl" task, 50 demos @ 5 Hz control
"image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 5,
"robot_type": "Franka",
},
"tdroid_pour_corn_in_pot": { # "pour corn from red bonawl into steel pot" task, 50 demos @ 5 Hz control
"image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 5,
"robot_type": "Franka",
},
"tdroid_flip_pot_upright": { # "flip pot upright" task, 10 demos @ 5 Hz control
"image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 5,
"robot_type": "Franka",
},
"tdroid_move_object_onto_plate": { # "move <object> onto plate" task, 150 demos @ 5 Hz control
"image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 5,
"robot_type": "Franka",
},
"tdroid_knock_object_over": { # "knock <object> over" task, 70 demos @ 5 Hz control
"image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 5,
"robot_type": "Franka",
},
"tdroid_cover_object_with_towel": { # "cover <object> with towel" task, 45 demos @ 5 Hz control
"image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 5,
"robot_type": "Franka",
},
### DROID Finetuning datasets
"droid_wipe": {
"image_obs_keys": {"primary": "exterior_image_2_left", "secondary": None, "wrist": "wrist_image_left"},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["proprio"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 15,
"robot_type": "Franka",
},
# NOTE: modified
### LIBERO datasets (modified versions)
"libero_spatial_no_noops": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": "wrist_image"},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 20,
"robot_type": "Franka",
},
"libero_object_no_noops": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": "wrist_image"},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 20,
"robot_type": "Franka",
},
"libero_goal_no_noops": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": "wrist_image"},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 20,
"robot_type": "Franka",
},
"libero_10_no_noops": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": "wrist_image"},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 20,
"robot_type": "Franka",
},
}

View File

@ -0,0 +1,76 @@
"""
Copied from https://github.com/openvla/openvla/blob/main/prismatic/vla/datasets/rlds/utils/data_utils.py
"""
from typing import Any, Dict
import tensorflow as tf
def binarize_gripper_actions(actions: tf.Tensor) -> tf.Tensor:
"""
Converts gripper actions from continuous to binary values (0 and 1).
We exploit that fact that most of the time, the gripper is fully open (near 1.0) or fully closed (near 0.0). As it
transitions between the two, it sometimes passes through a few intermediate values. We relabel those intermediate
values based on the state that is reached _after_ those intermediate values.
In the edge case that the trajectory ends with an intermediate value, we give up on binarizing and relabel that
chunk of intermediate values as the last action in the trajectory.
The `scan_fn` implements the following logic:
new_actions = np.empty_like(actions)
carry = actions[-1]
for i in reversed(range(actions.shape[0])):
if in_between_mask[i]:
carry = carry
else:
carry = float(open_mask[i])
new_actions[i] = carry
"""
open_mask, closed_mask = actions > 0.95, actions < 0.05
in_between_mask = tf.logical_not(tf.logical_or(open_mask, closed_mask))
is_open_float = tf.cast(open_mask, tf.float32)
def scan_fn(carry, i):
return tf.cond(in_between_mask[i], lambda: tf.cast(carry, tf.float32), lambda: is_open_float[i])
return tf.scan(scan_fn, tf.range(tf.shape(actions)[0]), actions[-1], reverse=True)
def invert_gripper_actions(actions: tf.Tensor) -> tf.Tensor:
return 1 - actions
def rel2abs_gripper_actions(actions: tf.Tensor) -> tf.Tensor:
"""
Converts relative gripper actions (+1 for closing, -1 for opening) to absolute actions (0 = closed; 1 = open).
Assumes that the first relative gripper is not redundant (i.e. close when already closed)!
"""
# Note =>> -1 for closing, 1 for opening, 0 for no change
opening_mask, closing_mask = actions < -0.1, actions > 0.1
thresholded_actions = tf.where(opening_mask, 1, tf.where(closing_mask, -1, 0))
def scan_fn(carry, i):
return tf.cond(thresholded_actions[i] == 0, lambda: carry, lambda: thresholded_actions[i])
# If no relative grasp, assumes open for whole trajectory
start = -1 * thresholded_actions[tf.argmax(thresholded_actions != 0, axis=0)]
start = tf.cond(start == 0, lambda: 1, lambda: start)
# Note =>> -1 for closed, 1 for open
new_actions = tf.scan(scan_fn, tf.range(tf.shape(actions)[0]), start)
new_actions = tf.cast(new_actions, tf.float32) / 2 + 0.5
return new_actions
# === Bridge-V2 =>> Dataset-Specific Transform ===
def relabel_bridge_actions(traj: Dict[str, Any]) -> Dict[str, Any]:
"""Relabels actions to use reached proprioceptive state; discards last timestep (no-action)."""
movement_actions = traj["observation"]["state"][1:, :6] - traj["observation"]["state"][:-1, :6]
traj_truncated = tf.nest.map_structure(lambda x: x[:-1], traj)
traj_truncated["action"] = tf.concat([movement_actions, traj["action"][:-1, -1:]], axis=1)
return traj_truncated

View File

@ -0,0 +1,985 @@
"""
Adapt from https://github.com/openvla/openvla/blob/main/prismatic/vla/datasets/rlds/oxe/transforms.py
transforms.py
Defines a registry of per-dataset standardization transforms for each dataset in Open-X Embodiment.
Transforms adopt the following structure:
Input: Dictionary of *batched* features (i.e., has leading time dimension)
Output: Dictionary `step` =>> {
"observation": {
<image_keys, depth_image_keys>
State (in chosen state representation)
},
"action": Action (in chosen action representation),
"language_instruction": str
}
"""
from typing import Any, Dict
import tensorflow as tf
from oxe_utils.transform_utils import (
binarize_gripper_actions,
invert_gripper_actions,
rel2abs_gripper_actions,
relabel_bridge_actions,
)
def droid_baseact_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
"""
DROID dataset transformation for actions expressed in *base* frame of the robot.
"""
def rand_swap_exterior_images(img1, img2):
"""
Randomly swaps the two exterior images (for training with single exterior input).
"""
return tf.cond(tf.random.uniform(shape=[]) > 0.5, lambda: (img1, img2), lambda: (img2, img1))
dt = trajectory["action_dict"]["cartesian_velocity"][:, :3]
dR = trajectory["action_dict"]["cartesian_velocity"][:, 3:6]
trajectory["action"] = tf.concat(
(
dt,
dR,
1 - trajectory["action_dict"]["gripper_position"],
),
axis=-1,
)
trajectory["observation"]["exterior_image_1_left"], trajectory["observation"]["exterior_image_2_left"] = (
rand_swap_exterior_images(
trajectory["observation"]["exterior_image_1_left"],
trajectory["observation"]["exterior_image_2_left"],
)
)
# trajectory["observation"]["proprio"] = tf.concat(
# (
# trajectory["observation"]["cartesian_position"],
# trajectory["observation"]["gripper_position"],
# ),
# axis=-1,
# )
trajectory["observation"]["EEF_state"] = trajectory["observation"]["cartesian_position"]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["gripper_position"]
return trajectory
def droid_finetuning_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
"""
DROID dataset transformation for actions expressed in *base* frame of the robot.
"""
dt = trajectory["action_dict"]["cartesian_velocity"][:, :3]
dR = trajectory["action_dict"]["cartesian_velocity"][:, 3:6]
trajectory["action"] = tf.concat(
(
dt,
dR,
1 - trajectory["action_dict"]["gripper_position"],
),
axis=-1,
)
trajectory["observation"]["proprio"] = tf.concat(
(
trajectory["observation"]["cartesian_position"],
trajectory["observation"]["gripper_position"],
),
axis=-1,
)
return trajectory
def bridge_oxe_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
"""
Applies to version of Bridge V2 in Open X-Embodiment mixture.
Note =>> In original Bridge V2 dataset, the first timestep has an all-zero action, so we remove it!
"""
for key in trajectory.keys():
if key == "traj_metadata":
continue
elif key in ["observation", "action"]:
for key2 in trajectory[key]:
trajectory[key][key2] = trajectory[key][key2][1:]
else:
trajectory[key] = trajectory[key][1:]
trajectory["action"] = tf.concat(
(
trajectory["action"]["world_vector"],
trajectory["action"]["rotation_delta"],
tf.cast(trajectory["action"]["open_gripper"][:, None], tf.float32),
),
axis=-1,
)
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
trajectory = relabel_bridge_actions(trajectory)
trajectory["observation"]["EEF_state"] = trajectory["observation"]["state"][:, :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
return trajectory
def bridge_orig_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
"""
Applies to original version of Bridge V2 from the official project website.
Note =>> In original Bridge V2 dataset, the first timestep has an all-zero action, so we remove it!
"""
for key in trajectory.keys():
if key == "traj_metadata":
continue
elif key == "observation":
for key2 in trajectory[key]:
trajectory[key][key2] = trajectory[key][key2][1:]
else:
trajectory[key] = trajectory[key][1:]
trajectory["action"] = tf.concat(
[
trajectory["action"][:, :6],
binarize_gripper_actions(trajectory["action"][:, -1])[:, None],
],
axis=1,
)
trajectory = relabel_bridge_actions(trajectory)
trajectory["observation"]["EEF_state"] = trajectory["observation"]["state"][:, :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
return trajectory
def ppgm_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["action"] = tf.concat(
[
trajectory["action"][:, :6],
binarize_gripper_actions(trajectory["action"][:, -1])[:, None],
],
axis=1,
)
trajectory["observation"]["EEF_state"] = trajectory["observation"]["cartesian_position"][:, :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["gripper_position"][:, -1:]
return trajectory
def rt1_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# make gripper action absolute action, +1 = open, 0 = close
gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0]
gripper_action = rel2abs_gripper_actions(gripper_action)
trajectory["action"] = tf.concat(
(
trajectory["action"]["world_vector"],
trajectory["action"]["rotation_delta"],
gripper_action[:, None],
),
axis=-1,
)
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def kuka_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# make gripper action absolute action, +1 = open, 0 = close
gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0]
gripper_action = rel2abs_gripper_actions(gripper_action)
trajectory["action"] = tf.concat(
(
trajectory["action"]["world_vector"],
trajectory["action"]["rotation_delta"],
gripper_action[:, None],
),
axis=-1,
)
# decode compressed state
eef_value = tf.io.decode_compressed(
trajectory["observation"]["clip_function_input/base_pose_tool_reached"],
compression_type="ZLIB",
)
eef_value = tf.io.decode_raw(eef_value, tf.float32)
trajectory["observation"]["clip_function_input/base_pose_tool_reached"] = tf.reshape(eef_value, (-1, 7))
gripper_value = tf.io.decode_compressed(trajectory["observation"]["gripper_closed"], compression_type="ZLIB")
gripper_value = tf.io.decode_raw(gripper_value, tf.float32)
trajectory["observation"]["gripper_closed"] = tf.reshape(gripper_value, (-1, 1))
# trajectory["language_instruction"] = tf.fill(
# tf.shape(trajectory["observation"]["natural_language_instruction"]), ""
# ) # delete uninformative language instruction
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def taco_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["state_eef"] = trajectory["observation"]["robot_obs"][:, :6]
trajectory["observation"]["state_gripper"] = trajectory["observation"]["robot_obs"][:, 7:8]
trajectory["action"] = trajectory["action"]["rel_actions_world"]
# invert gripper action + clip, +1 = open, 0 = close
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :6],
tf.clip_by_value(trajectory["action"][:, -1:], 0, 1),
),
axis=-1,
)
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def jaco_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["state_eef"] = trajectory["observation"]["end_effector_cartesian_pos"][:, :6]
trajectory["observation"]["state_gripper"] = trajectory["observation"]["end_effector_cartesian_pos"][:, -1:]
# make gripper action absolute action, +1 = open, 0 = close
gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0]
gripper_action = rel2abs_gripper_actions(gripper_action)
trajectory["action"] = tf.concat(
(
trajectory["action"]["world_vector"],
tf.zeros_like(trajectory["action"]["world_vector"]),
gripper_action[:, None],
),
axis=-1,
)
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def berkeley_cable_routing_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["action"] = tf.concat(
(
trajectory["action"]["world_vector"],
trajectory["action"]["rotation_delta"],
tf.zeros_like(trajectory["action"]["world_vector"][:, :1]),
),
axis=-1,
)
# trajectory["language_instruction"] = tf.fill(
# tf.shape(trajectory["observation"]["natural_language_instruction"]), ""
# ) # delete uninformative language instruction
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def roboturk_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# invert absolute gripper action, +1 = open, 0 = close
gripper_action = invert_gripper_actions(tf.clip_by_value(trajectory["action"]["gripper_closedness_action"], 0, 1))
trajectory["action"] = tf.concat(
(
trajectory["action"]["world_vector"],
trajectory["action"]["rotation_delta"],
gripper_action,
),
axis=-1,
)
# trajectory["language_instruction"] = tf.fill(
# tf.shape(trajectory["observation"]["natural_language_instruction"]), ""
# ) # delete uninformative language instruction
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def nyu_door_opening_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# make gripper action absolute action, +1 = open, 0 = close
gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0]
gripper_action = rel2abs_gripper_actions(gripper_action)
trajectory["action"] = tf.concat(
(
trajectory["action"]["world_vector"],
trajectory["action"]["rotation_delta"],
gripper_action[:, None],
),
axis=-1,
)
# trajectory["language_instruction"] = tf.fill(
# tf.shape(trajectory["observation"]["natural_language_instruction"]), ""
# ) # delete uninformative language instruction
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def viola_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# make gripper action, +1 = open, 0 = close
gripper_action = trajectory["action"]["gripper_closedness_action"][:, None]
gripper_action = tf.clip_by_value(gripper_action, 0, 1)
gripper_action = invert_gripper_actions(gripper_action)
trajectory["action"] = tf.concat(
(
trajectory["action"]["world_vector"],
trajectory["action"]["rotation_delta"],
gripper_action,
),
axis=-1,
)
# trajectory["language_instruction"] = tf.fill(
# tf.shape(trajectory["observation"]["natural_language_instruction"]), ""
# ) # delete uninformative language instruction
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def berkeley_autolab_ur5_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["state"] = trajectory["observation"]["robot_state"][:, 6:14]
trajectory["observation"]["depth"] = trajectory["observation"].pop("image_with_depth")
# make gripper action absolute action, +1 = open, 0 = close
gripper_action = trajectory["action"]["gripper_closedness_action"]
gripper_action = rel2abs_gripper_actions(gripper_action)
trajectory["action"] = tf.concat(
(
trajectory["action"]["world_vector"],
trajectory["action"]["rotation_delta"],
gripper_action[:, None],
),
axis=-1,
)
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def toto_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["action"] = tf.concat(
(
trajectory["action"]["world_vector"],
trajectory["action"]["rotation_delta"],
tf.cast(trajectory["action"]["open_gripper"][:, None], tf.float32),
),
axis=-1,
)
# trajectory["language_instruction"] = tf.fill(
# tf.shape(trajectory["observation"]["natural_language_instruction"]), ""
# ) # delete uninformative language instruction
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def language_table_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# default to "open" gripper
trajectory["action"] = tf.concat(
(
trajectory["action"],
tf.zeros_like(trajectory["action"]),
tf.zeros_like(trajectory["action"]),
tf.ones_like(trajectory["action"][:, :1]),
),
axis=-1,
)
# decode language instruction
instruction_bytes = trajectory["observation"]["instruction"]
instruction_encoded = tf.strings.unicode_encode(instruction_bytes, output_encoding="UTF-8")
# Remove trailing padding --> convert RaggedTensor to regular Tensor.
trajectory["language_instruction"] = tf.strings.split(instruction_encoded, "\x00")[:, :1].to_tensor()[:, 0]
return trajectory
def pusht_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["action"] = tf.concat(
(
trajectory["action"]["world_vector"],
trajectory["action"]["rotation_delta"],
trajectory["action"]["gripper_closedness_action"][:, None],
),
axis=-1,
)
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def stanford_kuka_multimodal_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["depth_image"] = trajectory["observation"]["depth_image"][..., 0]
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :3],
tf.zeros_like(trajectory["action"][:, :3]),
trajectory["action"][:, -1:],
),
axis=-1,
)
return trajectory
def nyu_rot_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][..., :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][..., -1:]
trajectory["action"] = trajectory["action"][..., :7]
return trajectory
def stanford_hydra_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# invert gripper action, +1 = open, 0 = close
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :6],
invert_gripper_actions(trajectory["action"][:, -1:]),
),
axis=-1,
)
trajectory["observation"]["eef_state"] = tf.concat(
(
trajectory["observation"]["state"][:, :3],
trajectory["observation"]["state"][:, 7:10],
),
axis=-1,
)
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -3:-2]
# trajectory["language_instruction"] = tf.fill(
# tf.shape(trajectory["language_instruction"]), ""
# ) # delete uninformative language instruction
return trajectory
def austin_buds_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# invert gripper action + clip, +1 = open, 0 = close
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :6],
invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
),
axis=-1,
)
trajectory["observation"]["state"] = trajectory["observation"]["state"][:, :8]
# trajectory["language_instruction"] = tf.fill(
# tf.shape(trajectory["language_instruction"]), ""
# ) # delete uninformative language instruction
return trajectory
def nyu_franka_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["depth"] = tf.cast(trajectory["observation"]["depth"][..., 0], tf.float32)
trajectory["observation"]["depth_additional_view"] = tf.cast(
trajectory["observation"]["depth_additional_view"][..., 0], tf.float32
)
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, -6:]
# clip gripper action, +1 = open, 0 = close
trajectory["action"] = tf.concat(
(
trajectory["action"][:, -8:-2],
tf.clip_by_value(trajectory["action"][:, -2:-1], 0, 1),
),
axis=-1,
)
# trajectory["language_instruction"] = tf.fill(
# tf.shape(trajectory["language_instruction"]), ""
# ) # delete uninformative language instruction
return trajectory
def maniskill_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][..., 7:8]
return trajectory
def furniture_bench_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
import tensorflow_graphics.geometry.transformation as tft
trajectory["observation"]["state"] = tf.concat(
(
trajectory["observation"]["state"][:, :7],
trajectory["observation"]["state"][:, -1:],
),
axis=-1,
)
# invert gripper action + clip, +1 = open, 0 = close
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :3],
tft.euler.from_quaternion(trajectory["action"][:, 3:7]),
invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
),
axis=-1,
)
return trajectory
def cmu_franka_exploration_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["action"] = trajectory["action"][..., :-1]
return trajectory
def ucsd_kitchen_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["joint_state"] = trajectory["observation"]["state"][:, :7]
trajectory["action"] = trajectory["action"][..., :-1]
return trajectory
def ucsd_pick_place_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :3],
tf.zeros_like(trajectory["action"][:, :3]),
trajectory["action"][:, -1:],
),
axis=-1,
)
return trajectory
def austin_sailor_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# invert gripper action + clip, +1 = open, 0 = close
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :6],
invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
),
axis=-1,
)
# trajectory["language_instruction"] = tf.fill(
# tf.shape(trajectory["language_instruction"]), ""
# ) # delete uninformative language instruction
return trajectory
def austin_sirius_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# invert gripper action + clip, +1 = open, 0 = close
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :6],
invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
),
axis=-1,
)
# trajectory["language_instruction"] = tf.fill(
# tf.shape(trajectory["language_instruction"]), ""
# ) # delete uninformative language instruction
return trajectory
def bc_z_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["action"] = tf.concat(
(
trajectory["action"]["future/xyz_residual"][:, :3],
trajectory["action"]["future/axis_angle_residual"][:, :3],
invert_gripper_actions(tf.cast(trajectory["action"]["future/target_close"][:, :1], tf.float32)),
),
axis=-1,
)
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def tokyo_pr2_opening_fridge_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
trajectory["action"] = trajectory["action"][..., :-1]
return trajectory
def tokyo_pr2_tabletop_manipulation_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
trajectory["action"] = trajectory["action"][..., :-1]
return trajectory
def utokyo_xarm_pick_place_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
return trajectory
def utokyo_xarm_bimanual_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["action"] = trajectory["action"][..., -7:]
return trajectory
def robo_net_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["eef_state"] = tf.concat(
(
trajectory["observation"]["state"][:, :4],
tf.zeros_like(trajectory["observation"]["state"][:, :2]),
),
axis=-1,
)
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :4],
tf.zeros_like(trajectory["action"][:, :2]),
trajectory["action"][:, -1:],
),
axis=-1,
)
return trajectory
def berkeley_mvp_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
return trajectory
def berkeley_rpt_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
return trajectory
def kaist_nonprehensible_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["state"] = trajectory["observation"]["state"][:, -7:]
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :6],
tf.zeros_like(trajectory["action"][:, :1]),
),
axis=-1,
)
return trajectory
def stanford_mask_vit_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["eef_state"] = tf.concat(
(
trajectory["observation"]["end_effector_pose"][:, :4],
tf.zeros_like(trajectory["observation"]["end_effector_pose"][:, :2]),
),
axis=-1,
)
trajectory["observation"]["gripper_state"] = trajectory["observation"]["end_effector_pose"][:, -1:]
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :4],
tf.zeros_like(trajectory["action"][:, :2]),
trajectory["action"][:, -1:],
),
axis=-1,
)
return trajectory
def tokyo_lsmo_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
return trajectory
def dlr_sara_pour_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
return trajectory
def dlr_sara_grid_clamp_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["state"] = trajectory["observation"]["state"][:, :6]
return trajectory
def dlr_edan_shared_control_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# invert gripper action, +1 = open, 0 = close
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :6],
invert_gripper_actions(trajectory["action"][:, -1:]),
),
axis=-1,
)
return trajectory
def asu_table_top_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["eef_state"] = trajectory["ground_truth_states"]["EE"]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
return trajectory
def robocook_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
return trajectory
def imperial_wristcam_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["action"] = trajectory["action"][..., :-1]
return trajectory
def iamlab_pick_insert_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
import tensorflow_graphics.geometry.transformation as tft
trajectory["observation"]["joint_state"] = trajectory["observation"]["state"][:, :7]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, 7:8]
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :3],
tft.euler.from_quaternion(trajectory["action"][:, 3:7]),
trajectory["action"][:, 7:8],
),
axis=-1,
)
return trajectory
def uiuc_d3field_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["action"] = tf.concat(
(
trajectory["action"],
tf.zeros_like(trajectory["action"]),
tf.zeros_like(trajectory["action"][:, :1]),
),
axis=-1,
)
return trajectory
def utaustin_mutex_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["state"] = trajectory["observation"]["state"][:, :8]
# invert gripper action + clip, +1 = open, 0 = close
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :6],
invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
),
axis=-1,
)
# trajectory["language_instruction"] = tf.fill(
# tf.shape(trajectory["language_instruction"]), ""
# ) # delete uninformative language instruction
return trajectory
def berkeley_fanuc_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["joint_state"] = trajectory["observation"]["state"][:, :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, 6:7]
# dataset does not store gripper actions, so use gripper state info, invert so +1 = open, 0 = close
trajectory["action"] = tf.concat(
(
trajectory["action"],
invert_gripper_actions(trajectory["observation"]["gripper_state"]),
),
axis=-1,
)
return trajectory
def cmu_playing_with_food_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
import tensorflow_graphics.geometry.transformation as tft
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :3],
tft.euler.from_quaternion(trajectory["action"][:, 3:7]),
trajectory["action"][:, -1:],
),
axis=-1,
)
return trajectory
def playfusion_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :3],
trajectory["action"][:, -4:],
),
axis=-1,
)
return trajectory
def cmu_stretch_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["eef_state"] = tf.concat(
(
trajectory["observation"]["state"][:, :3],
tf.zeros_like(trajectory["observation"]["state"][:, :3]),
),
axis=-1,
)
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
trajectory["action"] = trajectory["action"][..., :-1]
return trajectory
def gnm_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["state"] = tf.concat(
(
trajectory["observation"]["position"],
tf.zeros_like(trajectory["observation"]["state"][:, :3]),
trajectory["observation"]["yaw"],
),
axis=-1,
)
trajectory["action"] = tf.concat(
(
trajectory["action"],
tf.zeros_like(trajectory["action"]),
tf.zeros_like(trajectory["action"]),
tf.zeros_like(trajectory["action"][:, :1]),
),
axis=-1,
)
return trajectory
def fmb_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# every input feature is batched, ie has leading batch dimension
trajectory["observation"]["proprio"] = tf.concat(
(
trajectory["observation"]["eef_pose"],
trajectory["observation"]["state_gripper_pose"][..., None],
),
axis=-1,
)
return trajectory
def dobbe_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# every input feature is batched, ie has leading batch dimension
trajectory["observation"]["EEF_state"] = trajectory["observation"]["state"][:, :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
return trajectory
def roboset_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# every input feature is batched, ie has leading batch dimension
trajectory["observation"]["proprio"] = trajectory["observation"]["state"]
# gripper action is in -1...1 --> clip to 0...1, flip
gripper_action = trajectory["action"][:, -1:]
gripper_action = invert_gripper_actions(tf.clip_by_value(gripper_action, 0, 1))
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :7],
gripper_action,
),
axis=-1,
)
return trajectory
def rh20t_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["action"] = tf.concat(
(
trajectory["action"]["tcp_base"],
tf.cast(trajectory["action"]["gripper"][:, None], tf.float32),
),
axis=-1,
)
trajectory["observation"]["proprio"] = tf.concat(
(
trajectory["observation"]["tcp_base"],
trajectory["observation"]["gripper_width"][..., None],
),
axis=-1,
)
return trajectory
def tdroid_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["action"] = tf.concat(
[
trajectory["action"][:, :6],
binarize_gripper_actions(trajectory["action"][:, -1])[:, None],
],
axis=1,
)
trajectory["observation"]["EEF_state"] = trajectory["observation"]["cartesian_position"][:, :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["gripper_position"][:, -1:]
return trajectory
def libero_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# gripper action is in -1 (open)...1 (close) --> clip to 0...1, flip --> +1 = open, 0 = close
gripper_action = trajectory["action"][:, -1:]
gripper_action = invert_gripper_actions(tf.clip_by_value(gripper_action, 0, 1))
trajectory["action"] = tf.concat(
[
trajectory["action"][:, :6],
gripper_action,
],
axis=1,
)
trajectory["observation"]["EEF_state"] = trajectory["observation"]["state"][:, :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -2:] # 2D gripper state
return trajectory
# === Registry ===
OXE_STANDARDIZATION_TRANSFORMS = {
"bridge_oxe": bridge_oxe_dataset_transform,
"bridge_orig": bridge_orig_dataset_transform,
"bridge_dataset": bridge_orig_dataset_transform,
"ppgm": ppgm_dataset_transform,
"ppgm_static": ppgm_dataset_transform,
"ppgm_wrist": ppgm_dataset_transform,
"fractal20220817_data": rt1_dataset_transform,
"kuka": kuka_dataset_transform,
"taco_play": taco_play_dataset_transform,
"jaco_play": jaco_play_dataset_transform,
"berkeley_cable_routing": berkeley_cable_routing_dataset_transform,
"roboturk": roboturk_dataset_transform,
"nyu_door_opening_surprising_effectiveness": nyu_door_opening_dataset_transform,
"viola": viola_dataset_transform,
"berkeley_autolab_ur5": berkeley_autolab_ur5_dataset_transform,
"toto": toto_dataset_transform,
"language_table": language_table_dataset_transform,
"columbia_cairlab_pusht_real": pusht_dataset_transform,
"stanford_kuka_multimodal_dataset_converted_externally_to_rlds": stanford_kuka_multimodal_dataset_transform,
"nyu_rot_dataset_converted_externally_to_rlds": nyu_rot_dataset_transform,
"stanford_hydra_dataset_converted_externally_to_rlds": stanford_hydra_dataset_transform,
"austin_buds_dataset_converted_externally_to_rlds": austin_buds_dataset_transform,
"nyu_franka_play_dataset_converted_externally_to_rlds": nyu_franka_play_dataset_transform,
"maniskill_dataset_converted_externally_to_rlds": maniskill_dataset_transform,
"furniture_bench_dataset_converted_externally_to_rlds": furniture_bench_dataset_transform,
"cmu_franka_exploration_dataset_converted_externally_to_rlds": cmu_franka_exploration_dataset_transform,
"ucsd_kitchen_dataset_converted_externally_to_rlds": ucsd_kitchen_dataset_transform,
"ucsd_pick_and_place_dataset_converted_externally_to_rlds": ucsd_pick_place_dataset_transform,
"austin_sailor_dataset_converted_externally_to_rlds": austin_sailor_dataset_transform,
"austin_sirius_dataset_converted_externally_to_rlds": austin_sirius_dataset_transform,
"bc_z": bc_z_dataset_transform,
"utokyo_pr2_opening_fridge_converted_externally_to_rlds": tokyo_pr2_opening_fridge_dataset_transform,
"utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds": tokyo_pr2_tabletop_manipulation_dataset_transform,
"utokyo_xarm_pick_and_place_converted_externally_to_rlds": utokyo_xarm_pick_place_dataset_transform,
"utokyo_xarm_bimanual_converted_externally_to_rlds": utokyo_xarm_bimanual_dataset_transform,
"robo_net": robo_net_dataset_transform,
"berkeley_mvp_converted_externally_to_rlds": berkeley_mvp_dataset_transform,
"berkeley_rpt_converted_externally_to_rlds": berkeley_rpt_dataset_transform,
"kaist_nonprehensile_converted_externally_to_rlds": kaist_nonprehensible_dataset_transform,
"stanford_mask_vit_converted_externally_to_rlds": stanford_mask_vit_dataset_transform,
"tokyo_u_lsmo_converted_externally_to_rlds": tokyo_lsmo_dataset_transform,
"dlr_sara_pour_converted_externally_to_rlds": dlr_sara_pour_dataset_transform,
"dlr_sara_grid_clamp_converted_externally_to_rlds": dlr_sara_grid_clamp_dataset_transform,
"dlr_edan_shared_control_converted_externally_to_rlds": dlr_edan_shared_control_dataset_transform,
"asu_table_top_converted_externally_to_rlds": asu_table_top_dataset_transform,
"stanford_robocook_converted_externally_to_rlds": robocook_dataset_transform,
"imperialcollege_sawyer_wrist_cam": imperial_wristcam_dataset_transform,
"iamlab_cmu_pickup_insert_converted_externally_to_rlds": iamlab_pick_insert_dataset_transform,
"uiuc_d3field": uiuc_d3field_dataset_transform,
"utaustin_mutex": utaustin_mutex_dataset_transform,
"berkeley_fanuc_manipulation": berkeley_fanuc_dataset_transform,
"cmu_playing_with_food": cmu_playing_with_food_dataset_transform,
"cmu_play_fusion": playfusion_dataset_transform,
"cmu_stretch": cmu_stretch_dataset_transform,
"berkeley_gnm_recon": gnm_dataset_transform,
"berkeley_gnm_cory_hall": gnm_dataset_transform,
"berkeley_gnm_sac_son": gnm_dataset_transform,
"droid": droid_baseact_transform,
"fmb_dataset": fmb_dataset_transform,
"dobbe": dobbe_dataset_transform,
"roboset": roboset_dataset_transform,
"rh20t_rlds": rh20t_dataset_transform,
### T-DROID datasets
"tdroid_carrot_in_bowl": tdroid_dataset_transform,
"tdroid_pour_corn_in_pot": tdroid_dataset_transform,
"tdroid_flip_pot_upright": tdroid_dataset_transform,
"tdroid_move_object_onto_plate": tdroid_dataset_transform,
"tdroid_knock_object_over": tdroid_dataset_transform,
"tdroid_cover_object_with_towel": tdroid_dataset_transform,
### DROID Finetuning datasets
"droid_wipe": droid_finetuning_transform,
### LIBERO datasets (modified versions)
"libero_spatial_no_noops": libero_dataset_transform,
"libero_object_no_noops": libero_dataset_transform,
"libero_goal_no_noops": libero_dataset_transform,
"libero_10_no_noops": libero_dataset_transform,
}