apply oxe configs and transform for dataset standardization
Signed-off-by: youliangtan <tan_you_liang@hotmail.com>
This commit is contained in:
parent
85fec65b3e
commit
3380665e3e
|
@ -0,0 +1,651 @@
|
|||
"""
|
||||
NOTE(YL): Adapted from:
|
||||
OpenVLA: https://github.com/openvla/openvla/blob/main/prismatic/vla/datasets/rlds/oxe/configs.py
|
||||
Octo: https://github.com/octo-models/octo/blob/main/octo/data/oxe/oxe_dataset_configs.py
|
||||
|
||||
TODO: implement the following:
|
||||
- Populate all `fps` for each dataset
|
||||
- Upload the dataset config to the Readme of each dataset on huggingface hub, for verbosity
|
||||
|
||||
configs.py
|
||||
|
||||
Defines per-dataset configuration (kwargs) for each dataset in Open-X Embodiment.
|
||||
|
||||
Configuration adopts the following structure:
|
||||
image_obs_keys:
|
||||
primary: primary external RGB
|
||||
secondary: secondary external RGB
|
||||
wrist: wrist RGB
|
||||
|
||||
depth_obs_keys:
|
||||
primary: primary external depth
|
||||
secondary: secondary external depth
|
||||
wrist: wrist depth
|
||||
|
||||
# Always 8-dim =>> changes based on `StateEncoding`
|
||||
state_obs_keys:
|
||||
StateEncoding.POS_EULER: EEF XYZ (3) + Roll-Pitch-Yaw (3) + <PAD> (1) + Gripper Open/Close (1)
|
||||
StateEncoding.POS_QUAT: EEF XYZ (3) + Quaternion (4) + Gripper Open/Close (1)
|
||||
StateEncoding.JOINT: Joint Angles (7, <PAD> if fewer) + Gripper Open/Close (1)
|
||||
|
||||
state_encoding: Type of `StateEncoding`
|
||||
action_encoding: Type of action encoding (e.g., EEF Position vs. Joint Position)
|
||||
"""
|
||||
|
||||
from enum import IntEnum
|
||||
|
||||
|
||||
# Defines Proprioceptive State Encoding Schemes
|
||||
class StateEncoding(IntEnum):
|
||||
# fmt: off
|
||||
NONE = -1 # No Proprioceptive State
|
||||
POS_EULER = 1 # EEF XYZ (3) + Roll-Pitch-Yaw (3) + <PAD> (1) + Gripper Open/Close (1)
|
||||
POS_QUAT = 2 # EEF XYZ (3) + Quaternion (4) + Gripper Open/Close (1)
|
||||
JOINT = 3 # Joint Angles (7, <PAD> if fewer) + Gripper Open/Close (1)
|
||||
JOINT_BIMANUAL = 4 # Joint Angles (2 x [ Joint Angles (6) + Gripper Open/Close (1) ])
|
||||
# fmt: on
|
||||
|
||||
|
||||
# Defines Action Encoding Schemes
|
||||
class ActionEncoding(IntEnum):
|
||||
# fmt: off
|
||||
EEF_POS = 1 # EEF Delta XYZ (3) + Roll-Pitch-Yaw (3) + Gripper Open/Close (1)
|
||||
JOINT_POS = 2 # Joint Delta Position (7) + Gripper Open/Close (1)
|
||||
JOINT_POS_BIMANUAL = 3 # Joint Delta Position (2 x [ Joint Delta Position (6) + Gripper Open/Close (1) ])
|
||||
EEF_R6 = 4 # EEF Delta XYZ (3) + R6 (6) + Gripper Open/Close (1)
|
||||
# fmt: on
|
||||
|
||||
|
||||
# === Individual Dataset Configs ===
|
||||
OXE_DATASET_CONFIGS = {
|
||||
"fractal20220817_data": {
|
||||
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["base_pose_tool_reached", "gripper_closed"],
|
||||
"state_encoding": StateEncoding.POS_QUAT,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"fps": 3, # TODO: placeholder (youliang)
|
||||
},
|
||||
"kuka": {
|
||||
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": [
|
||||
"clip_function_input/base_pose_tool_reached",
|
||||
"gripper_closed",
|
||||
],
|
||||
"state_encoding": StateEncoding.POS_QUAT,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
"bridge_oxe": { # Version of Bridge V2 in Open X-Embodiment mixture
|
||||
"image_obs_keys": {"primary": "image", "secondary": "image_1", "wrist": None},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["EEF_state", None, "gripper_state"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
"bridge_orig": { # Original version of Bridge V2 from project website
|
||||
"image_obs_keys": {"primary": "image_0", "secondary": "image_1", "wrist": None},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["EEF_state", None, "gripper_state"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"fps": 5,
|
||||
},
|
||||
"bridge_dataset": { # Original version of Bridge V2 from project website
|
||||
"image_obs_keys": {"primary": "image_0", "secondary": "image_1", "wrist": None},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["EEF_state", None, "gripper_state"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"fps": 5,
|
||||
},
|
||||
"taco_play": {
|
||||
"image_obs_keys": {
|
||||
"primary": "rgb_static",
|
||||
"secondary": None,
|
||||
"wrist": "rgb_gripper",
|
||||
},
|
||||
"depth_obs_keys": {
|
||||
"primary": "depth_static",
|
||||
"secondary": None,
|
||||
"wrist": "depth_gripper",
|
||||
},
|
||||
"state_obs_keys": ["state_eef", None, "state_gripper"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
"jaco_play": {
|
||||
"image_obs_keys": {
|
||||
"primary": "image",
|
||||
"secondary": None,
|
||||
"wrist": "image_wrist",
|
||||
},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["state_eef", None, "state_gripper"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
"berkeley_cable_routing": {
|
||||
"image_obs_keys": {
|
||||
"primary": "image",
|
||||
"secondary": "top_image",
|
||||
"wrist": "wrist45_image",
|
||||
},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["robot_state", None],
|
||||
"state_encoding": StateEncoding.JOINT,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
"roboturk": {
|
||||
"image_obs_keys": {"primary": "front_rgb", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": [None, None, None, None, None, None, None, None],
|
||||
"state_encoding": StateEncoding.NONE,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
"nyu_door_opening_surprising_effectiveness": {
|
||||
"image_obs_keys": {"primary": None, "secondary": None, "wrist": "image"},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": [None, None, None, None, None, None, None, None],
|
||||
"state_encoding": StateEncoding.NONE,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
"viola": {
|
||||
"image_obs_keys": {
|
||||
"primary": "agentview_rgb",
|
||||
"secondary": None,
|
||||
"wrist": "eye_in_hand_rgb",
|
||||
},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["joint_states", "gripper_states"],
|
||||
"state_encoding": StateEncoding.JOINT,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
"berkeley_autolab_ur5": {
|
||||
"image_obs_keys": {
|
||||
"primary": "image",
|
||||
"secondary": None,
|
||||
"wrist": "hand_image",
|
||||
},
|
||||
"depth_obs_keys": {"primary": "depth", "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["state"],
|
||||
"state_encoding": StateEncoding.POS_QUAT,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
"toto": {
|
||||
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["state", None],
|
||||
"state_encoding": StateEncoding.JOINT,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
"language_table": {
|
||||
"image_obs_keys": {"primary": "rgb", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["effector_translation", None, None, None, None, None, None],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
"columbia_cairlab_pusht_real": {
|
||||
"image_obs_keys": {
|
||||
"primary": "image",
|
||||
"secondary": None,
|
||||
"wrist": "wrist_image",
|
||||
},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["robot_state", None, None, None, None, None, None],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
"stanford_kuka_multimodal_dataset_converted_externally_to_rlds": {
|
||||
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": "depth_image", "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["ee_position", "ee_orientation", None],
|
||||
"state_encoding": StateEncoding.POS_QUAT,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
"nyu_rot_dataset_converted_externally_to_rlds": {
|
||||
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["eef_state", None, "gripper_state"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
"stanford_hydra_dataset_converted_externally_to_rlds": {
|
||||
"image_obs_keys": {
|
||||
"primary": "image",
|
||||
"secondary": None,
|
||||
"wrist": "wrist_image",
|
||||
},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["eef_state", None, "gripper_state"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
"austin_buds_dataset_converted_externally_to_rlds": {
|
||||
"image_obs_keys": {
|
||||
"primary": "image",
|
||||
"secondary": None,
|
||||
"wrist": "wrist_image",
|
||||
},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["state"],
|
||||
"state_encoding": StateEncoding.JOINT,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"fps": 20,
|
||||
},
|
||||
"nyu_franka_play_dataset_converted_externally_to_rlds": {
|
||||
"image_obs_keys": {
|
||||
"primary": "image",
|
||||
"secondary": "image_additional_view",
|
||||
"wrist": None,
|
||||
},
|
||||
"depth_obs_keys": {
|
||||
"primary": "depth",
|
||||
"secondary": "depth_additional_view",
|
||||
"wrist": None,
|
||||
},
|
||||
"state_obs_keys": ["eef_state", None, None],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"fps": 3,
|
||||
},
|
||||
"maniskill_dataset_converted_externally_to_rlds": {
|
||||
"image_obs_keys": {
|
||||
"primary": "image",
|
||||
"secondary": None,
|
||||
"wrist": "wrist_image",
|
||||
},
|
||||
"depth_obs_keys": {
|
||||
"primary": "depth",
|
||||
"secondary": None,
|
||||
"wrist": "wrist_depth",
|
||||
},
|
||||
"state_obs_keys": ["tcp_pose", "gripper_state"],
|
||||
"state_encoding": StateEncoding.POS_QUAT,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
"furniture_bench_dataset_converted_externally_to_rlds": {
|
||||
"image_obs_keys": {
|
||||
"primary": "image",
|
||||
"secondary": None,
|
||||
"wrist": "wrist_image",
|
||||
},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["state"],
|
||||
"state_encoding": StateEncoding.POS_QUAT,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
"cmu_franka_exploration_dataset_converted_externally_to_rlds": {
|
||||
"image_obs_keys": {
|
||||
"primary": "highres_image",
|
||||
"secondary": None,
|
||||
"wrist": None,
|
||||
},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": [None, None, None, None, None, None, None, None],
|
||||
"state_encoding": StateEncoding.NONE,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
"ucsd_kitchen_dataset_converted_externally_to_rlds": {
|
||||
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["joint_state", None],
|
||||
"state_encoding": StateEncoding.JOINT,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
"ucsd_pick_and_place_dataset_converted_externally_to_rlds": {
|
||||
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["eef_state", None, "gripper_state"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
"austin_sailor_dataset_converted_externally_to_rlds": {
|
||||
"image_obs_keys": {
|
||||
"primary": "image",
|
||||
"secondary": None,
|
||||
"wrist": "wrist_image",
|
||||
},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["state"],
|
||||
"state_encoding": StateEncoding.POS_QUAT,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
"austin_sirius_dataset_converted_externally_to_rlds": {
|
||||
"image_obs_keys": {
|
||||
"primary": "image",
|
||||
"secondary": None,
|
||||
"wrist": "wrist_image",
|
||||
},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["state"],
|
||||
"state_encoding": StateEncoding.POS_QUAT,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
"bc_z": {
|
||||
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": [
|
||||
"present/xyz",
|
||||
"present/axis_angle",
|
||||
None,
|
||||
"present/sensed_close",
|
||||
],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
"utokyo_pr2_opening_fridge_converted_externally_to_rlds": {
|
||||
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["eef_state", None, "gripper_state"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
"utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds": {
|
||||
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["eef_state", None, "gripper_state"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
"utokyo_xarm_pick_and_place_converted_externally_to_rlds": {
|
||||
"image_obs_keys": {
|
||||
"primary": "image",
|
||||
"secondary": "image2",
|
||||
"wrist": "hand_image",
|
||||
},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["end_effector_pose", None, None],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
"utokyo_xarm_bimanual_converted_externally_to_rlds": {
|
||||
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["pose_r", None, None],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
"robo_net": {
|
||||
"image_obs_keys": {"primary": "image", "secondary": "image1", "wrist": None},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["eef_state", None, "gripper_state"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
"berkeley_mvp_converted_externally_to_rlds": {
|
||||
"image_obs_keys": {"primary": None, "secondary": None, "wrist": "hand_image"},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["pose", "gripper"],
|
||||
"state_encoding": StateEncoding.POS_QUAT,
|
||||
"action_encoding": ActionEncoding.JOINT_POS,
|
||||
},
|
||||
"berkeley_rpt_converted_externally_to_rlds": {
|
||||
"image_obs_keys": {"primary": None, "secondary": None, "wrist": "hand_image"},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["joint_pos", "gripper"],
|
||||
"state_encoding": StateEncoding.JOINT,
|
||||
"action_encoding": ActionEncoding.JOINT_POS,
|
||||
},
|
||||
"kaist_nonprehensile_converted_externally_to_rlds": {
|
||||
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["state", None],
|
||||
"state_encoding": StateEncoding.POS_QUAT,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
"stanford_mask_vit_converted_externally_to_rlds": {
|
||||
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["eef_state", None, "gripper_state"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
"tokyo_u_lsmo_converted_externally_to_rlds": {
|
||||
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["eef_state", None, "gripper_state"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
"dlr_sara_pour_converted_externally_to_rlds": {
|
||||
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["state", None, None],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
"dlr_sara_grid_clamp_converted_externally_to_rlds": {
|
||||
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["state", None, None],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
"dlr_edan_shared_control_converted_externally_to_rlds": {
|
||||
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["state", None],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
"asu_table_top_converted_externally_to_rlds": {
|
||||
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["eef_state", None, "gripper_state"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
"stanford_robocook_converted_externally_to_rlds": {
|
||||
"image_obs_keys": {"primary": "image_1", "secondary": "image_2", "wrist": None},
|
||||
"depth_obs_keys": {"primary": "depth_1", "secondary": "depth_2", "wrist": None},
|
||||
"state_obs_keys": ["eef_state", None, "gripper_state"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
"imperialcollege_sawyer_wrist_cam": {
|
||||
"image_obs_keys": {
|
||||
"primary": "image",
|
||||
"secondary": None,
|
||||
"wrist": "wrist_image",
|
||||
},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": [None, None, None, None, None, None, None, "state"],
|
||||
"state_encoding": StateEncoding.NONE,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
"iamlab_cmu_pickup_insert_converted_externally_to_rlds": {
|
||||
"image_obs_keys": {
|
||||
"primary": "image",
|
||||
"secondary": None,
|
||||
"wrist": "wrist_image",
|
||||
},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["joint_state", "gripper_state"],
|
||||
"state_encoding": StateEncoding.JOINT,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
"uiuc_d3field": {
|
||||
"image_obs_keys": {"primary": "image_1", "secondary": "image_2", "wrist": None},
|
||||
"depth_obs_keys": {"primary": "depth_1", "secondary": "depth_2", "wrist": None},
|
||||
"state_obs_keys": [None, None, None, None, None, None, None, None],
|
||||
"state_encoding": StateEncoding.NONE,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
"utaustin_mutex": {
|
||||
"image_obs_keys": {
|
||||
"primary": "image",
|
||||
"secondary": None,
|
||||
"wrist": "wrist_image",
|
||||
},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["state"],
|
||||
"state_encoding": StateEncoding.JOINT,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
"berkeley_fanuc_manipulation": {
|
||||
"image_obs_keys": {
|
||||
"primary": "image",
|
||||
"secondary": None,
|
||||
"wrist": "wrist_image",
|
||||
},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["joint_state", None, "gripper_state"],
|
||||
"state_encoding": StateEncoding.JOINT,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
"cmu_playing_with_food": {
|
||||
"image_obs_keys": {
|
||||
"primary": "image",
|
||||
"secondary": None,
|
||||
"wrist": "finger_vision_1",
|
||||
},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["state", None, None],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
"cmu_play_fusion": {
|
||||
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["state"],
|
||||
"state_encoding": StateEncoding.JOINT,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
"cmu_stretch": {
|
||||
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["eef_state", None, "gripper_state"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
"fps": 10,
|
||||
},
|
||||
"berkeley_gnm_recon": {
|
||||
"image_obs_keys": {"primary": None, "secondary": None, "wrist": "image"},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["state", None, None],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
"berkeley_gnm_cory_hall": {
|
||||
"image_obs_keys": {"primary": None, "secondary": None, "wrist": "image"},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["state", None, None],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
"berkeley_gnm_sac_son": {
|
||||
"image_obs_keys": {"primary": None, "secondary": None, "wrist": "image"},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["state", None, None],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
"droid": {
|
||||
"image_obs_keys": {
|
||||
"primary": "exterior_image_1_left",
|
||||
"secondary": "exterior_image_2_left",
|
||||
"wrist": "wrist_image_left",
|
||||
},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["proprio"],
|
||||
"state_encoding": StateEncoding.POS_QUAT,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
"fmb_dataset": {
|
||||
"image_obs_keys": {
|
||||
"primary": "image_side_1",
|
||||
"secondary": "image_side_2",
|
||||
"wrist": "image_wrist_1",
|
||||
},
|
||||
"depth_obs_keys": {
|
||||
"primary": "image_side_1_depth",
|
||||
"secondary": "image_side_2_depth",
|
||||
"wrist": "image_wrist_1_depth",
|
||||
},
|
||||
"state_obs_keys": ["proprio"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
"dobbe": {
|
||||
"image_obs_keys": {"primary": "wrist_image", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["proprio"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
"roboset": {
|
||||
"image_obs_keys": {
|
||||
"primary": "image_left",
|
||||
"secondary": "image_right",
|
||||
"wrist": "image_wrist",
|
||||
},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["proprio"],
|
||||
"state_encoding": StateEncoding.JOINT,
|
||||
"action_encoding": ActionEncoding.JOINT_POS,
|
||||
},
|
||||
"rh20t": {
|
||||
"image_obs_keys": {
|
||||
"primary": "image_front",
|
||||
"secondary": "image_side_right",
|
||||
"wrist": "image_wrist",
|
||||
},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["proprio"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
### T-DROID datasets
|
||||
"tdroid_carrot_in_bowl": { # "put carrot in bowl" task, 50 demos @ 5 Hz control
|
||||
"image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["EEF_state", None, "gripper_state"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
"tdroid_pour_corn_in_pot": { # "pour corn from red bowl into steel pot" task, 50 demos @ 5 Hz control
|
||||
"image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["EEF_state", None, "gripper_state"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
"tdroid_flip_pot_upright": { # "flip pot upright" task, 10 demos @ 5 Hz control
|
||||
"image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["EEF_state", None, "gripper_state"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
"tdroid_move_object_onto_plate": { # "move <object> onto plate" task, 150 demos @ 5 Hz control
|
||||
"image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["EEF_state", None, "gripper_state"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
"tdroid_knock_object_over": { # "knock <object> over" task, 70 demos @ 5 Hz control
|
||||
"image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["EEF_state", None, "gripper_state"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
"tdroid_cover_object_with_towel": { # "cover <object> with towel" task, 45 demos @ 5 Hz control
|
||||
"image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None},
|
||||
"depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["EEF_state", None, "gripper_state"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
### DROID Finetuning datasets
|
||||
"droid_wipe": {
|
||||
"image_obs_keys": {"primary": "exterior_image_2_left", "secondary": None, "wrist": "wrist_image_left"},
|
||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||
"state_obs_keys": ["proprio"],
|
||||
"state_encoding": StateEncoding.POS_EULER,
|
||||
"action_encoding": ActionEncoding.EEF_POS,
|
||||
},
|
||||
}
|
|
@ -0,0 +1,91 @@
|
|||
"""
|
||||
NOTE(YL): Adapted from:
|
||||
Octo: https://github.com/octo-models/octo/blob/main/octo/data/utils/data_utils.py
|
||||
|
||||
data_utils.py
|
||||
|
||||
Additional utils for data processing.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def binarize_gripper_actions(actions: tf.Tensor) -> tf.Tensor:
|
||||
"""
|
||||
Converts gripper actions from continuous to binary values (0 and 1).
|
||||
|
||||
We exploit that fact that most of the time, the gripper is fully open (near 1.0) or fully closed (near 0.0). As it
|
||||
transitions between the two, it sometimes passes through a few intermediate values. We relabel those intermediate
|
||||
values based on the state that is reached _after_ those intermediate values.
|
||||
|
||||
In the edge case that the trajectory ends with an intermediate value, we give up on binarizing and relabel that
|
||||
chunk of intermediate values as the last action in the trajectory.
|
||||
|
||||
The `scan_fn` implements the following logic:
|
||||
new_actions = np.empty_like(actions)
|
||||
carry = actions[-1]
|
||||
for i in reversed(range(actions.shape[0])):
|
||||
if in_between_mask[i]:
|
||||
carry = carry
|
||||
else:
|
||||
carry = float(open_mask[i])
|
||||
new_actions[i] = carry
|
||||
"""
|
||||
open_mask, closed_mask = actions > 0.95, actions < 0.05
|
||||
in_between_mask = tf.logical_not(tf.logical_or(open_mask, closed_mask))
|
||||
is_open_float = tf.cast(open_mask, tf.float32)
|
||||
|
||||
def scan_fn(carry, i):
|
||||
return tf.cond(in_between_mask[i], lambda: tf.cast(carry, tf.float32), lambda: is_open_float[i])
|
||||
|
||||
return tf.scan(scan_fn, tf.range(tf.shape(actions)[0]), actions[-1], reverse=True)
|
||||
|
||||
|
||||
def invert_gripper_actions(actions: tf.Tensor) -> tf.Tensor:
|
||||
return 1 - actions
|
||||
|
||||
|
||||
def rel2abs_gripper_actions(actions: tf.Tensor) -> tf.Tensor:
|
||||
"""
|
||||
Converts relative gripper actions (+1 for closing, -1 for opening) to absolute actions (0 = closed; 1 = open).
|
||||
|
||||
Assumes that the first relative gripper is not redundant (i.e. close when already closed)!
|
||||
"""
|
||||
# Note =>> -1 for closing, 1 for opening, 0 for no change
|
||||
opening_mask, closing_mask = actions < -0.1, actions > 0.1
|
||||
thresholded_actions = tf.where(opening_mask, 1, tf.where(closing_mask, -1, 0))
|
||||
|
||||
def scan_fn(carry, i):
|
||||
return tf.cond(thresholded_actions[i] == 0, lambda: carry, lambda: thresholded_actions[i])
|
||||
|
||||
# If no relative grasp, assumes open for whole trajectory
|
||||
start = -1 * thresholded_actions[tf.argmax(thresholded_actions != 0, axis=0)]
|
||||
start = tf.cond(start == 0, lambda: 1, lambda: start)
|
||||
|
||||
# Note =>> -1 for closed, 1 for open
|
||||
new_actions = tf.scan(scan_fn, tf.range(tf.shape(actions)[0]), start)
|
||||
new_actions = tf.cast(new_actions, tf.float32) / 2 + 0.5
|
||||
|
||||
return new_actions
|
||||
|
||||
|
||||
# === Bridge-V2 =>> Dataset-Specific Transform ===
|
||||
def relabel_bridge_actions(traj: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Relabels actions to use reached proprioceptive state; discards last timestep (no-action)."""
|
||||
movement_actions = traj["observation"]["state"][1:, :6] - traj["observation"]["state"][:-1, :6]
|
||||
traj_truncated = tf.nest.map_structure(lambda x: x[:-1], traj)
|
||||
traj_truncated["action"] = tf.concat([movement_actions, traj["action"][:-1, -1:]], axis=1)
|
||||
|
||||
return traj_truncated
|
||||
|
||||
|
||||
# === RLDS Dataset Initialization Utilities ===
|
||||
def pprint_data_mixture(dataset_kwargs_list: List[Dict[str, Any]], dataset_weights: List[int]) -> None:
|
||||
print("\n######################################################################################")
|
||||
print(f"# Loading the following {len(dataset_kwargs_list)} datasets (incl. sampling weight):{'': >24} #")
|
||||
for dataset_kwargs, weight in zip(dataset_kwargs_list, dataset_weights):
|
||||
pad = 80 - len(dataset_kwargs["name"])
|
||||
print(f"# {dataset_kwargs['name']}: {weight:=>{pad}f} #")
|
||||
print("######################################################################################\n")
|
|
@ -0,0 +1,183 @@
|
|||
"""
|
||||
NOTE(YL): Adapted from:
|
||||
OpenVLA: https://github.com/openvla/openvla
|
||||
|
||||
Episode transforms for DROID dataset.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
import tensorflow as tf
|
||||
import tensorflow_graphics.geometry.transformation as tfg
|
||||
|
||||
|
||||
def rmat_to_euler(rot_mat):
|
||||
return tfg.euler.from_rotation_matrix(rot_mat)
|
||||
|
||||
|
||||
def euler_to_rmat(euler):
|
||||
return tfg.rotation_matrix_3d.from_euler(euler)
|
||||
|
||||
|
||||
def invert_rmat(rot_mat):
|
||||
return tfg.rotation_matrix_3d.inverse(rot_mat)
|
||||
|
||||
|
||||
def rotmat_to_rot6d(mat):
|
||||
"""
|
||||
Converts rotation matrix to R6 rotation representation (first two rows in rotation matrix).
|
||||
Args:
|
||||
mat: rotation matrix
|
||||
|
||||
Returns: 6d vector (first two rows of rotation matrix)
|
||||
|
||||
"""
|
||||
r6 = mat[..., :2, :]
|
||||
r6_0, r6_1 = r6[..., 0, :], r6[..., 1, :]
|
||||
r6_flat = tf.concat([r6_0, r6_1], axis=-1)
|
||||
return r6_flat
|
||||
|
||||
|
||||
def velocity_act_to_wrist_frame(velocity, wrist_in_robot_frame):
|
||||
"""
|
||||
Translates velocity actions (translation + rotation) from base frame of the robot to wrist frame.
|
||||
Args:
|
||||
velocity: 6d velocity action (3 x translation, 3 x rotation)
|
||||
wrist_in_robot_frame: 6d pose of the end-effector in robot base frame
|
||||
|
||||
Returns: 9d velocity action in robot wrist frame (3 x translation, 6 x rotation as R6)
|
||||
|
||||
"""
|
||||
R_frame = euler_to_rmat(wrist_in_robot_frame[:, 3:6])
|
||||
R_frame_inv = invert_rmat(R_frame)
|
||||
|
||||
# world to wrist: dT_pi = R^-1 dT_rbt
|
||||
vel_t = (R_frame_inv @ velocity[:, :3][..., None])[..., 0]
|
||||
|
||||
# world to wrist: dR_pi = R^-1 dR_rbt R
|
||||
dR = euler_to_rmat(velocity[:, 3:6])
|
||||
dR = R_frame_inv @ (dR @ R_frame)
|
||||
dR_r6 = rotmat_to_rot6d(dR)
|
||||
return tf.concat([vel_t, dR_r6], axis=-1)
|
||||
|
||||
|
||||
def rand_swap_exterior_images(img1, img2):
|
||||
"""
|
||||
Randomly swaps the two exterior images (for training with single exterior input).
|
||||
"""
|
||||
return tf.cond(tf.random.uniform(shape=[]) > 0.5, lambda: (img1, img2), lambda: (img2, img1))
|
||||
|
||||
|
||||
def droid_baseact_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
DROID dataset transformation for actions expressed in *base* frame of the robot.
|
||||
"""
|
||||
dt = trajectory["action_dict"]["cartesian_velocity"][:, :3]
|
||||
dR = trajectory["action_dict"]["cartesian_velocity"][:, 3:6]
|
||||
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
dt,
|
||||
dR,
|
||||
1 - trajectory["action_dict"]["gripper_position"],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["observation"]["exterior_image_1_left"], trajectory["observation"]["exterior_image_2_left"] = (
|
||||
rand_swap_exterior_images(
|
||||
trajectory["observation"]["exterior_image_1_left"],
|
||||
trajectory["observation"]["exterior_image_2_left"],
|
||||
)
|
||||
)
|
||||
trajectory["observation"]["proprio"] = tf.concat(
|
||||
(
|
||||
trajectory["observation"]["cartesian_position"],
|
||||
trajectory["observation"]["gripper_position"],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def droid_wristact_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
DROID dataset transformation for actions expressed in *wrist* frame of the robot.
|
||||
"""
|
||||
wrist_act = velocity_act_to_wrist_frame(
|
||||
trajectory["action_dict"]["cartesian_velocity"], trajectory["observation"]["cartesian_position"]
|
||||
)
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
wrist_act,
|
||||
trajectory["action_dict"]["gripper_position"],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["observation"]["exterior_image_1_left"], trajectory["observation"]["exterior_image_2_left"] = (
|
||||
rand_swap_exterior_images(
|
||||
trajectory["observation"]["exterior_image_1_left"],
|
||||
trajectory["observation"]["exterior_image_2_left"],
|
||||
)
|
||||
)
|
||||
trajectory["observation"]["proprio"] = tf.concat(
|
||||
(
|
||||
trajectory["observation"]["cartesian_position"],
|
||||
trajectory["observation"]["gripper_position"],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def droid_finetuning_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
DROID dataset transformation for actions expressed in *base* frame of the robot.
|
||||
"""
|
||||
dt = trajectory["action_dict"]["cartesian_velocity"][:, :3]
|
||||
dR = trajectory["action_dict"]["cartesian_velocity"][:, 3:6]
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
dt,
|
||||
dR,
|
||||
1 - trajectory["action_dict"]["gripper_position"],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["observation"]["proprio"] = tf.concat(
|
||||
(
|
||||
trajectory["observation"]["cartesian_position"],
|
||||
trajectory["observation"]["gripper_position"],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def zero_action_filter(traj: Dict) -> bool:
|
||||
"""
|
||||
Filters transitions whose actions are all-0 (only relative actions, no gripper action).
|
||||
Note: this filter is applied *after* action normalization, so need to compare to "normalized 0".
|
||||
"""
|
||||
DROID_Q01 = tf.convert_to_tensor(
|
||||
[
|
||||
-0.7776297926902771,
|
||||
-0.5803514122962952,
|
||||
-0.5795090794563293,
|
||||
-0.6464047729969025,
|
||||
-0.7041108310222626,
|
||||
-0.8895104378461838,
|
||||
]
|
||||
)
|
||||
DROID_Q99 = tf.convert_to_tensor(
|
||||
[
|
||||
0.7597932070493698,
|
||||
0.5726242214441299,
|
||||
0.7351000607013702,
|
||||
0.6705610305070877,
|
||||
0.6464948207139969,
|
||||
0.8897542208433151,
|
||||
]
|
||||
)
|
||||
DROID_NORM_0_ACT = 2 * (tf.zeros_like(traj["action"][:, :6]) - DROID_Q01) / (DROID_Q99 - DROID_Q01 + 1e-8) - 1
|
||||
|
||||
return tf.reduce_any(tf.math.abs(traj["action"][:, :6] - DROID_NORM_0_ACT) > 1e-5)
|
|
@ -0,0 +1,912 @@
|
|||
"""
|
||||
NOTE(YL): Adapted from:
|
||||
OpenVLA: https://github.com/openvla/openvla
|
||||
Octo: https://github.com/octo-models/octo
|
||||
|
||||
transforms.py
|
||||
|
||||
Defines a registry of per-dataset standardization transforms for each dataset in Open-X Embodiment.
|
||||
|
||||
Transforms adopt the following structure:
|
||||
Input: Dictionary of *batched* features (i.e., has leading time dimension)
|
||||
Output: Dictionary `step` =>> {
|
||||
"observation": {
|
||||
<image_keys, depth_image_keys>
|
||||
State (in chosen state representation)
|
||||
},
|
||||
"action": Action (in chosen action representation),
|
||||
"language_instruction": str
|
||||
}
|
||||
"""
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from lerobot.common.datasets.push_dataset_to_hub.oxe.data_utils import (
|
||||
binarize_gripper_actions,
|
||||
invert_gripper_actions,
|
||||
rel2abs_gripper_actions,
|
||||
relabel_bridge_actions,
|
||||
)
|
||||
|
||||
def droid_baseact_transform_fn():
|
||||
from lerobot.common.datasets.push_dataset_to_hub.oxe.droid_utils import droid_baseact_transform
|
||||
return droid_baseact_transform
|
||||
|
||||
|
||||
def droid_finetuning_transform_fn():
|
||||
from lerobot.common.datasets.push_dataset_to_hub.oxe.droid_utils import droid_finetuning_transform
|
||||
return droid_finetuning_transform
|
||||
|
||||
|
||||
def bridge_oxe_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Applies to version of Bridge V2 in Open X-Embodiment mixture.
|
||||
|
||||
Note =>> In original Bridge V2 dataset, the first timestep has an all-zero action, so we remove it!
|
||||
"""
|
||||
for key in trajectory.keys():
|
||||
if key == "traj_metadata":
|
||||
continue
|
||||
elif key in ["observation", "action"]:
|
||||
for key2 in trajectory[key]:
|
||||
trajectory[key][key2] = trajectory[key][key2][1:]
|
||||
else:
|
||||
trajectory[key] = trajectory[key][1:]
|
||||
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"]["world_vector"],
|
||||
trajectory["action"]["rotation_delta"],
|
||||
tf.cast(trajectory["action"]["open_gripper"][:, None], tf.float32),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
||||
trajectory = relabel_bridge_actions(trajectory)
|
||||
trajectory["observation"]["EEF_state"] = trajectory["observation"]["state"][:, :6]
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
||||
return trajectory
|
||||
|
||||
|
||||
def bridge_orig_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Applies to original version of Bridge V2 from the official project website.
|
||||
|
||||
Note =>> In original Bridge V2 dataset, the first timestep has an all-zero action, so we remove it!
|
||||
"""
|
||||
for key in trajectory.keys():
|
||||
if key == "traj_metadata":
|
||||
continue
|
||||
elif key == "observation":
|
||||
for key2 in trajectory[key]:
|
||||
trajectory[key][key2] = trajectory[key][key2][1:]
|
||||
else:
|
||||
trajectory[key] = trajectory[key][1:]
|
||||
|
||||
trajectory["action"] = tf.concat(
|
||||
[
|
||||
trajectory["action"][:, :6],
|
||||
binarize_gripper_actions(trajectory["action"][:, -1])[:, None],
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
trajectory = relabel_bridge_actions(trajectory)
|
||||
trajectory["observation"]["EEF_state"] = trajectory["observation"]["state"][:, :6]
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
||||
return trajectory
|
||||
|
||||
|
||||
def ppgm_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["action"] = tf.concat(
|
||||
[
|
||||
trajectory["action"][:, :6],
|
||||
binarize_gripper_actions(trajectory["action"][:, -1])[:, None],
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
trajectory["observation"]["EEF_state"] = trajectory["observation"]["cartesian_position"][:, :6]
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["gripper_position"][:, -1:]
|
||||
return trajectory
|
||||
|
||||
|
||||
def rt1_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# make gripper action absolute action, +1 = open, 0 = close
|
||||
gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0]
|
||||
gripper_action = rel2abs_gripper_actions(gripper_action)
|
||||
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"]["world_vector"],
|
||||
trajectory["action"]["rotation_delta"],
|
||||
gripper_action[:, None],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
||||
return trajectory
|
||||
|
||||
|
||||
def kuka_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# make gripper action absolute action, +1 = open, 0 = close
|
||||
gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0]
|
||||
gripper_action = rel2abs_gripper_actions(gripper_action)
|
||||
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"]["world_vector"],
|
||||
trajectory["action"]["rotation_delta"],
|
||||
gripper_action[:, None],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
# decode compressed state
|
||||
eef_value = tf.io.decode_compressed(
|
||||
trajectory["observation"]["clip_function_input/base_pose_tool_reached"],
|
||||
compression_type="ZLIB",
|
||||
)
|
||||
eef_value = tf.io.decode_raw(eef_value, tf.float32)
|
||||
trajectory["observation"]["clip_function_input/base_pose_tool_reached"] = tf.reshape(eef_value, (-1, 7))
|
||||
gripper_value = tf.io.decode_compressed(trajectory["observation"]["gripper_closed"], compression_type="ZLIB")
|
||||
gripper_value = tf.io.decode_raw(gripper_value, tf.float32)
|
||||
trajectory["observation"]["gripper_closed"] = tf.reshape(gripper_value, (-1, 1))
|
||||
# trajectory["language_instruction"] = tf.fill(
|
||||
# tf.shape(trajectory["observation"]["natural_language_instruction"]), ""
|
||||
# ) # delete uninformative language instruction
|
||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
||||
return trajectory
|
||||
|
||||
|
||||
def taco_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["state_eef"] = trajectory["observation"]["robot_obs"][:, :6]
|
||||
trajectory["observation"]["state_gripper"] = trajectory["observation"]["robot_obs"][:, 7:8]
|
||||
trajectory["action"] = trajectory["action"]["rel_actions_world"]
|
||||
|
||||
# invert gripper action + clip, +1 = open, 0 = close
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :6],
|
||||
tf.clip_by_value(trajectory["action"][:, -1:], 0, 1),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
||||
return trajectory
|
||||
|
||||
|
||||
def jaco_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["state_eef"] = trajectory["observation"]["end_effector_cartesian_pos"][:, :6]
|
||||
trajectory["observation"]["state_gripper"] = trajectory["observation"]["end_effector_cartesian_pos"][:, -1:]
|
||||
|
||||
# make gripper action absolute action, +1 = open, 0 = close
|
||||
gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0]
|
||||
gripper_action = rel2abs_gripper_actions(gripper_action)
|
||||
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"]["world_vector"],
|
||||
tf.zeros_like(trajectory["action"]["world_vector"]),
|
||||
gripper_action[:, None],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
||||
return trajectory
|
||||
|
||||
|
||||
def berkeley_cable_routing_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"]["world_vector"],
|
||||
trajectory["action"]["rotation_delta"],
|
||||
tf.zeros_like(trajectory["action"]["world_vector"][:, :1]),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
# trajectory["language_instruction"] = tf.fill(
|
||||
# tf.shape(trajectory["observation"]["natural_language_instruction"]), ""
|
||||
# ) # delete uninformative language instruction
|
||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
||||
return trajectory
|
||||
|
||||
|
||||
def roboturk_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# invert absolute gripper action, +1 = open, 0 = close
|
||||
gripper_action = invert_gripper_actions(tf.clip_by_value(trajectory["action"]["gripper_closedness_action"], 0, 1))
|
||||
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"]["world_vector"],
|
||||
trajectory["action"]["rotation_delta"],
|
||||
gripper_action,
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
# trajectory["language_instruction"] = tf.fill(
|
||||
# tf.shape(trajectory["observation"]["natural_language_instruction"]), ""
|
||||
# ) # delete uninformative language instruction
|
||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
||||
return trajectory
|
||||
|
||||
|
||||
def nyu_door_opening_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# make gripper action absolute action, +1 = open, 0 = close
|
||||
gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0]
|
||||
gripper_action = rel2abs_gripper_actions(gripper_action)
|
||||
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"]["world_vector"],
|
||||
trajectory["action"]["rotation_delta"],
|
||||
gripper_action[:, None],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
# trajectory["language_instruction"] = tf.fill(
|
||||
# tf.shape(trajectory["observation"]["natural_language_instruction"]), ""
|
||||
# ) # delete uninformative language instruction
|
||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
||||
return trajectory
|
||||
|
||||
|
||||
def viola_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# make gripper action, +1 = open, 0 = close
|
||||
gripper_action = trajectory["action"]["gripper_closedness_action"][:, None]
|
||||
gripper_action = tf.clip_by_value(gripper_action, 0, 1)
|
||||
gripper_action = invert_gripper_actions(gripper_action)
|
||||
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"]["world_vector"],
|
||||
trajectory["action"]["rotation_delta"],
|
||||
gripper_action,
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
# trajectory["language_instruction"] = tf.fill(
|
||||
# tf.shape(trajectory["observation"]["natural_language_instruction"]), ""
|
||||
# ) # delete uninformative language instruction
|
||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
||||
return trajectory
|
||||
|
||||
|
||||
def berkeley_autolab_ur5_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["state"] = trajectory["observation"]["robot_state"][:, 6:14]
|
||||
trajectory["observation"]["depth"] = trajectory["observation"].pop("image_with_depth")
|
||||
|
||||
# make gripper action absolute action, +1 = open, 0 = close
|
||||
gripper_action = trajectory["action"]["gripper_closedness_action"]
|
||||
gripper_action = rel2abs_gripper_actions(gripper_action)
|
||||
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"]["world_vector"],
|
||||
trajectory["action"]["rotation_delta"],
|
||||
gripper_action[:, None],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
||||
return trajectory
|
||||
|
||||
|
||||
def toto_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"]["world_vector"],
|
||||
trajectory["action"]["rotation_delta"],
|
||||
tf.cast(trajectory["action"]["open_gripper"][:, None], tf.float32),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
# trajectory["language_instruction"] = tf.fill(
|
||||
# tf.shape(trajectory["observation"]["natural_language_instruction"]), ""
|
||||
# ) # delete uninformative language instruction
|
||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
||||
return trajectory
|
||||
|
||||
|
||||
def language_table_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# default to "open" gripper
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"],
|
||||
tf.zeros_like(trajectory["action"]),
|
||||
tf.zeros_like(trajectory["action"]),
|
||||
tf.ones_like(trajectory["action"][:, :1]),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
# decode language instruction
|
||||
instruction_bytes = trajectory["observation"]["instruction"]
|
||||
instruction_encoded = tf.strings.unicode_encode(instruction_bytes, output_encoding="UTF-8")
|
||||
# Remove trailing padding --> convert RaggedTensor to regular Tensor.
|
||||
trajectory["language_instruction"] = tf.strings.split(instruction_encoded, "\x00")[:, :1].to_tensor()[:, 0]
|
||||
return trajectory
|
||||
|
||||
|
||||
def pusht_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"]["world_vector"],
|
||||
trajectory["action"]["rotation_delta"],
|
||||
trajectory["action"]["gripper_closedness_action"][:, None],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
||||
return trajectory
|
||||
|
||||
|
||||
def stanford_kuka_multimodal_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["depth_image"] = trajectory["observation"]["depth_image"][..., 0]
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :3],
|
||||
tf.zeros_like(trajectory["action"][:, :3]),
|
||||
trajectory["action"][:, -1:],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def nyu_rot_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][..., :6]
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][..., -1:]
|
||||
trajectory["action"] = trajectory["action"][..., :7]
|
||||
return trajectory
|
||||
|
||||
|
||||
def stanford_hydra_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# invert gripper action, +1 = open, 0 = close
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :6],
|
||||
invert_gripper_actions(trajectory["action"][:, -1:]),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
trajectory["observation"]["eef_state"] = tf.concat(
|
||||
(
|
||||
trajectory["observation"]["state"][:, :3],
|
||||
trajectory["observation"]["state"][:, 7:10],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -3:-2]
|
||||
# trajectory["language_instruction"] = tf.fill(
|
||||
# tf.shape(trajectory["language_instruction"]), ""
|
||||
# ) # delete uninformative language instruction
|
||||
return trajectory
|
||||
|
||||
|
||||
def austin_buds_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# invert gripper action + clip, +1 = open, 0 = close
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :6],
|
||||
invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
trajectory["observation"]["state"] = trajectory["observation"]["state"][:, :8]
|
||||
# trajectory["language_instruction"] = tf.fill(
|
||||
# tf.shape(trajectory["language_instruction"]), ""
|
||||
# ) # delete uninformative language instruction
|
||||
return trajectory
|
||||
|
||||
|
||||
def nyu_franka_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["depth"] = tf.cast(trajectory["observation"]["depth"][..., 0], tf.float32)
|
||||
trajectory["observation"]["depth_additional_view"] = tf.cast(
|
||||
trajectory["observation"]["depth_additional_view"][..., 0], tf.float32
|
||||
)
|
||||
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, -6:]
|
||||
|
||||
# clip gripper action, +1 = open, 0 = close
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, -8:-2],
|
||||
tf.clip_by_value(trajectory["action"][:, -2:-1], 0, 1),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
# trajectory["language_instruction"] = tf.fill(
|
||||
# tf.shape(trajectory["language_instruction"]), ""
|
||||
# ) # delete uninformative language instruction
|
||||
return trajectory
|
||||
|
||||
|
||||
def maniskill_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][..., 7:8]
|
||||
return trajectory
|
||||
|
||||
|
||||
def furniture_bench_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
import tensorflow_graphics.geometry.transformation as tft
|
||||
|
||||
trajectory["observation"]["state"] = tf.concat(
|
||||
(
|
||||
trajectory["observation"]["state"][:, :7],
|
||||
trajectory["observation"]["state"][:, -1:],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
# invert gripper action + clip, +1 = open, 0 = close
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :3],
|
||||
tft.euler.from_quaternion(trajectory["action"][:, 3:7]),
|
||||
invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def cmu_franka_exploration_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["action"] = trajectory["action"][..., :-1]
|
||||
return trajectory
|
||||
|
||||
|
||||
def ucsd_kitchen_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["joint_state"] = trajectory["observation"]["state"][:, :7]
|
||||
trajectory["action"] = trajectory["action"][..., :-1]
|
||||
return trajectory
|
||||
|
||||
|
||||
def ucsd_pick_place_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :3],
|
||||
tf.zeros_like(trajectory["action"][:, :3]),
|
||||
trajectory["action"][:, -1:],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def austin_sailor_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# invert gripper action + clip, +1 = open, 0 = close
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :6],
|
||||
invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
# trajectory["language_instruction"] = tf.fill(
|
||||
# tf.shape(trajectory["language_instruction"]), ""
|
||||
# ) # delete uninformative language instruction
|
||||
return trajectory
|
||||
|
||||
|
||||
def austin_sirius_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# invert gripper action + clip, +1 = open, 0 = close
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :6],
|
||||
invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
# trajectory["language_instruction"] = tf.fill(
|
||||
# tf.shape(trajectory["language_instruction"]), ""
|
||||
# ) # delete uninformative language instruction
|
||||
return trajectory
|
||||
|
||||
|
||||
def bc_z_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"]["future/xyz_residual"][:, :3],
|
||||
trajectory["action"]["future/axis_angle_residual"][:, :3],
|
||||
invert_gripper_actions(tf.cast(trajectory["action"]["future/target_close"][:, :1], tf.float32)),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
||||
return trajectory
|
||||
|
||||
|
||||
def tokyo_pr2_opening_fridge_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
||||
trajectory["action"] = trajectory["action"][..., :-1]
|
||||
return trajectory
|
||||
|
||||
|
||||
def tokyo_pr2_tabletop_manipulation_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
||||
trajectory["action"] = trajectory["action"][..., :-1]
|
||||
return trajectory
|
||||
|
||||
|
||||
def utokyo_xarm_pick_place_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return trajectory
|
||||
|
||||
|
||||
def utokyo_xarm_bimanual_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["action"] = trajectory["action"][..., -7:]
|
||||
return trajectory
|
||||
|
||||
|
||||
def robo_net_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["eef_state"] = tf.concat(
|
||||
(
|
||||
trajectory["observation"]["state"][:, :4],
|
||||
tf.zeros_like(trajectory["observation"]["state"][:, :2]),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :4],
|
||||
tf.zeros_like(trajectory["action"][:, :2]),
|
||||
trajectory["action"][:, -1:],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def berkeley_mvp_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return trajectory
|
||||
|
||||
|
||||
def berkeley_rpt_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return trajectory
|
||||
|
||||
|
||||
def kaist_nonprehensible_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["state"] = trajectory["observation"]["state"][:, -7:]
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :6],
|
||||
tf.zeros_like(trajectory["action"][:, :1]),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def stanford_mask_vit_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["eef_state"] = tf.concat(
|
||||
(
|
||||
trajectory["observation"]["end_effector_pose"][:, :4],
|
||||
tf.zeros_like(trajectory["observation"]["end_effector_pose"][:, :2]),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["end_effector_pose"][:, -1:]
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :4],
|
||||
tf.zeros_like(trajectory["action"][:, :2]),
|
||||
trajectory["action"][:, -1:],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def tokyo_lsmo_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
||||
return trajectory
|
||||
|
||||
|
||||
def dlr_sara_pour_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return trajectory
|
||||
|
||||
|
||||
def dlr_sara_grid_clamp_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["state"] = trajectory["observation"]["state"][:, :6]
|
||||
return trajectory
|
||||
|
||||
|
||||
def dlr_edan_shared_control_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# invert gripper action, +1 = open, 0 = close
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :6],
|
||||
invert_gripper_actions(trajectory["action"][:, -1:]),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def asu_table_top_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["eef_state"] = trajectory["ground_truth_states"]["EE"]
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
||||
return trajectory
|
||||
|
||||
|
||||
def robocook_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
||||
return trajectory
|
||||
|
||||
|
||||
def imperial_wristcam_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["action"] = trajectory["action"][..., :-1]
|
||||
return trajectory
|
||||
|
||||
|
||||
def iamlab_pick_insert_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
import tensorflow_graphics.geometry.transformation as tft
|
||||
|
||||
trajectory["observation"]["joint_state"] = trajectory["observation"]["state"][:, :7]
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, 7:8]
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :3],
|
||||
tft.euler.from_quaternion(trajectory["action"][:, 3:7]),
|
||||
trajectory["action"][:, 7:8],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def uiuc_d3field_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"],
|
||||
tf.zeros_like(trajectory["action"]),
|
||||
tf.zeros_like(trajectory["action"][:, :1]),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def utaustin_mutex_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["state"] = trajectory["observation"]["state"][:, :8]
|
||||
|
||||
# invert gripper action + clip, +1 = open, 0 = close
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :6],
|
||||
invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
# trajectory["language_instruction"] = tf.fill(
|
||||
# tf.shape(trajectory["language_instruction"]), ""
|
||||
# ) # delete uninformative language instruction
|
||||
return trajectory
|
||||
|
||||
|
||||
def berkeley_fanuc_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["joint_state"] = trajectory["observation"]["state"][:, :6]
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, 6:7]
|
||||
|
||||
# dataset does not store gripper actions, so use gripper state info, invert so +1 = open, 0 = close
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"],
|
||||
invert_gripper_actions(trajectory["observation"]["gripper_state"]),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def cmu_playing_with_food_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
import tensorflow_graphics.geometry.transformation as tft
|
||||
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :3],
|
||||
tft.euler.from_quaternion(trajectory["action"][:, 3:7]),
|
||||
trajectory["action"][:, -1:],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def playfusion_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :3],
|
||||
trajectory["action"][:, -4:],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def cmu_stretch_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["eef_state"] = tf.concat(
|
||||
(
|
||||
trajectory["observation"]["state"][:, :3],
|
||||
tf.zeros_like(trajectory["observation"]["state"][:, :3]),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
||||
trajectory["action"] = trajectory["action"][..., :-1]
|
||||
return trajectory
|
||||
|
||||
|
||||
def gnm_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["state"] = tf.concat(
|
||||
(
|
||||
trajectory["observation"]["position"],
|
||||
tf.zeros_like(trajectory["observation"]["state"][:, :3]),
|
||||
trajectory["observation"]["yaw"],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"],
|
||||
tf.zeros_like(trajectory["action"]),
|
||||
tf.zeros_like(trajectory["action"]),
|
||||
tf.zeros_like(trajectory["action"][:, :1]),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def fmb_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# every input feature is batched, ie has leading batch dimension
|
||||
trajectory["observation"]["proprio"] = tf.concat(
|
||||
(
|
||||
trajectory["observation"]["eef_pose"],
|
||||
trajectory["observation"]["state_gripper_pose"][..., None],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def dobbe_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# every input feature is batched, ie has leading batch dimension
|
||||
trajectory["observation"]["proprio"] = trajectory["observation"]["state"]
|
||||
return trajectory
|
||||
|
||||
|
||||
def roboset_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# every input feature is batched, ie has leading batch dimension
|
||||
trajectory["observation"]["proprio"] = trajectory["observation"]["state"]
|
||||
|
||||
# gripper action is in -1...1 --> clip to 0...1, flip
|
||||
gripper_action = trajectory["action"][:, -1:]
|
||||
gripper_action = invert_gripper_actions(tf.clip_by_value(gripper_action, 0, 1))
|
||||
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :7],
|
||||
gripper_action,
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def rh20t_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"]["tcp_base"],
|
||||
tf.cast(trajectory["action"]["gripper"][:, None], tf.float32),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["observation"]["proprio"] = tf.concat(
|
||||
(
|
||||
trajectory["observation"]["tcp_base"],
|
||||
trajectory["observation"]["gripper_width"][..., None],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def tdroid_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["action"] = tf.concat(
|
||||
[
|
||||
trajectory["action"][:, :6],
|
||||
binarize_gripper_actions(trajectory["action"][:, -1])[:, None],
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
trajectory["observation"]["EEF_state"] = trajectory["observation"]["cartesian_position"][:, :6]
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["gripper_position"][:, -1:]
|
||||
return trajectory
|
||||
|
||||
|
||||
# === Registry ===
|
||||
OXE_STANDARDIZATION_TRANSFORMS = {
|
||||
"bridge_oxe": bridge_oxe_dataset_transform,
|
||||
"bridge_orig": bridge_orig_dataset_transform,
|
||||
"bridge_dataset": bridge_orig_dataset_transform,
|
||||
"ppgm": ppgm_dataset_transform,
|
||||
"ppgm_static": ppgm_dataset_transform,
|
||||
"ppgm_wrist": ppgm_dataset_transform,
|
||||
"fractal20220817_data": rt1_dataset_transform,
|
||||
"kuka": kuka_dataset_transform,
|
||||
"taco_play": taco_play_dataset_transform,
|
||||
"jaco_play": jaco_play_dataset_transform,
|
||||
"berkeley_cable_routing": berkeley_cable_routing_dataset_transform,
|
||||
"roboturk": roboturk_dataset_transform,
|
||||
"nyu_door_opening_surprising_effectiveness": nyu_door_opening_dataset_transform,
|
||||
"viola": viola_dataset_transform,
|
||||
"berkeley_autolab_ur5": berkeley_autolab_ur5_dataset_transform,
|
||||
"toto": toto_dataset_transform,
|
||||
"language_table": language_table_dataset_transform,
|
||||
"columbia_cairlab_pusht_real": pusht_dataset_transform,
|
||||
"stanford_kuka_multimodal_dataset_converted_externally_to_rlds": stanford_kuka_multimodal_dataset_transform,
|
||||
"nyu_rot_dataset_converted_externally_to_rlds": nyu_rot_dataset_transform,
|
||||
"stanford_hydra_dataset_converted_externally_to_rlds": stanford_hydra_dataset_transform,
|
||||
"austin_buds_dataset_converted_externally_to_rlds": austin_buds_dataset_transform,
|
||||
"nyu_franka_play_dataset_converted_externally_to_rlds": nyu_franka_play_dataset_transform,
|
||||
"maniskill_dataset_converted_externally_to_rlds": maniskill_dataset_transform,
|
||||
"furniture_bench_dataset_converted_externally_to_rlds": furniture_bench_dataset_transform,
|
||||
"cmu_franka_exploration_dataset_converted_externally_to_rlds": cmu_franka_exploration_dataset_transform,
|
||||
"ucsd_kitchen_dataset_converted_externally_to_rlds": ucsd_kitchen_dataset_transform,
|
||||
"ucsd_pick_and_place_dataset_converted_externally_to_rlds": ucsd_pick_place_dataset_transform,
|
||||
"austin_sailor_dataset_converted_externally_to_rlds": austin_sailor_dataset_transform,
|
||||
"austin_sirius_dataset_converted_externally_to_rlds": austin_sirius_dataset_transform,
|
||||
"bc_z": bc_z_dataset_transform,
|
||||
"utokyo_pr2_opening_fridge_converted_externally_to_rlds": tokyo_pr2_opening_fridge_dataset_transform,
|
||||
"utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds": tokyo_pr2_tabletop_manipulation_dataset_transform,
|
||||
"utokyo_xarm_pick_and_place_converted_externally_to_rlds": utokyo_xarm_pick_place_dataset_transform,
|
||||
"utokyo_xarm_bimanual_converted_externally_to_rlds": utokyo_xarm_bimanual_dataset_transform,
|
||||
"robo_net": robo_net_dataset_transform,
|
||||
"berkeley_mvp_converted_externally_to_rlds": berkeley_mvp_dataset_transform,
|
||||
"berkeley_rpt_converted_externally_to_rlds": berkeley_rpt_dataset_transform,
|
||||
"kaist_nonprehensile_converted_externally_to_rlds": kaist_nonprehensible_dataset_transform,
|
||||
"stanford_mask_vit_converted_externally_to_rlds": stanford_mask_vit_dataset_transform,
|
||||
"tokyo_u_lsmo_converted_externally_to_rlds": tokyo_lsmo_dataset_transform,
|
||||
"dlr_sara_pour_converted_externally_to_rlds": dlr_sara_pour_dataset_transform,
|
||||
"dlr_sara_grid_clamp_converted_externally_to_rlds": dlr_sara_grid_clamp_dataset_transform,
|
||||
"dlr_edan_shared_control_converted_externally_to_rlds": dlr_edan_shared_control_dataset_transform,
|
||||
"asu_table_top_converted_externally_to_rlds": asu_table_top_dataset_transform,
|
||||
"stanford_robocook_converted_externally_to_rlds": robocook_dataset_transform,
|
||||
"imperialcollege_sawyer_wrist_cam": imperial_wristcam_dataset_transform,
|
||||
"iamlab_cmu_pickup_insert_converted_externally_to_rlds": iamlab_pick_insert_dataset_transform,
|
||||
"uiuc_d3field": uiuc_d3field_dataset_transform,
|
||||
"utaustin_mutex": utaustin_mutex_dataset_transform,
|
||||
"berkeley_fanuc_manipulation": berkeley_fanuc_dataset_transform,
|
||||
"cmu_playing_with_food": cmu_playing_with_food_dataset_transform,
|
||||
"cmu_play_fusion": playfusion_dataset_transform,
|
||||
"cmu_stretch": cmu_stretch_dataset_transform,
|
||||
"berkeley_gnm_recon": gnm_dataset_transform,
|
||||
"berkeley_gnm_cory_hall": gnm_dataset_transform,
|
||||
"berkeley_gnm_sac_son": gnm_dataset_transform,
|
||||
"droid": droid_baseact_transform_fn,
|
||||
"fmb_dataset": fmb_dataset_transform,
|
||||
"dobbe": dobbe_dataset_transform,
|
||||
"roboset": roboset_dataset_transform,
|
||||
"rh20t": rh20t_dataset_transform,
|
||||
### T-DROID datasets
|
||||
"tdroid_carrot_in_bowl": tdroid_dataset_transform,
|
||||
"tdroid_pour_corn_in_pot": tdroid_dataset_transform,
|
||||
"tdroid_flip_pot_upright": tdroid_dataset_transform,
|
||||
"tdroid_move_object_onto_plate": tdroid_dataset_transform,
|
||||
"tdroid_knock_object_over": tdroid_dataset_transform,
|
||||
"tdroid_cover_object_with_towel": tdroid_dataset_transform,
|
||||
### DROID Finetuning datasets
|
||||
"droid_wipe": droid_finetuning_transform_fn,
|
||||
}
|
|
@ -7,11 +7,10 @@ Example:
|
|||
python lerobot/scripts/push_dataset_to_hub.py \
|
||||
--raw-dir /hdd/tensorflow_datasets/bridge_dataset/1.0.0/ \
|
||||
--repo-id youliangtan/sampled_bridge_data_v2 \
|
||||
--raw-format oxe_rlds \
|
||||
--episodes 3 4 5 8 9 \
|
||||
--fps 5
|
||||
--raw-format oxe_rlds.bridge_orig \
|
||||
--episodes 3 4 5 8 9
|
||||
|
||||
Exact dataset fps is specified in:
|
||||
Exact dataset fps defined in oxe/config.py, obtained from:
|
||||
https://docs.google.com/spreadsheets/d/1rPBD77tk60AEIGZrGSODwyyzs5FgCU9Uz3h-3_t2A9g/edit?gid=0#gid=0&range=R:R
|
||||
"""
|
||||
|
||||
|
@ -19,12 +18,15 @@ import shutil
|
|||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import tensorflow_datasets as tfds
|
||||
import torch
|
||||
import tqdm
|
||||
from datasets import Dataset, Features, Image, Sequence, Value
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from lerobot.common.datasets.push_dataset_to_hub.oxe.configs import OXE_DATASET_CONFIGS
|
||||
from lerobot.common.datasets.push_dataset_to_hub.oxe.transforms import OXE_STANDARDIZATION_TRANSFORMS
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
||||
from lerobot.common.datasets.utils import (
|
||||
calculate_episode_data_index,
|
||||
|
@ -43,7 +45,51 @@ def tf_to_torch(data):
|
|||
return torch.from_numpy(data.numpy())
|
||||
|
||||
|
||||
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
|
||||
def tf_img_convert(img):
|
||||
if img.dtype == tf.string:
|
||||
img = tf.io.decode_image(img, expand_animations=False, dtype=tf.uint8)
|
||||
elif img.dtype != tf.uint8:
|
||||
raise ValueError(f"Unsupported image dtype: found with dtype {img.dtype}")
|
||||
return img.numpy()
|
||||
|
||||
|
||||
def _broadcast_metadata_rlds(i: tf.Tensor, traj: dict) -> dict:
|
||||
"""
|
||||
In the RLDS format, each trajectory has some top-level metadata that is explicitly separated out, and a "steps"
|
||||
entry. This function moves the "steps" entry to the top level, broadcasting any metadata to the length of the
|
||||
trajectory. This function also adds the extra metadata fields `_len`, `_traj_index`, and `_frame_index`.
|
||||
|
||||
NOTE: adapted from DLimp library https://github.com/kvablack/dlimp/
|
||||
"""
|
||||
steps = traj.pop("steps")
|
||||
|
||||
traj_len = tf.shape(tf.nest.flatten(steps)[0])[0]
|
||||
|
||||
# broadcast metadata to the length of the trajectory
|
||||
metadata = tf.nest.map_structure(lambda x: tf.repeat(x, traj_len), traj)
|
||||
|
||||
# put steps back in
|
||||
assert "traj_metadata" not in steps
|
||||
traj = {**steps, "traj_metadata": metadata}
|
||||
|
||||
assert "_len" not in traj
|
||||
assert "_traj_index" not in traj
|
||||
assert "_frame_index" not in traj
|
||||
traj["_len"] = tf.repeat(traj_len, traj_len)
|
||||
traj["_traj_index"] = tf.repeat(i, traj_len)
|
||||
traj["_frame_index"] = tf.range(traj_len)
|
||||
|
||||
return traj
|
||||
|
||||
|
||||
def load_from_raw(
|
||||
raw_dir: Path,
|
||||
videos_dir: Path,
|
||||
fps: int,
|
||||
video: bool,
|
||||
episodes: list[int],
|
||||
oxe_dataset_name: str | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
raw_dir (Path): _description_
|
||||
|
@ -53,22 +99,32 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
|
|||
episodes (list[int] | None, optional): _description_. Defaults to None.
|
||||
"""
|
||||
ds_builder = tfds.builder_from_directory(str(raw_dir))
|
||||
dataset = ds_builder.as_dataset(split="all")
|
||||
dataset = ds_builder.as_dataset(
|
||||
split="all",
|
||||
decoders={"steps": tfds.decode.SkipDecoding()},
|
||||
)
|
||||
dataset_info = ds_builder.info
|
||||
print("dataset_info: ", dataset_info)
|
||||
|
||||
image_keys = get_cameras_keys(dataset_info.features["steps"]["observation"].keys())
|
||||
ds_length = len(dataset)
|
||||
dataset = dataset.take(ds_length)
|
||||
|
||||
# check if there's a 'tfds.features.Text' in step, only take 1 lang instruction
|
||||
lang_key = [
|
||||
key for key, value in dataset_info.features["steps"].items() if isinstance(value, tfds.features.Text)
|
||||
]
|
||||
lang_key = None if len(lang_key) == 0 else lang_key[0]
|
||||
# "flatten" the dataset as such we can apply trajectory level map() easily
|
||||
# each [obs][key] has a shape of (frame_size, ...)
|
||||
dataset = dataset.enumerate().map(_broadcast_metadata_rlds)
|
||||
|
||||
# we will apply the standardization transform if the dataset_name is provided
|
||||
if oxe_dataset_name is not None:
|
||||
print(" - applying standardization transform for dataset: ", oxe_dataset_name)
|
||||
assert oxe_dataset_name in OXE_STANDARDIZATION_TRANSFORMS
|
||||
transform_fn = OXE_STANDARDIZATION_TRANSFORMS[oxe_dataset_name]
|
||||
dataset = dataset.map(transform_fn)
|
||||
|
||||
image_keys = get_cameras_keys(dataset_info.features["steps"]["observation"].keys())
|
||||
lang_key = "language_instruction" if "language_instruction" in dataset.element_spec else None
|
||||
print(" - image_keys: ", image_keys)
|
||||
print(" - lang_key: ", lang_key)
|
||||
|
||||
ds_length = len(dataset)
|
||||
dataset = dataset.take(ds_length)
|
||||
it = iter(dataset)
|
||||
|
||||
ep_dicts = []
|
||||
|
@ -83,7 +139,7 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
|
|||
for ep_idx in tqdm.tqdm(range(ds_length)):
|
||||
episode = next(it)
|
||||
|
||||
# if we user specified episodes, skip the ones not in the list
|
||||
# if user specified episodes, skip the ones not in the list
|
||||
if episodes is not None:
|
||||
if len(episodes) == 0:
|
||||
break
|
||||
|
@ -94,38 +150,49 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
|
|||
else:
|
||||
continue # skip
|
||||
|
||||
steps = episode["steps"]
|
||||
num_frames = len(steps)
|
||||
num_frames = episode["action"].shape[0]
|
||||
|
||||
###########################################################
|
||||
# Handle the episodic data
|
||||
|
||||
# last step of demonstration is considered done
|
||||
done = torch.zeros(num_frames, dtype=torch.bool)
|
||||
done[-1] = True
|
||||
|
||||
states = []
|
||||
actions = [] # TODO(YL): some actions can be a featuredict
|
||||
rewards = torch.zeros(num_frames, dtype=torch.float32)
|
||||
ep_dict = {}
|
||||
langs = []
|
||||
langs = [] # TODO: might be located in "observation"
|
||||
|
||||
image_array_dict = {key: [] for key in image_keys}
|
||||
|
||||
###########################################################
|
||||
# loop through all steps in the episode
|
||||
for j, step in enumerate(steps):
|
||||
states.append(tf_to_torch(step["observation"]["state"]))
|
||||
actions.append(tf_to_torch(step["action"]))
|
||||
rewards[j] = torch.tensor(step["reward"].numpy(), dtype=torch.float32)
|
||||
# We will create the state observation tensor by stacking the state
|
||||
# obs keys defined in the oxe/configs.py
|
||||
if oxe_dataset_name is not None:
|
||||
state_obs_keys = OXE_DATASET_CONFIGS[oxe_dataset_name]["state_obs_keys"]
|
||||
# stack the state observations, if is None, pad with zeros
|
||||
states = []
|
||||
for key in state_obs_keys:
|
||||
if key in episode["observation"]:
|
||||
states.append(tf_to_torch(episode["observation"][key]))
|
||||
else:
|
||||
states.append(torch.zeros(num_frames, 1)) # pad with zeros
|
||||
states = torch.cat(states, dim=1)
|
||||
assert states.shape == (num_frames, 8)
|
||||
else:
|
||||
states = tf_to_torch(episode["observation"]["state"])
|
||||
|
||||
if lang_key is not None:
|
||||
langs.append(str(step[lang_key]))
|
||||
actions = tf_to_torch(episode["action"])
|
||||
rewards = tf_to_torch(episode["reward"]).float()
|
||||
|
||||
for im_key in image_keys:
|
||||
if im_key not in step["observation"]:
|
||||
continue
|
||||
# If lang_key is present, convert the entire tensor at once
|
||||
if lang_key is not None:
|
||||
langs = [str(x) for x in episode[lang_key]]
|
||||
|
||||
img = step["observation"][im_key]
|
||||
img = np.array(img)
|
||||
image_array_dict[im_key].append(img)
|
||||
for im_key in image_keys:
|
||||
imgs = episode["observation"][im_key]
|
||||
image_array_dict[im_key] = [tf_img_convert(img) for img in imgs]
|
||||
|
||||
# simple assertions
|
||||
for item in [states, actions, rewards, done]:
|
||||
assert len(item) == num_frames
|
||||
|
||||
###########################################################
|
||||
|
||||
|
@ -157,12 +224,12 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
|
|||
if lang_key is not None:
|
||||
ep_dict["language_instruction"] = langs
|
||||
|
||||
ep_dict["observation.state"] = torch.stack(states) # TODO better way
|
||||
ep_dict["action"] = torch.stack(actions)
|
||||
ep_dict["observation.state"] = states
|
||||
ep_dict["action"] = actions
|
||||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames)
|
||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||
ep_dict["reward"] = rewards
|
||||
ep_dict["next.reward"] = rewards
|
||||
ep_dict["next.done"] = done
|
||||
|
||||
ep_dicts.append(ep_dict)
|
||||
|
@ -204,7 +271,7 @@ def to_hf_dataset(data_dict, video) -> Dataset:
|
|||
features["episode_index"] = Value(dtype="int64", id=None)
|
||||
features["frame_index"] = Value(dtype="int64", id=None)
|
||||
features["timestamp"] = Value(dtype="float32", id=None)
|
||||
features["reward"] = Value(dtype="float32", id=None)
|
||||
features["next.reward"] = Value(dtype="float32", id=None)
|
||||
features["next.done"] = Value(dtype="bool", id=None)
|
||||
features["index"] = Value(dtype="int64", id=None)
|
||||
|
||||
|
@ -219,12 +286,22 @@ def from_raw_to_lerobot_format(
|
|||
fps: int | None = None,
|
||||
video: bool = True,
|
||||
episodes: list[int] | None = None,
|
||||
oxe_dataset_name: str | None = None,
|
||||
):
|
||||
"""This is a test impl for rlds conversion"""
|
||||
if fps is None:
|
||||
fps = 5
|
||||
if oxe_dataset_name is not None:
|
||||
if "fps" not in OXE_DATASET_CONFIGS[oxe_dataset_name]:
|
||||
raise ValueError(
|
||||
"fps for this dataset is not specified in oxe/configs.py yet,"
|
||||
"means it is not yet tested"
|
||||
)
|
||||
fps = OXE_DATASET_CONFIGS[oxe_dataset_name]["fps"]
|
||||
else:
|
||||
print(" - WARNING: fps is not provided, using default value of 5 fps")
|
||||
fps = 5
|
||||
|
||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
|
||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, oxe_dataset_name)
|
||||
hf_dataset = to_hf_dataset(data_dict, video)
|
||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||
info = {
|
||||
|
|
|
@ -65,7 +65,7 @@ def get_from_raw_to_lerobot_format_fn(raw_format: str):
|
|||
from lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format import from_raw_to_lerobot_format
|
||||
elif raw_format == "aloha_hdf5":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import from_raw_to_lerobot_format
|
||||
elif raw_format == "oxe_rlds":
|
||||
elif "oxe_rlds" in raw_format:
|
||||
from lerobot.common.datasets.push_dataset_to_hub.oxe_rlds_format import from_raw_to_lerobot_format
|
||||
elif raw_format == "dora_parquet":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.dora_parquet_format import from_raw_to_lerobot_format
|
||||
|
@ -192,9 +192,27 @@ def push_dataset_to_hub(
|
|||
|
||||
# convert dataset from original raw format to LeRobot format
|
||||
from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format)
|
||||
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(
|
||||
raw_dir, videos_dir, fps, video, episodes
|
||||
)
|
||||
|
||||
if "oxe_rlds" in raw_format:
|
||||
# User could provide official OXE dataset name to convert it to LeRobot format
|
||||
# the raw_format str is as such:
|
||||
# oxe_rlds (default)
|
||||
# oxe_rlds.bridge_orig: (with bridge_orig as oxe_dataset_name)
|
||||
splited_raw_format = raw_format.split(".")
|
||||
assert len(splited_raw_format) <= 2, f"Invalid raw_format: {raw_format}"
|
||||
if len(splited_raw_format) == 2:
|
||||
oxe_dataset_name = splited_raw_format[1]
|
||||
print(f"Converting dataset [{oxe_dataset_name}] from 'oxe_rlds' to LeRobot format.")
|
||||
else:
|
||||
oxe_dataset_name = None
|
||||
|
||||
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(
|
||||
raw_dir, videos_dir, fps, video, episodes, oxe_dataset_name=oxe_dataset_name
|
||||
)
|
||||
else:
|
||||
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(
|
||||
raw_dir, videos_dir, fps, video, episodes
|
||||
)
|
||||
|
||||
lerobot_dataset = LeRobotDataset.from_preloaded(
|
||||
repo_id=repo_id,
|
||||
|
|
Loading…
Reference in New Issue