From 3380665e3ead37c1f77de9e77231dc12021ac05d Mon Sep 17 00:00:00 2001 From: youliangtan Date: Thu, 11 Jul 2024 18:07:35 -0700 Subject: [PATCH] apply oxe configs and transform for dataset standardization Signed-off-by: youliangtan --- .../push_dataset_to_hub/oxe/configs.py | 651 +++++++++++++ .../push_dataset_to_hub/oxe/data_utils.py | 91 ++ .../push_dataset_to_hub/oxe/droid_utils.py | 183 ++++ .../push_dataset_to_hub/oxe/transforms.py | 912 ++++++++++++++++++ .../push_dataset_to_hub/oxe_rlds_format.py | 161 +++- lerobot/scripts/push_dataset_to_hub.py | 26 +- 6 files changed, 1978 insertions(+), 46 deletions(-) create mode 100644 lerobot/common/datasets/push_dataset_to_hub/oxe/configs.py create mode 100644 lerobot/common/datasets/push_dataset_to_hub/oxe/data_utils.py create mode 100644 lerobot/common/datasets/push_dataset_to_hub/oxe/droid_utils.py create mode 100644 lerobot/common/datasets/push_dataset_to_hub/oxe/transforms.py diff --git a/lerobot/common/datasets/push_dataset_to_hub/oxe/configs.py b/lerobot/common/datasets/push_dataset_to_hub/oxe/configs.py new file mode 100644 index 00000000..85f6ee2d --- /dev/null +++ b/lerobot/common/datasets/push_dataset_to_hub/oxe/configs.py @@ -0,0 +1,651 @@ +""" +NOTE(YL): Adapted from: + OpenVLA: https://github.com/openvla/openvla/blob/main/prismatic/vla/datasets/rlds/oxe/configs.py + Octo: https://github.com/octo-models/octo/blob/main/octo/data/oxe/oxe_dataset_configs.py + +TODO: implement the following: + - Populate all `fps` for each dataset + - Upload the dataset config to the Readme of each dataset on huggingface hub, for verbosity + +configs.py + +Defines per-dataset configuration (kwargs) for each dataset in Open-X Embodiment. + +Configuration adopts the following structure: + image_obs_keys: + primary: primary external RGB + secondary: secondary external RGB + wrist: wrist RGB + + depth_obs_keys: + primary: primary external depth + secondary: secondary external depth + wrist: wrist depth + + # Always 8-dim =>> changes based on `StateEncoding` + state_obs_keys: + StateEncoding.POS_EULER: EEF XYZ (3) + Roll-Pitch-Yaw (3) + (1) + Gripper Open/Close (1) + StateEncoding.POS_QUAT: EEF XYZ (3) + Quaternion (4) + Gripper Open/Close (1) + StateEncoding.JOINT: Joint Angles (7, if fewer) + Gripper Open/Close (1) + + state_encoding: Type of `StateEncoding` + action_encoding: Type of action encoding (e.g., EEF Position vs. Joint Position) +""" + +from enum import IntEnum + + +# Defines Proprioceptive State Encoding Schemes +class StateEncoding(IntEnum): + # fmt: off + NONE = -1 # No Proprioceptive State + POS_EULER = 1 # EEF XYZ (3) + Roll-Pitch-Yaw (3) + (1) + Gripper Open/Close (1) + POS_QUAT = 2 # EEF XYZ (3) + Quaternion (4) + Gripper Open/Close (1) + JOINT = 3 # Joint Angles (7, if fewer) + Gripper Open/Close (1) + JOINT_BIMANUAL = 4 # Joint Angles (2 x [ Joint Angles (6) + Gripper Open/Close (1) ]) + # fmt: on + + +# Defines Action Encoding Schemes +class ActionEncoding(IntEnum): + # fmt: off + EEF_POS = 1 # EEF Delta XYZ (3) + Roll-Pitch-Yaw (3) + Gripper Open/Close (1) + JOINT_POS = 2 # Joint Delta Position (7) + Gripper Open/Close (1) + JOINT_POS_BIMANUAL = 3 # Joint Delta Position (2 x [ Joint Delta Position (6) + Gripper Open/Close (1) ]) + EEF_R6 = 4 # EEF Delta XYZ (3) + R6 (6) + Gripper Open/Close (1) + # fmt: on + + +# === Individual Dataset Configs === +OXE_DATASET_CONFIGS = { + "fractal20220817_data": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["base_pose_tool_reached", "gripper_closed"], + "state_encoding": StateEncoding.POS_QUAT, + "action_encoding": ActionEncoding.EEF_POS, + "fps": 3, # TODO: placeholder (youliang) + }, + "kuka": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": [ + "clip_function_input/base_pose_tool_reached", + "gripper_closed", + ], + "state_encoding": StateEncoding.POS_QUAT, + "action_encoding": ActionEncoding.EEF_POS, + }, + "bridge_oxe": { # Version of Bridge V2 in Open X-Embodiment mixture + "image_obs_keys": {"primary": "image", "secondary": "image_1", "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", None, "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "bridge_orig": { # Original version of Bridge V2 from project website + "image_obs_keys": {"primary": "image_0", "secondary": "image_1", "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", None, "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + "fps": 5, + }, + "bridge_dataset": { # Original version of Bridge V2 from project website + "image_obs_keys": {"primary": "image_0", "secondary": "image_1", "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", None, "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + "fps": 5, + }, + "taco_play": { + "image_obs_keys": { + "primary": "rgb_static", + "secondary": None, + "wrist": "rgb_gripper", + }, + "depth_obs_keys": { + "primary": "depth_static", + "secondary": None, + "wrist": "depth_gripper", + }, + "state_obs_keys": ["state_eef", None, "state_gripper"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "jaco_play": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "image_wrist", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state_eef", None, "state_gripper"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "berkeley_cable_routing": { + "image_obs_keys": { + "primary": "image", + "secondary": "top_image", + "wrist": "wrist45_image", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["robot_state", None], + "state_encoding": StateEncoding.JOINT, + "action_encoding": ActionEncoding.EEF_POS, + }, + "roboturk": { + "image_obs_keys": {"primary": "front_rgb", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": [None, None, None, None, None, None, None, None], + "state_encoding": StateEncoding.NONE, + "action_encoding": ActionEncoding.EEF_POS, + }, + "nyu_door_opening_surprising_effectiveness": { + "image_obs_keys": {"primary": None, "secondary": None, "wrist": "image"}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": [None, None, None, None, None, None, None, None], + "state_encoding": StateEncoding.NONE, + "action_encoding": ActionEncoding.EEF_POS, + }, + "viola": { + "image_obs_keys": { + "primary": "agentview_rgb", + "secondary": None, + "wrist": "eye_in_hand_rgb", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["joint_states", "gripper_states"], + "state_encoding": StateEncoding.JOINT, + "action_encoding": ActionEncoding.EEF_POS, + }, + "berkeley_autolab_ur5": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "hand_image", + }, + "depth_obs_keys": {"primary": "depth", "secondary": None, "wrist": None}, + "state_obs_keys": ["state"], + "state_encoding": StateEncoding.POS_QUAT, + "action_encoding": ActionEncoding.EEF_POS, + }, + "toto": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state", None], + "state_encoding": StateEncoding.JOINT, + "action_encoding": ActionEncoding.EEF_POS, + }, + "language_table": { + "image_obs_keys": {"primary": "rgb", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["effector_translation", None, None, None, None, None, None], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "columbia_cairlab_pusht_real": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "wrist_image", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["robot_state", None, None, None, None, None, None], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "stanford_kuka_multimodal_dataset_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": "depth_image", "secondary": None, "wrist": None}, + "state_obs_keys": ["ee_position", "ee_orientation", None], + "state_encoding": StateEncoding.POS_QUAT, + "action_encoding": ActionEncoding.EEF_POS, + }, + "nyu_rot_dataset_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["eef_state", None, "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "stanford_hydra_dataset_converted_externally_to_rlds": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "wrist_image", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["eef_state", None, "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "austin_buds_dataset_converted_externally_to_rlds": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "wrist_image", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state"], + "state_encoding": StateEncoding.JOINT, + "action_encoding": ActionEncoding.EEF_POS, + "fps": 20, + }, + "nyu_franka_play_dataset_converted_externally_to_rlds": { + "image_obs_keys": { + "primary": "image", + "secondary": "image_additional_view", + "wrist": None, + }, + "depth_obs_keys": { + "primary": "depth", + "secondary": "depth_additional_view", + "wrist": None, + }, + "state_obs_keys": ["eef_state", None, None], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + "fps": 3, + }, + "maniskill_dataset_converted_externally_to_rlds": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "wrist_image", + }, + "depth_obs_keys": { + "primary": "depth", + "secondary": None, + "wrist": "wrist_depth", + }, + "state_obs_keys": ["tcp_pose", "gripper_state"], + "state_encoding": StateEncoding.POS_QUAT, + "action_encoding": ActionEncoding.EEF_POS, + }, + "furniture_bench_dataset_converted_externally_to_rlds": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "wrist_image", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state"], + "state_encoding": StateEncoding.POS_QUAT, + "action_encoding": ActionEncoding.EEF_POS, + }, + "cmu_franka_exploration_dataset_converted_externally_to_rlds": { + "image_obs_keys": { + "primary": "highres_image", + "secondary": None, + "wrist": None, + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": [None, None, None, None, None, None, None, None], + "state_encoding": StateEncoding.NONE, + "action_encoding": ActionEncoding.EEF_POS, + }, + "ucsd_kitchen_dataset_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["joint_state", None], + "state_encoding": StateEncoding.JOINT, + "action_encoding": ActionEncoding.EEF_POS, + }, + "ucsd_pick_and_place_dataset_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["eef_state", None, "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "austin_sailor_dataset_converted_externally_to_rlds": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "wrist_image", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state"], + "state_encoding": StateEncoding.POS_QUAT, + "action_encoding": ActionEncoding.EEF_POS, + }, + "austin_sirius_dataset_converted_externally_to_rlds": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "wrist_image", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state"], + "state_encoding": StateEncoding.POS_QUAT, + "action_encoding": ActionEncoding.EEF_POS, + }, + "bc_z": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": [ + "present/xyz", + "present/axis_angle", + None, + "present/sensed_close", + ], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "utokyo_pr2_opening_fridge_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["eef_state", None, "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["eef_state", None, "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "utokyo_xarm_pick_and_place_converted_externally_to_rlds": { + "image_obs_keys": { + "primary": "image", + "secondary": "image2", + "wrist": "hand_image", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["end_effector_pose", None, None], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "utokyo_xarm_bimanual_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["pose_r", None, None], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "robo_net": { + "image_obs_keys": {"primary": "image", "secondary": "image1", "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["eef_state", None, "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "berkeley_mvp_converted_externally_to_rlds": { + "image_obs_keys": {"primary": None, "secondary": None, "wrist": "hand_image"}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["pose", "gripper"], + "state_encoding": StateEncoding.POS_QUAT, + "action_encoding": ActionEncoding.JOINT_POS, + }, + "berkeley_rpt_converted_externally_to_rlds": { + "image_obs_keys": {"primary": None, "secondary": None, "wrist": "hand_image"}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["joint_pos", "gripper"], + "state_encoding": StateEncoding.JOINT, + "action_encoding": ActionEncoding.JOINT_POS, + }, + "kaist_nonprehensile_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state", None], + "state_encoding": StateEncoding.POS_QUAT, + "action_encoding": ActionEncoding.EEF_POS, + }, + "stanford_mask_vit_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["eef_state", None, "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "tokyo_u_lsmo_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["eef_state", None, "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "dlr_sara_pour_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state", None, None], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "dlr_sara_grid_clamp_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state", None, None], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "dlr_edan_shared_control_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state", None], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "asu_table_top_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["eef_state", None, "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "stanford_robocook_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image_1", "secondary": "image_2", "wrist": None}, + "depth_obs_keys": {"primary": "depth_1", "secondary": "depth_2", "wrist": None}, + "state_obs_keys": ["eef_state", None, "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "imperialcollege_sawyer_wrist_cam": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "wrist_image", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": [None, None, None, None, None, None, None, "state"], + "state_encoding": StateEncoding.NONE, + "action_encoding": ActionEncoding.EEF_POS, + }, + "iamlab_cmu_pickup_insert_converted_externally_to_rlds": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "wrist_image", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["joint_state", "gripper_state"], + "state_encoding": StateEncoding.JOINT, + "action_encoding": ActionEncoding.EEF_POS, + }, + "uiuc_d3field": { + "image_obs_keys": {"primary": "image_1", "secondary": "image_2", "wrist": None}, + "depth_obs_keys": {"primary": "depth_1", "secondary": "depth_2", "wrist": None}, + "state_obs_keys": [None, None, None, None, None, None, None, None], + "state_encoding": StateEncoding.NONE, + "action_encoding": ActionEncoding.EEF_POS, + }, + "utaustin_mutex": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "wrist_image", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state"], + "state_encoding": StateEncoding.JOINT, + "action_encoding": ActionEncoding.EEF_POS, + }, + "berkeley_fanuc_manipulation": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "wrist_image", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["joint_state", None, "gripper_state"], + "state_encoding": StateEncoding.JOINT, + "action_encoding": ActionEncoding.EEF_POS, + }, + "cmu_playing_with_food": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "finger_vision_1", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state", None, None], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "cmu_play_fusion": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state"], + "state_encoding": StateEncoding.JOINT, + "action_encoding": ActionEncoding.EEF_POS, + }, + "cmu_stretch": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["eef_state", None, "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + "fps": 10, + }, + "berkeley_gnm_recon": { + "image_obs_keys": {"primary": None, "secondary": None, "wrist": "image"}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state", None, None], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "berkeley_gnm_cory_hall": { + "image_obs_keys": {"primary": None, "secondary": None, "wrist": "image"}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state", None, None], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "berkeley_gnm_sac_son": { + "image_obs_keys": {"primary": None, "secondary": None, "wrist": "image"}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state", None, None], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "droid": { + "image_obs_keys": { + "primary": "exterior_image_1_left", + "secondary": "exterior_image_2_left", + "wrist": "wrist_image_left", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["proprio"], + "state_encoding": StateEncoding.POS_QUAT, + "action_encoding": ActionEncoding.EEF_POS, + }, + "fmb_dataset": { + "image_obs_keys": { + "primary": "image_side_1", + "secondary": "image_side_2", + "wrist": "image_wrist_1", + }, + "depth_obs_keys": { + "primary": "image_side_1_depth", + "secondary": "image_side_2_depth", + "wrist": "image_wrist_1_depth", + }, + "state_obs_keys": ["proprio"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "dobbe": { + "image_obs_keys": {"primary": "wrist_image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["proprio"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "roboset": { + "image_obs_keys": { + "primary": "image_left", + "secondary": "image_right", + "wrist": "image_wrist", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["proprio"], + "state_encoding": StateEncoding.JOINT, + "action_encoding": ActionEncoding.JOINT_POS, + }, + "rh20t": { + "image_obs_keys": { + "primary": "image_front", + "secondary": "image_side_right", + "wrist": "image_wrist", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["proprio"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + ### T-DROID datasets + "tdroid_carrot_in_bowl": { # "put carrot in bowl" task, 50 demos @ 5 Hz control + "image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", None, "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "tdroid_pour_corn_in_pot": { # "pour corn from red bowl into steel pot" task, 50 demos @ 5 Hz control + "image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", None, "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "tdroid_flip_pot_upright": { # "flip pot upright" task, 10 demos @ 5 Hz control + "image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", None, "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "tdroid_move_object_onto_plate": { # "move onto plate" task, 150 demos @ 5 Hz control + "image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", None, "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "tdroid_knock_object_over": { # "knock over" task, 70 demos @ 5 Hz control + "image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", None, "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "tdroid_cover_object_with_towel": { # "cover with towel" task, 45 demos @ 5 Hz control + "image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", None, "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + ### DROID Finetuning datasets + "droid_wipe": { + "image_obs_keys": {"primary": "exterior_image_2_left", "secondary": None, "wrist": "wrist_image_left"}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["proprio"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, +} \ No newline at end of file diff --git a/lerobot/common/datasets/push_dataset_to_hub/oxe/data_utils.py b/lerobot/common/datasets/push_dataset_to_hub/oxe/data_utils.py new file mode 100644 index 00000000..f95bdff8 --- /dev/null +++ b/lerobot/common/datasets/push_dataset_to_hub/oxe/data_utils.py @@ -0,0 +1,91 @@ +""" +NOTE(YL): Adapted from: + Octo: https://github.com/octo-models/octo/blob/main/octo/data/utils/data_utils.py + +data_utils.py + +Additional utils for data processing. +""" + +from typing import Any, Dict, List + +import tensorflow as tf + + +def binarize_gripper_actions(actions: tf.Tensor) -> tf.Tensor: + """ + Converts gripper actions from continuous to binary values (0 and 1). + + We exploit that fact that most of the time, the gripper is fully open (near 1.0) or fully closed (near 0.0). As it + transitions between the two, it sometimes passes through a few intermediate values. We relabel those intermediate + values based on the state that is reached _after_ those intermediate values. + + In the edge case that the trajectory ends with an intermediate value, we give up on binarizing and relabel that + chunk of intermediate values as the last action in the trajectory. + + The `scan_fn` implements the following logic: + new_actions = np.empty_like(actions) + carry = actions[-1] + for i in reversed(range(actions.shape[0])): + if in_between_mask[i]: + carry = carry + else: + carry = float(open_mask[i]) + new_actions[i] = carry + """ + open_mask, closed_mask = actions > 0.95, actions < 0.05 + in_between_mask = tf.logical_not(tf.logical_or(open_mask, closed_mask)) + is_open_float = tf.cast(open_mask, tf.float32) + + def scan_fn(carry, i): + return tf.cond(in_between_mask[i], lambda: tf.cast(carry, tf.float32), lambda: is_open_float[i]) + + return tf.scan(scan_fn, tf.range(tf.shape(actions)[0]), actions[-1], reverse=True) + + +def invert_gripper_actions(actions: tf.Tensor) -> tf.Tensor: + return 1 - actions + + +def rel2abs_gripper_actions(actions: tf.Tensor) -> tf.Tensor: + """ + Converts relative gripper actions (+1 for closing, -1 for opening) to absolute actions (0 = closed; 1 = open). + + Assumes that the first relative gripper is not redundant (i.e. close when already closed)! + """ + # Note =>> -1 for closing, 1 for opening, 0 for no change + opening_mask, closing_mask = actions < -0.1, actions > 0.1 + thresholded_actions = tf.where(opening_mask, 1, tf.where(closing_mask, -1, 0)) + + def scan_fn(carry, i): + return tf.cond(thresholded_actions[i] == 0, lambda: carry, lambda: thresholded_actions[i]) + + # If no relative grasp, assumes open for whole trajectory + start = -1 * thresholded_actions[tf.argmax(thresholded_actions != 0, axis=0)] + start = tf.cond(start == 0, lambda: 1, lambda: start) + + # Note =>> -1 for closed, 1 for open + new_actions = tf.scan(scan_fn, tf.range(tf.shape(actions)[0]), start) + new_actions = tf.cast(new_actions, tf.float32) / 2 + 0.5 + + return new_actions + + +# === Bridge-V2 =>> Dataset-Specific Transform === +def relabel_bridge_actions(traj: Dict[str, Any]) -> Dict[str, Any]: + """Relabels actions to use reached proprioceptive state; discards last timestep (no-action).""" + movement_actions = traj["observation"]["state"][1:, :6] - traj["observation"]["state"][:-1, :6] + traj_truncated = tf.nest.map_structure(lambda x: x[:-1], traj) + traj_truncated["action"] = tf.concat([movement_actions, traj["action"][:-1, -1:]], axis=1) + + return traj_truncated + + +# === RLDS Dataset Initialization Utilities === +def pprint_data_mixture(dataset_kwargs_list: List[Dict[str, Any]], dataset_weights: List[int]) -> None: + print("\n######################################################################################") + print(f"# Loading the following {len(dataset_kwargs_list)} datasets (incl. sampling weight):{'': >24} #") + for dataset_kwargs, weight in zip(dataset_kwargs_list, dataset_weights): + pad = 80 - len(dataset_kwargs["name"]) + print(f"# {dataset_kwargs['name']}: {weight:=>{pad}f} #") + print("######################################################################################\n") diff --git a/lerobot/common/datasets/push_dataset_to_hub/oxe/droid_utils.py b/lerobot/common/datasets/push_dataset_to_hub/oxe/droid_utils.py new file mode 100644 index 00000000..a2e20a63 --- /dev/null +++ b/lerobot/common/datasets/push_dataset_to_hub/oxe/droid_utils.py @@ -0,0 +1,183 @@ +""" +NOTE(YL): Adapted from: + OpenVLA: https://github.com/openvla/openvla + +Episode transforms for DROID dataset. +""" + +from typing import Any, Dict + +import tensorflow as tf +import tensorflow_graphics.geometry.transformation as tfg + + +def rmat_to_euler(rot_mat): + return tfg.euler.from_rotation_matrix(rot_mat) + + +def euler_to_rmat(euler): + return tfg.rotation_matrix_3d.from_euler(euler) + + +def invert_rmat(rot_mat): + return tfg.rotation_matrix_3d.inverse(rot_mat) + + +def rotmat_to_rot6d(mat): + """ + Converts rotation matrix to R6 rotation representation (first two rows in rotation matrix). + Args: + mat: rotation matrix + + Returns: 6d vector (first two rows of rotation matrix) + + """ + r6 = mat[..., :2, :] + r6_0, r6_1 = r6[..., 0, :], r6[..., 1, :] + r6_flat = tf.concat([r6_0, r6_1], axis=-1) + return r6_flat + + +def velocity_act_to_wrist_frame(velocity, wrist_in_robot_frame): + """ + Translates velocity actions (translation + rotation) from base frame of the robot to wrist frame. + Args: + velocity: 6d velocity action (3 x translation, 3 x rotation) + wrist_in_robot_frame: 6d pose of the end-effector in robot base frame + + Returns: 9d velocity action in robot wrist frame (3 x translation, 6 x rotation as R6) + + """ + R_frame = euler_to_rmat(wrist_in_robot_frame[:, 3:6]) + R_frame_inv = invert_rmat(R_frame) + + # world to wrist: dT_pi = R^-1 dT_rbt + vel_t = (R_frame_inv @ velocity[:, :3][..., None])[..., 0] + + # world to wrist: dR_pi = R^-1 dR_rbt R + dR = euler_to_rmat(velocity[:, 3:6]) + dR = R_frame_inv @ (dR @ R_frame) + dR_r6 = rotmat_to_rot6d(dR) + return tf.concat([vel_t, dR_r6], axis=-1) + + +def rand_swap_exterior_images(img1, img2): + """ + Randomly swaps the two exterior images (for training with single exterior input). + """ + return tf.cond(tf.random.uniform(shape=[]) > 0.5, lambda: (img1, img2), lambda: (img2, img1)) + + +def droid_baseact_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + """ + DROID dataset transformation for actions expressed in *base* frame of the robot. + """ + dt = trajectory["action_dict"]["cartesian_velocity"][:, :3] + dR = trajectory["action_dict"]["cartesian_velocity"][:, 3:6] + + trajectory["action"] = tf.concat( + ( + dt, + dR, + 1 - trajectory["action_dict"]["gripper_position"], + ), + axis=-1, + ) + trajectory["observation"]["exterior_image_1_left"], trajectory["observation"]["exterior_image_2_left"] = ( + rand_swap_exterior_images( + trajectory["observation"]["exterior_image_1_left"], + trajectory["observation"]["exterior_image_2_left"], + ) + ) + trajectory["observation"]["proprio"] = tf.concat( + ( + trajectory["observation"]["cartesian_position"], + trajectory["observation"]["gripper_position"], + ), + axis=-1, + ) + return trajectory + + +def droid_wristact_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + """ + DROID dataset transformation for actions expressed in *wrist* frame of the robot. + """ + wrist_act = velocity_act_to_wrist_frame( + trajectory["action_dict"]["cartesian_velocity"], trajectory["observation"]["cartesian_position"] + ) + trajectory["action"] = tf.concat( + ( + wrist_act, + trajectory["action_dict"]["gripper_position"], + ), + axis=-1, + ) + trajectory["observation"]["exterior_image_1_left"], trajectory["observation"]["exterior_image_2_left"] = ( + rand_swap_exterior_images( + trajectory["observation"]["exterior_image_1_left"], + trajectory["observation"]["exterior_image_2_left"], + ) + ) + trajectory["observation"]["proprio"] = tf.concat( + ( + trajectory["observation"]["cartesian_position"], + trajectory["observation"]["gripper_position"], + ), + axis=-1, + ) + return trajectory + + +def droid_finetuning_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + """ + DROID dataset transformation for actions expressed in *base* frame of the robot. + """ + dt = trajectory["action_dict"]["cartesian_velocity"][:, :3] + dR = trajectory["action_dict"]["cartesian_velocity"][:, 3:6] + trajectory["action"] = tf.concat( + ( + dt, + dR, + 1 - trajectory["action_dict"]["gripper_position"], + ), + axis=-1, + ) + trajectory["observation"]["proprio"] = tf.concat( + ( + trajectory["observation"]["cartesian_position"], + trajectory["observation"]["gripper_position"], + ), + axis=-1, + ) + return trajectory + + +def zero_action_filter(traj: Dict) -> bool: + """ + Filters transitions whose actions are all-0 (only relative actions, no gripper action). + Note: this filter is applied *after* action normalization, so need to compare to "normalized 0". + """ + DROID_Q01 = tf.convert_to_tensor( + [ + -0.7776297926902771, + -0.5803514122962952, + -0.5795090794563293, + -0.6464047729969025, + -0.7041108310222626, + -0.8895104378461838, + ] + ) + DROID_Q99 = tf.convert_to_tensor( + [ + 0.7597932070493698, + 0.5726242214441299, + 0.7351000607013702, + 0.6705610305070877, + 0.6464948207139969, + 0.8897542208433151, + ] + ) + DROID_NORM_0_ACT = 2 * (tf.zeros_like(traj["action"][:, :6]) - DROID_Q01) / (DROID_Q99 - DROID_Q01 + 1e-8) - 1 + + return tf.reduce_any(tf.math.abs(traj["action"][:, :6] - DROID_NORM_0_ACT) > 1e-5) \ No newline at end of file diff --git a/lerobot/common/datasets/push_dataset_to_hub/oxe/transforms.py b/lerobot/common/datasets/push_dataset_to_hub/oxe/transforms.py new file mode 100644 index 00000000..becac2eb --- /dev/null +++ b/lerobot/common/datasets/push_dataset_to_hub/oxe/transforms.py @@ -0,0 +1,912 @@ +""" +NOTE(YL): Adapted from: + OpenVLA: https://github.com/openvla/openvla + Octo: https://github.com/octo-models/octo + +transforms.py + +Defines a registry of per-dataset standardization transforms for each dataset in Open-X Embodiment. + +Transforms adopt the following structure: + Input: Dictionary of *batched* features (i.e., has leading time dimension) + Output: Dictionary `step` =>> { + "observation": { + + State (in chosen state representation) + }, + "action": Action (in chosen action representation), + "language_instruction": str + } +""" + +from typing import Any, Dict + +import tensorflow as tf + +from lerobot.common.datasets.push_dataset_to_hub.oxe.data_utils import ( + binarize_gripper_actions, + invert_gripper_actions, + rel2abs_gripper_actions, + relabel_bridge_actions, +) + +def droid_baseact_transform_fn(): + from lerobot.common.datasets.push_dataset_to_hub.oxe.droid_utils import droid_baseact_transform + return droid_baseact_transform + + +def droid_finetuning_transform_fn(): + from lerobot.common.datasets.push_dataset_to_hub.oxe.droid_utils import droid_finetuning_transform + return droid_finetuning_transform + + +def bridge_oxe_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + """ + Applies to version of Bridge V2 in Open X-Embodiment mixture. + + Note =>> In original Bridge V2 dataset, the first timestep has an all-zero action, so we remove it! + """ + for key in trajectory.keys(): + if key == "traj_metadata": + continue + elif key in ["observation", "action"]: + for key2 in trajectory[key]: + trajectory[key][key2] = trajectory[key][key2][1:] + else: + trajectory[key] = trajectory[key][1:] + + trajectory["action"] = tf.concat( + ( + trajectory["action"]["world_vector"], + trajectory["action"]["rotation_delta"], + tf.cast(trajectory["action"]["open_gripper"][:, None], tf.float32), + ), + axis=-1, + ) + trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] + trajectory = relabel_bridge_actions(trajectory) + trajectory["observation"]["EEF_state"] = trajectory["observation"]["state"][:, :6] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] + return trajectory + + +def bridge_orig_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + """ + Applies to original version of Bridge V2 from the official project website. + + Note =>> In original Bridge V2 dataset, the first timestep has an all-zero action, so we remove it! + """ + for key in trajectory.keys(): + if key == "traj_metadata": + continue + elif key == "observation": + for key2 in trajectory[key]: + trajectory[key][key2] = trajectory[key][key2][1:] + else: + trajectory[key] = trajectory[key][1:] + + trajectory["action"] = tf.concat( + [ + trajectory["action"][:, :6], + binarize_gripper_actions(trajectory["action"][:, -1])[:, None], + ], + axis=1, + ) + trajectory = relabel_bridge_actions(trajectory) + trajectory["observation"]["EEF_state"] = trajectory["observation"]["state"][:, :6] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] + return trajectory + + +def ppgm_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["action"] = tf.concat( + [ + trajectory["action"][:, :6], + binarize_gripper_actions(trajectory["action"][:, -1])[:, None], + ], + axis=1, + ) + trajectory["observation"]["EEF_state"] = trajectory["observation"]["cartesian_position"][:, :6] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["gripper_position"][:, -1:] + return trajectory + + +def rt1_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + # make gripper action absolute action, +1 = open, 0 = close + gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0] + gripper_action = rel2abs_gripper_actions(gripper_action) + + trajectory["action"] = tf.concat( + ( + trajectory["action"]["world_vector"], + trajectory["action"]["rotation_delta"], + gripper_action[:, None], + ), + axis=-1, + ) + trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] + return trajectory + + +def kuka_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + # make gripper action absolute action, +1 = open, 0 = close + gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0] + gripper_action = rel2abs_gripper_actions(gripper_action) + + trajectory["action"] = tf.concat( + ( + trajectory["action"]["world_vector"], + trajectory["action"]["rotation_delta"], + gripper_action[:, None], + ), + axis=-1, + ) + # decode compressed state + eef_value = tf.io.decode_compressed( + trajectory["observation"]["clip_function_input/base_pose_tool_reached"], + compression_type="ZLIB", + ) + eef_value = tf.io.decode_raw(eef_value, tf.float32) + trajectory["observation"]["clip_function_input/base_pose_tool_reached"] = tf.reshape(eef_value, (-1, 7)) + gripper_value = tf.io.decode_compressed(trajectory["observation"]["gripper_closed"], compression_type="ZLIB") + gripper_value = tf.io.decode_raw(gripper_value, tf.float32) + trajectory["observation"]["gripper_closed"] = tf.reshape(gripper_value, (-1, 1)) + # trajectory["language_instruction"] = tf.fill( + # tf.shape(trajectory["observation"]["natural_language_instruction"]), "" + # ) # delete uninformative language instruction + trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] + return trajectory + + +def taco_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["state_eef"] = trajectory["observation"]["robot_obs"][:, :6] + trajectory["observation"]["state_gripper"] = trajectory["observation"]["robot_obs"][:, 7:8] + trajectory["action"] = trajectory["action"]["rel_actions_world"] + + # invert gripper action + clip, +1 = open, 0 = close + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :6], + tf.clip_by_value(trajectory["action"][:, -1:], 0, 1), + ), + axis=-1, + ) + + trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] + return trajectory + + +def jaco_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["state_eef"] = trajectory["observation"]["end_effector_cartesian_pos"][:, :6] + trajectory["observation"]["state_gripper"] = trajectory["observation"]["end_effector_cartesian_pos"][:, -1:] + + # make gripper action absolute action, +1 = open, 0 = close + gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0] + gripper_action = rel2abs_gripper_actions(gripper_action) + + trajectory["action"] = tf.concat( + ( + trajectory["action"]["world_vector"], + tf.zeros_like(trajectory["action"]["world_vector"]), + gripper_action[:, None], + ), + axis=-1, + ) + trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] + return trajectory + + +def berkeley_cable_routing_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["action"] = tf.concat( + ( + trajectory["action"]["world_vector"], + trajectory["action"]["rotation_delta"], + tf.zeros_like(trajectory["action"]["world_vector"][:, :1]), + ), + axis=-1, + ) + # trajectory["language_instruction"] = tf.fill( + # tf.shape(trajectory["observation"]["natural_language_instruction"]), "" + # ) # delete uninformative language instruction + trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] + return trajectory + + +def roboturk_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + # invert absolute gripper action, +1 = open, 0 = close + gripper_action = invert_gripper_actions(tf.clip_by_value(trajectory["action"]["gripper_closedness_action"], 0, 1)) + + trajectory["action"] = tf.concat( + ( + trajectory["action"]["world_vector"], + trajectory["action"]["rotation_delta"], + gripper_action, + ), + axis=-1, + ) + # trajectory["language_instruction"] = tf.fill( + # tf.shape(trajectory["observation"]["natural_language_instruction"]), "" + # ) # delete uninformative language instruction + trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] + return trajectory + + +def nyu_door_opening_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + # make gripper action absolute action, +1 = open, 0 = close + gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0] + gripper_action = rel2abs_gripper_actions(gripper_action) + + trajectory["action"] = tf.concat( + ( + trajectory["action"]["world_vector"], + trajectory["action"]["rotation_delta"], + gripper_action[:, None], + ), + axis=-1, + ) + # trajectory["language_instruction"] = tf.fill( + # tf.shape(trajectory["observation"]["natural_language_instruction"]), "" + # ) # delete uninformative language instruction + trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] + return trajectory + + +def viola_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + # make gripper action, +1 = open, 0 = close + gripper_action = trajectory["action"]["gripper_closedness_action"][:, None] + gripper_action = tf.clip_by_value(gripper_action, 0, 1) + gripper_action = invert_gripper_actions(gripper_action) + + trajectory["action"] = tf.concat( + ( + trajectory["action"]["world_vector"], + trajectory["action"]["rotation_delta"], + gripper_action, + ), + axis=-1, + ) + # trajectory["language_instruction"] = tf.fill( + # tf.shape(trajectory["observation"]["natural_language_instruction"]), "" + # ) # delete uninformative language instruction + trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] + return trajectory + + +def berkeley_autolab_ur5_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["state"] = trajectory["observation"]["robot_state"][:, 6:14] + trajectory["observation"]["depth"] = trajectory["observation"].pop("image_with_depth") + + # make gripper action absolute action, +1 = open, 0 = close + gripper_action = trajectory["action"]["gripper_closedness_action"] + gripper_action = rel2abs_gripper_actions(gripper_action) + + trajectory["action"] = tf.concat( + ( + trajectory["action"]["world_vector"], + trajectory["action"]["rotation_delta"], + gripper_action[:, None], + ), + axis=-1, + ) + trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] + return trajectory + + +def toto_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["action"] = tf.concat( + ( + trajectory["action"]["world_vector"], + trajectory["action"]["rotation_delta"], + tf.cast(trajectory["action"]["open_gripper"][:, None], tf.float32), + ), + axis=-1, + ) + # trajectory["language_instruction"] = tf.fill( + # tf.shape(trajectory["observation"]["natural_language_instruction"]), "" + # ) # delete uninformative language instruction + trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] + return trajectory + + +def language_table_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + # default to "open" gripper + trajectory["action"] = tf.concat( + ( + trajectory["action"], + tf.zeros_like(trajectory["action"]), + tf.zeros_like(trajectory["action"]), + tf.ones_like(trajectory["action"][:, :1]), + ), + axis=-1, + ) + + # decode language instruction + instruction_bytes = trajectory["observation"]["instruction"] + instruction_encoded = tf.strings.unicode_encode(instruction_bytes, output_encoding="UTF-8") + # Remove trailing padding --> convert RaggedTensor to regular Tensor. + trajectory["language_instruction"] = tf.strings.split(instruction_encoded, "\x00")[:, :1].to_tensor()[:, 0] + return trajectory + + +def pusht_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["action"] = tf.concat( + ( + trajectory["action"]["world_vector"], + trajectory["action"]["rotation_delta"], + trajectory["action"]["gripper_closedness_action"][:, None], + ), + axis=-1, + ) + trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] + return trajectory + + +def stanford_kuka_multimodal_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["depth_image"] = trajectory["observation"]["depth_image"][..., 0] + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :3], + tf.zeros_like(trajectory["action"][:, :3]), + trajectory["action"][:, -1:], + ), + axis=-1, + ) + return trajectory + + +def nyu_rot_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][..., :6] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][..., -1:] + trajectory["action"] = trajectory["action"][..., :7] + return trajectory + + +def stanford_hydra_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + # invert gripper action, +1 = open, 0 = close + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :6], + invert_gripper_actions(trajectory["action"][:, -1:]), + ), + axis=-1, + ) + + trajectory["observation"]["eef_state"] = tf.concat( + ( + trajectory["observation"]["state"][:, :3], + trajectory["observation"]["state"][:, 7:10], + ), + axis=-1, + ) + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -3:-2] + # trajectory["language_instruction"] = tf.fill( + # tf.shape(trajectory["language_instruction"]), "" + # ) # delete uninformative language instruction + return trajectory + + +def austin_buds_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + # invert gripper action + clip, +1 = open, 0 = close + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :6], + invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)), + ), + axis=-1, + ) + + trajectory["observation"]["state"] = trajectory["observation"]["state"][:, :8] + # trajectory["language_instruction"] = tf.fill( + # tf.shape(trajectory["language_instruction"]), "" + # ) # delete uninformative language instruction + return trajectory + + +def nyu_franka_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["depth"] = tf.cast(trajectory["observation"]["depth"][..., 0], tf.float32) + trajectory["observation"]["depth_additional_view"] = tf.cast( + trajectory["observation"]["depth_additional_view"][..., 0], tf.float32 + ) + trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, -6:] + + # clip gripper action, +1 = open, 0 = close + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, -8:-2], + tf.clip_by_value(trajectory["action"][:, -2:-1], 0, 1), + ), + axis=-1, + ) + + # trajectory["language_instruction"] = tf.fill( + # tf.shape(trajectory["language_instruction"]), "" + # ) # delete uninformative language instruction + return trajectory + + +def maniskill_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][..., 7:8] + return trajectory + + +def furniture_bench_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + import tensorflow_graphics.geometry.transformation as tft + + trajectory["observation"]["state"] = tf.concat( + ( + trajectory["observation"]["state"][:, :7], + trajectory["observation"]["state"][:, -1:], + ), + axis=-1, + ) + + # invert gripper action + clip, +1 = open, 0 = close + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :3], + tft.euler.from_quaternion(trajectory["action"][:, 3:7]), + invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)), + ), + axis=-1, + ) + return trajectory + + +def cmu_franka_exploration_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["action"] = trajectory["action"][..., :-1] + return trajectory + + +def ucsd_kitchen_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["joint_state"] = trajectory["observation"]["state"][:, :7] + trajectory["action"] = trajectory["action"][..., :-1] + return trajectory + + +def ucsd_pick_place_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :3], + tf.zeros_like(trajectory["action"][:, :3]), + trajectory["action"][:, -1:], + ), + axis=-1, + ) + return trajectory + + +def austin_sailor_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + # invert gripper action + clip, +1 = open, 0 = close + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :6], + invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)), + ), + axis=-1, + ) + + # trajectory["language_instruction"] = tf.fill( + # tf.shape(trajectory["language_instruction"]), "" + # ) # delete uninformative language instruction + return trajectory + + +def austin_sirius_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + # invert gripper action + clip, +1 = open, 0 = close + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :6], + invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)), + ), + axis=-1, + ) + + # trajectory["language_instruction"] = tf.fill( + # tf.shape(trajectory["language_instruction"]), "" + # ) # delete uninformative language instruction + return trajectory + + +def bc_z_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["action"] = tf.concat( + ( + trajectory["action"]["future/xyz_residual"][:, :3], + trajectory["action"]["future/axis_angle_residual"][:, :3], + invert_gripper_actions(tf.cast(trajectory["action"]["future/target_close"][:, :1], tf.float32)), + ), + axis=-1, + ) + trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] + return trajectory + + +def tokyo_pr2_opening_fridge_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] + trajectory["action"] = trajectory["action"][..., :-1] + return trajectory + + +def tokyo_pr2_tabletop_manipulation_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] + trajectory["action"] = trajectory["action"][..., :-1] + return trajectory + + +def utokyo_xarm_pick_place_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + return trajectory + + +def utokyo_xarm_bimanual_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["action"] = trajectory["action"][..., -7:] + return trajectory + + +def robo_net_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["eef_state"] = tf.concat( + ( + trajectory["observation"]["state"][:, :4], + tf.zeros_like(trajectory["observation"]["state"][:, :2]), + ), + axis=-1, + ) + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :4], + tf.zeros_like(trajectory["action"][:, :2]), + trajectory["action"][:, -1:], + ), + axis=-1, + ) + return trajectory + + +def berkeley_mvp_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + return trajectory + + +def berkeley_rpt_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + return trajectory + + +def kaist_nonprehensible_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["state"] = trajectory["observation"]["state"][:, -7:] + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :6], + tf.zeros_like(trajectory["action"][:, :1]), + ), + axis=-1, + ) + return trajectory + + +def stanford_mask_vit_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["eef_state"] = tf.concat( + ( + trajectory["observation"]["end_effector_pose"][:, :4], + tf.zeros_like(trajectory["observation"]["end_effector_pose"][:, :2]), + ), + axis=-1, + ) + trajectory["observation"]["gripper_state"] = trajectory["observation"]["end_effector_pose"][:, -1:] + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :4], + tf.zeros_like(trajectory["action"][:, :2]), + trajectory["action"][:, -1:], + ), + axis=-1, + ) + return trajectory + + +def tokyo_lsmo_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] + return trajectory + + +def dlr_sara_pour_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + return trajectory + + +def dlr_sara_grid_clamp_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["state"] = trajectory["observation"]["state"][:, :6] + return trajectory + + +def dlr_edan_shared_control_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + # invert gripper action, +1 = open, 0 = close + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :6], + invert_gripper_actions(trajectory["action"][:, -1:]), + ), + axis=-1, + ) + return trajectory + + +def asu_table_top_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["eef_state"] = trajectory["ground_truth_states"]["EE"] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] + return trajectory + + +def robocook_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] + return trajectory + + +def imperial_wristcam_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["action"] = trajectory["action"][..., :-1] + return trajectory + + +def iamlab_pick_insert_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + import tensorflow_graphics.geometry.transformation as tft + + trajectory["observation"]["joint_state"] = trajectory["observation"]["state"][:, :7] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, 7:8] + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :3], + tft.euler.from_quaternion(trajectory["action"][:, 3:7]), + trajectory["action"][:, 7:8], + ), + axis=-1, + ) + return trajectory + + +def uiuc_d3field_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["action"] = tf.concat( + ( + trajectory["action"], + tf.zeros_like(trajectory["action"]), + tf.zeros_like(trajectory["action"][:, :1]), + ), + axis=-1, + ) + return trajectory + + +def utaustin_mutex_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["state"] = trajectory["observation"]["state"][:, :8] + + # invert gripper action + clip, +1 = open, 0 = close + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :6], + invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)), + ), + axis=-1, + ) + + # trajectory["language_instruction"] = tf.fill( + # tf.shape(trajectory["language_instruction"]), "" + # ) # delete uninformative language instruction + return trajectory + + +def berkeley_fanuc_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["joint_state"] = trajectory["observation"]["state"][:, :6] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, 6:7] + + # dataset does not store gripper actions, so use gripper state info, invert so +1 = open, 0 = close + trajectory["action"] = tf.concat( + ( + trajectory["action"], + invert_gripper_actions(trajectory["observation"]["gripper_state"]), + ), + axis=-1, + ) + return trajectory + + +def cmu_playing_with_food_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + import tensorflow_graphics.geometry.transformation as tft + + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :3], + tft.euler.from_quaternion(trajectory["action"][:, 3:7]), + trajectory["action"][:, -1:], + ), + axis=-1, + ) + return trajectory + + +def playfusion_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :3], + trajectory["action"][:, -4:], + ), + axis=-1, + ) + return trajectory + + +def cmu_stretch_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["eef_state"] = tf.concat( + ( + trajectory["observation"]["state"][:, :3], + tf.zeros_like(trajectory["observation"]["state"][:, :3]), + ), + axis=-1, + ) + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] + trajectory["action"] = trajectory["action"][..., :-1] + return trajectory + + +def gnm_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["state"] = tf.concat( + ( + trajectory["observation"]["position"], + tf.zeros_like(trajectory["observation"]["state"][:, :3]), + trajectory["observation"]["yaw"], + ), + axis=-1, + ) + trajectory["action"] = tf.concat( + ( + trajectory["action"], + tf.zeros_like(trajectory["action"]), + tf.zeros_like(trajectory["action"]), + tf.zeros_like(trajectory["action"][:, :1]), + ), + axis=-1, + ) + return trajectory + + +def fmb_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + # every input feature is batched, ie has leading batch dimension + trajectory["observation"]["proprio"] = tf.concat( + ( + trajectory["observation"]["eef_pose"], + trajectory["observation"]["state_gripper_pose"][..., None], + ), + axis=-1, + ) + return trajectory + + +def dobbe_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + # every input feature is batched, ie has leading batch dimension + trajectory["observation"]["proprio"] = trajectory["observation"]["state"] + return trajectory + + +def roboset_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + # every input feature is batched, ie has leading batch dimension + trajectory["observation"]["proprio"] = trajectory["observation"]["state"] + + # gripper action is in -1...1 --> clip to 0...1, flip + gripper_action = trajectory["action"][:, -1:] + gripper_action = invert_gripper_actions(tf.clip_by_value(gripper_action, 0, 1)) + + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :7], + gripper_action, + ), + axis=-1, + ) + return trajectory + + +def rh20t_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["action"] = tf.concat( + ( + trajectory["action"]["tcp_base"], + tf.cast(trajectory["action"]["gripper"][:, None], tf.float32), + ), + axis=-1, + ) + trajectory["observation"]["proprio"] = tf.concat( + ( + trajectory["observation"]["tcp_base"], + trajectory["observation"]["gripper_width"][..., None], + ), + axis=-1, + ) + return trajectory + + +def tdroid_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["action"] = tf.concat( + [ + trajectory["action"][:, :6], + binarize_gripper_actions(trajectory["action"][:, -1])[:, None], + ], + axis=1, + ) + trajectory["observation"]["EEF_state"] = trajectory["observation"]["cartesian_position"][:, :6] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["gripper_position"][:, -1:] + return trajectory + + +# === Registry === +OXE_STANDARDIZATION_TRANSFORMS = { + "bridge_oxe": bridge_oxe_dataset_transform, + "bridge_orig": bridge_orig_dataset_transform, + "bridge_dataset": bridge_orig_dataset_transform, + "ppgm": ppgm_dataset_transform, + "ppgm_static": ppgm_dataset_transform, + "ppgm_wrist": ppgm_dataset_transform, + "fractal20220817_data": rt1_dataset_transform, + "kuka": kuka_dataset_transform, + "taco_play": taco_play_dataset_transform, + "jaco_play": jaco_play_dataset_transform, + "berkeley_cable_routing": berkeley_cable_routing_dataset_transform, + "roboturk": roboturk_dataset_transform, + "nyu_door_opening_surprising_effectiveness": nyu_door_opening_dataset_transform, + "viola": viola_dataset_transform, + "berkeley_autolab_ur5": berkeley_autolab_ur5_dataset_transform, + "toto": toto_dataset_transform, + "language_table": language_table_dataset_transform, + "columbia_cairlab_pusht_real": pusht_dataset_transform, + "stanford_kuka_multimodal_dataset_converted_externally_to_rlds": stanford_kuka_multimodal_dataset_transform, + "nyu_rot_dataset_converted_externally_to_rlds": nyu_rot_dataset_transform, + "stanford_hydra_dataset_converted_externally_to_rlds": stanford_hydra_dataset_transform, + "austin_buds_dataset_converted_externally_to_rlds": austin_buds_dataset_transform, + "nyu_franka_play_dataset_converted_externally_to_rlds": nyu_franka_play_dataset_transform, + "maniskill_dataset_converted_externally_to_rlds": maniskill_dataset_transform, + "furniture_bench_dataset_converted_externally_to_rlds": furniture_bench_dataset_transform, + "cmu_franka_exploration_dataset_converted_externally_to_rlds": cmu_franka_exploration_dataset_transform, + "ucsd_kitchen_dataset_converted_externally_to_rlds": ucsd_kitchen_dataset_transform, + "ucsd_pick_and_place_dataset_converted_externally_to_rlds": ucsd_pick_place_dataset_transform, + "austin_sailor_dataset_converted_externally_to_rlds": austin_sailor_dataset_transform, + "austin_sirius_dataset_converted_externally_to_rlds": austin_sirius_dataset_transform, + "bc_z": bc_z_dataset_transform, + "utokyo_pr2_opening_fridge_converted_externally_to_rlds": tokyo_pr2_opening_fridge_dataset_transform, + "utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds": tokyo_pr2_tabletop_manipulation_dataset_transform, + "utokyo_xarm_pick_and_place_converted_externally_to_rlds": utokyo_xarm_pick_place_dataset_transform, + "utokyo_xarm_bimanual_converted_externally_to_rlds": utokyo_xarm_bimanual_dataset_transform, + "robo_net": robo_net_dataset_transform, + "berkeley_mvp_converted_externally_to_rlds": berkeley_mvp_dataset_transform, + "berkeley_rpt_converted_externally_to_rlds": berkeley_rpt_dataset_transform, + "kaist_nonprehensile_converted_externally_to_rlds": kaist_nonprehensible_dataset_transform, + "stanford_mask_vit_converted_externally_to_rlds": stanford_mask_vit_dataset_transform, + "tokyo_u_lsmo_converted_externally_to_rlds": tokyo_lsmo_dataset_transform, + "dlr_sara_pour_converted_externally_to_rlds": dlr_sara_pour_dataset_transform, + "dlr_sara_grid_clamp_converted_externally_to_rlds": dlr_sara_grid_clamp_dataset_transform, + "dlr_edan_shared_control_converted_externally_to_rlds": dlr_edan_shared_control_dataset_transform, + "asu_table_top_converted_externally_to_rlds": asu_table_top_dataset_transform, + "stanford_robocook_converted_externally_to_rlds": robocook_dataset_transform, + "imperialcollege_sawyer_wrist_cam": imperial_wristcam_dataset_transform, + "iamlab_cmu_pickup_insert_converted_externally_to_rlds": iamlab_pick_insert_dataset_transform, + "uiuc_d3field": uiuc_d3field_dataset_transform, + "utaustin_mutex": utaustin_mutex_dataset_transform, + "berkeley_fanuc_manipulation": berkeley_fanuc_dataset_transform, + "cmu_playing_with_food": cmu_playing_with_food_dataset_transform, + "cmu_play_fusion": playfusion_dataset_transform, + "cmu_stretch": cmu_stretch_dataset_transform, + "berkeley_gnm_recon": gnm_dataset_transform, + "berkeley_gnm_cory_hall": gnm_dataset_transform, + "berkeley_gnm_sac_son": gnm_dataset_transform, + "droid": droid_baseact_transform_fn, + "fmb_dataset": fmb_dataset_transform, + "dobbe": dobbe_dataset_transform, + "roboset": roboset_dataset_transform, + "rh20t": rh20t_dataset_transform, + ### T-DROID datasets + "tdroid_carrot_in_bowl": tdroid_dataset_transform, + "tdroid_pour_corn_in_pot": tdroid_dataset_transform, + "tdroid_flip_pot_upright": tdroid_dataset_transform, + "tdroid_move_object_onto_plate": tdroid_dataset_transform, + "tdroid_knock_object_over": tdroid_dataset_transform, + "tdroid_cover_object_with_towel": tdroid_dataset_transform, + ### DROID Finetuning datasets + "droid_wipe": droid_finetuning_transform_fn, +} \ No newline at end of file diff --git a/lerobot/common/datasets/push_dataset_to_hub/oxe_rlds_format.py b/lerobot/common/datasets/push_dataset_to_hub/oxe_rlds_format.py index 0274062f..2d20e8f5 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/oxe_rlds_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/oxe_rlds_format.py @@ -7,11 +7,10 @@ Example: python lerobot/scripts/push_dataset_to_hub.py \ --raw-dir /hdd/tensorflow_datasets/bridge_dataset/1.0.0/ \ --repo-id youliangtan/sampled_bridge_data_v2 \ - --raw-format oxe_rlds \ - --episodes 3 4 5 8 9 \ - --fps 5 + --raw-format oxe_rlds.bridge_orig \ + --episodes 3 4 5 8 9 -Exact dataset fps is specified in: +Exact dataset fps defined in oxe/config.py, obtained from: https://docs.google.com/spreadsheets/d/1rPBD77tk60AEIGZrGSODwyyzs5FgCU9Uz3h-3_t2A9g/edit?gid=0#gid=0&range=R:R """ @@ -19,12 +18,15 @@ import shutil from pathlib import Path import numpy as np +import tensorflow as tf import tensorflow_datasets as tfds import torch import tqdm from datasets import Dataset, Features, Image, Sequence, Value from PIL import Image as PILImage +from lerobot.common.datasets.push_dataset_to_hub.oxe.configs import OXE_DATASET_CONFIGS +from lerobot.common.datasets.push_dataset_to_hub.oxe.transforms import OXE_STANDARDIZATION_TRANSFORMS from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently from lerobot.common.datasets.utils import ( calculate_episode_data_index, @@ -43,7 +45,51 @@ def tf_to_torch(data): return torch.from_numpy(data.numpy()) -def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None): +def tf_img_convert(img): + if img.dtype == tf.string: + img = tf.io.decode_image(img, expand_animations=False, dtype=tf.uint8) + elif img.dtype != tf.uint8: + raise ValueError(f"Unsupported image dtype: found with dtype {img.dtype}") + return img.numpy() + + +def _broadcast_metadata_rlds(i: tf.Tensor, traj: dict) -> dict: + """ + In the RLDS format, each trajectory has some top-level metadata that is explicitly separated out, and a "steps" + entry. This function moves the "steps" entry to the top level, broadcasting any metadata to the length of the + trajectory. This function also adds the extra metadata fields `_len`, `_traj_index`, and `_frame_index`. + + NOTE: adapted from DLimp library https://github.com/kvablack/dlimp/ + """ + steps = traj.pop("steps") + + traj_len = tf.shape(tf.nest.flatten(steps)[0])[0] + + # broadcast metadata to the length of the trajectory + metadata = tf.nest.map_structure(lambda x: tf.repeat(x, traj_len), traj) + + # put steps back in + assert "traj_metadata" not in steps + traj = {**steps, "traj_metadata": metadata} + + assert "_len" not in traj + assert "_traj_index" not in traj + assert "_frame_index" not in traj + traj["_len"] = tf.repeat(traj_len, traj_len) + traj["_traj_index"] = tf.repeat(i, traj_len) + traj["_frame_index"] = tf.range(traj_len) + + return traj + + +def load_from_raw( + raw_dir: Path, + videos_dir: Path, + fps: int, + video: bool, + episodes: list[int], + oxe_dataset_name: str | None = None, +): """ Args: raw_dir (Path): _description_ @@ -53,22 +99,32 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod episodes (list[int] | None, optional): _description_. Defaults to None. """ ds_builder = tfds.builder_from_directory(str(raw_dir)) - dataset = ds_builder.as_dataset(split="all") + dataset = ds_builder.as_dataset( + split="all", + decoders={"steps": tfds.decode.SkipDecoding()}, + ) dataset_info = ds_builder.info print("dataset_info: ", dataset_info) - image_keys = get_cameras_keys(dataset_info.features["steps"]["observation"].keys()) + ds_length = len(dataset) + dataset = dataset.take(ds_length) - # check if there's a 'tfds.features.Text' in step, only take 1 lang instruction - lang_key = [ - key for key, value in dataset_info.features["steps"].items() if isinstance(value, tfds.features.Text) - ] - lang_key = None if len(lang_key) == 0 else lang_key[0] + # "flatten" the dataset as such we can apply trajectory level map() easily + # each [obs][key] has a shape of (frame_size, ...) + dataset = dataset.enumerate().map(_broadcast_metadata_rlds) + + # we will apply the standardization transform if the dataset_name is provided + if oxe_dataset_name is not None: + print(" - applying standardization transform for dataset: ", oxe_dataset_name) + assert oxe_dataset_name in OXE_STANDARDIZATION_TRANSFORMS + transform_fn = OXE_STANDARDIZATION_TRANSFORMS[oxe_dataset_name] + dataset = dataset.map(transform_fn) + + image_keys = get_cameras_keys(dataset_info.features["steps"]["observation"].keys()) + lang_key = "language_instruction" if "language_instruction" in dataset.element_spec else None print(" - image_keys: ", image_keys) print(" - lang_key: ", lang_key) - ds_length = len(dataset) - dataset = dataset.take(ds_length) it = iter(dataset) ep_dicts = [] @@ -83,7 +139,7 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod for ep_idx in tqdm.tqdm(range(ds_length)): episode = next(it) - # if we user specified episodes, skip the ones not in the list + # if user specified episodes, skip the ones not in the list if episodes is not None: if len(episodes) == 0: break @@ -94,38 +150,49 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod else: continue # skip - steps = episode["steps"] - num_frames = len(steps) + num_frames = episode["action"].shape[0] + + ########################################################### + # Handle the episodic data # last step of demonstration is considered done done = torch.zeros(num_frames, dtype=torch.bool) done[-1] = True - - states = [] - actions = [] # TODO(YL): some actions can be a featuredict - rewards = torch.zeros(num_frames, dtype=torch.float32) ep_dict = {} - langs = [] + langs = [] # TODO: might be located in "observation" image_array_dict = {key: [] for key in image_keys} - ########################################################### - # loop through all steps in the episode - for j, step in enumerate(steps): - states.append(tf_to_torch(step["observation"]["state"])) - actions.append(tf_to_torch(step["action"])) - rewards[j] = torch.tensor(step["reward"].numpy(), dtype=torch.float32) + # We will create the state observation tensor by stacking the state + # obs keys defined in the oxe/configs.py + if oxe_dataset_name is not None: + state_obs_keys = OXE_DATASET_CONFIGS[oxe_dataset_name]["state_obs_keys"] + # stack the state observations, if is None, pad with zeros + states = [] + for key in state_obs_keys: + if key in episode["observation"]: + states.append(tf_to_torch(episode["observation"][key])) + else: + states.append(torch.zeros(num_frames, 1)) # pad with zeros + states = torch.cat(states, dim=1) + assert states.shape == (num_frames, 8) + else: + states = tf_to_torch(episode["observation"]["state"]) - if lang_key is not None: - langs.append(str(step[lang_key])) + actions = tf_to_torch(episode["action"]) + rewards = tf_to_torch(episode["reward"]).float() - for im_key in image_keys: - if im_key not in step["observation"]: - continue + # If lang_key is present, convert the entire tensor at once + if lang_key is not None: + langs = [str(x) for x in episode[lang_key]] - img = step["observation"][im_key] - img = np.array(img) - image_array_dict[im_key].append(img) + for im_key in image_keys: + imgs = episode["observation"][im_key] + image_array_dict[im_key] = [tf_img_convert(img) for img in imgs] + + # simple assertions + for item in [states, actions, rewards, done]: + assert len(item) == num_frames ########################################################### @@ -157,12 +224,12 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod if lang_key is not None: ep_dict["language_instruction"] = langs - ep_dict["observation.state"] = torch.stack(states) # TODO better way - ep_dict["action"] = torch.stack(actions) + ep_dict["observation.state"] = states + ep_dict["action"] = actions ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames) ep_dict["frame_index"] = torch.arange(0, num_frames, 1) - ep_dict["reward"] = rewards + ep_dict["next.reward"] = rewards ep_dict["next.done"] = done ep_dicts.append(ep_dict) @@ -204,7 +271,7 @@ def to_hf_dataset(data_dict, video) -> Dataset: features["episode_index"] = Value(dtype="int64", id=None) features["frame_index"] = Value(dtype="int64", id=None) features["timestamp"] = Value(dtype="float32", id=None) - features["reward"] = Value(dtype="float32", id=None) + features["next.reward"] = Value(dtype="float32", id=None) features["next.done"] = Value(dtype="bool", id=None) features["index"] = Value(dtype="int64", id=None) @@ -219,12 +286,22 @@ def from_raw_to_lerobot_format( fps: int | None = None, video: bool = True, episodes: list[int] | None = None, + oxe_dataset_name: str | None = None, ): """This is a test impl for rlds conversion""" if fps is None: - fps = 5 + if oxe_dataset_name is not None: + if "fps" not in OXE_DATASET_CONFIGS[oxe_dataset_name]: + raise ValueError( + "fps for this dataset is not specified in oxe/configs.py yet," + "means it is not yet tested" + ) + fps = OXE_DATASET_CONFIGS[oxe_dataset_name]["fps"] + else: + print(" - WARNING: fps is not provided, using default value of 5 fps") + fps = 5 - data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes) + data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, oxe_dataset_name) hf_dataset = to_hf_dataset(data_dict, video) episode_data_index = calculate_episode_data_index(hf_dataset) info = { diff --git a/lerobot/scripts/push_dataset_to_hub.py b/lerobot/scripts/push_dataset_to_hub.py index e473aa0c..a65cc6b0 100644 --- a/lerobot/scripts/push_dataset_to_hub.py +++ b/lerobot/scripts/push_dataset_to_hub.py @@ -65,7 +65,7 @@ def get_from_raw_to_lerobot_format_fn(raw_format: str): from lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format import from_raw_to_lerobot_format elif raw_format == "aloha_hdf5": from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import from_raw_to_lerobot_format - elif raw_format == "oxe_rlds": + elif "oxe_rlds" in raw_format: from lerobot.common.datasets.push_dataset_to_hub.oxe_rlds_format import from_raw_to_lerobot_format elif raw_format == "dora_parquet": from lerobot.common.datasets.push_dataset_to_hub.dora_parquet_format import from_raw_to_lerobot_format @@ -192,9 +192,27 @@ def push_dataset_to_hub( # convert dataset from original raw format to LeRobot format from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format) - hf_dataset, episode_data_index, info = from_raw_to_lerobot_format( - raw_dir, videos_dir, fps, video, episodes - ) + + if "oxe_rlds" in raw_format: + # User could provide official OXE dataset name to convert it to LeRobot format + # the raw_format str is as such: + # oxe_rlds (default) + # oxe_rlds.bridge_orig: (with bridge_orig as oxe_dataset_name) + splited_raw_format = raw_format.split(".") + assert len(splited_raw_format) <= 2, f"Invalid raw_format: {raw_format}" + if len(splited_raw_format) == 2: + oxe_dataset_name = splited_raw_format[1] + print(f"Converting dataset [{oxe_dataset_name}] from 'oxe_rlds' to LeRobot format.") + else: + oxe_dataset_name = None + + hf_dataset, episode_data_index, info = from_raw_to_lerobot_format( + raw_dir, videos_dir, fps, video, episodes, oxe_dataset_name=oxe_dataset_name + ) + else: + hf_dataset, episode_data_index, info = from_raw_to_lerobot_format( + raw_dir, videos_dir, fps, video, episodes + ) lerobot_dataset = LeRobotDataset.from_preloaded( repo_id=repo_id,