Support for converting OpenX datasets from RLDS format to LeRobotDataset (#354)
Signed-off-by: youliangtan <tan_you_liang@hotmail.com> Co-authored-by: Simon Alibert <alibert.sim@gmail.com> Co-authored-by: youliangtan <tan_you_liang@hotmail.com> Co-authored-by: Remi <re.cadene@gmail.com>
This commit is contained in:
parent
aad59e6b6b
commit
eb4c505cff
|
@ -129,6 +129,53 @@ available_real_world_datasets = [
|
|||
"lerobot/unitreeh1_rearrange_objects",
|
||||
"lerobot/unitreeh1_two_robot_greeting",
|
||||
"lerobot/unitreeh1_warehouse",
|
||||
"lerobot/nyu_rot_dataset",
|
||||
"lerobot/utokyo_saytap",
|
||||
"lerobot/imperialcollege_sawyer_wrist_cam",
|
||||
"lerobot/utokyo_xarm_bimanual",
|
||||
"lerobot/tokyo_u_lsmo",
|
||||
"lerobot/utokyo_pr2_opening_fridge",
|
||||
"lerobot/cmu_franka_exploration_dataset",
|
||||
"lerobot/cmu_stretch",
|
||||
"lerobot/asu_table_top",
|
||||
"lerobot/utokyo_pr2_tabletop_manipulation",
|
||||
"lerobot/utokyo_xarm_pick_and_place",
|
||||
"lerobot/ucsd_kitchen_dataset",
|
||||
"lerobot/austin_buds_dataset",
|
||||
"lerobot/dlr_sara_grid_clamp",
|
||||
"lerobot/conq_hose_manipulation",
|
||||
"lerobot/columbia_cairlab_pusht_real",
|
||||
"lerobot/dlr_sara_pour",
|
||||
"lerobot/dlr_edan_shared_control",
|
||||
"lerobot/ucsd_pick_and_place_dataset",
|
||||
"lerobot/berkeley_cable_routing",
|
||||
"lerobot/nyu_franka_play_dataset",
|
||||
"lerobot/austin_sirius_dataset",
|
||||
"lerobot/cmu_play_fusion",
|
||||
"lerobot/berkeley_gnm_sac_son",
|
||||
"lerobot/nyu_door_opening_surprising_effectiveness",
|
||||
"lerobot/berkeley_fanuc_manipulation",
|
||||
"lerobot/jaco_play",
|
||||
"lerobot/viola",
|
||||
"lerobot/kaist_nonprehensile",
|
||||
"lerobot/berkeley_mvp",
|
||||
"lerobot/uiuc_d3field",
|
||||
"lerobot/berkeley_gnm_recon",
|
||||
"lerobot/austin_sailor_dataset",
|
||||
"lerobot/utaustin_mutex",
|
||||
"lerobot/roboturk",
|
||||
"lerobot/stanford_hydra_dataset",
|
||||
"lerobot/berkeley_autolab_ur5",
|
||||
"lerobot/stanford_robocook",
|
||||
"lerobot/toto",
|
||||
"lerobot/fmb",
|
||||
"lerobot/droid_100",
|
||||
"lerobot/berkeley_rpt",
|
||||
"lerobot/stanford_kuka_multimodal_dataset",
|
||||
"lerobot/iamlab_cmu_pickup_insert",
|
||||
"lerobot/taco_play",
|
||||
"lerobot/berkeley_gnm_cory_hall",
|
||||
"lerobot/usc_cloth_sim",
|
||||
]
|
||||
|
||||
available_datasets = list(
|
||||
|
|
|
@ -40,6 +40,10 @@ def get_stats_einops_patterns(dataset, num_workers=0):
|
|||
|
||||
stats_patterns = {}
|
||||
for key, feats_type in dataset.features.items():
|
||||
# NOTE: skip language_instruction embedding in stats computation
|
||||
if key == "language_instruction":
|
||||
continue
|
||||
|
||||
# sanity check that tensors are not float64
|
||||
assert batch[key].dtype != torch.float64
|
||||
|
||||
|
|
|
@ -60,8 +60,8 @@ AVAILABLE_RAW_REPO_IDS = {
|
|||
"lerobot-raw/aloha_static_vinh_cup_left_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_static_vinh_cup_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_static_ziploc_slide_raw": "aloha_hdf5",
|
||||
"lerobot-raw/pusht_raw": "pusht_zarr",
|
||||
"lerobot-raw/umi_cup_in_the_wild_raw": "umi_zarr",
|
||||
"lerobot-raw/pusht_raw": "pusht_zarr",
|
||||
"lerobot-raw/unitreeh1_fold_clothes_raw": "aloha_hdf5",
|
||||
"lerobot-raw/unitreeh1_rearrange_objects_raw": "aloha_hdf5",
|
||||
"lerobot-raw/unitreeh1_two_robot_greeting_raw": "aloha_hdf5",
|
||||
|
@ -70,6 +70,74 @@ AVAILABLE_RAW_REPO_IDS = {
|
|||
"lerobot-raw/xarm_lift_medium_replay_raw": "xarm_pkl",
|
||||
"lerobot-raw/xarm_push_medium_raw": "xarm_pkl",
|
||||
"lerobot-raw/xarm_push_medium_replay_raw": "xarm_pkl",
|
||||
"lerobot-raw/fractal20220817_data_raw": "openx_rlds.fractal20220817_data",
|
||||
"lerobot-raw/kuka_raw": "openx_rlds.kuka",
|
||||
"lerobot-raw/bridge_openx_raw": "openx_rlds.bridge_openx",
|
||||
"lerobot-raw/taco_play_raw": "openx_rlds.taco_play",
|
||||
"lerobot-raw/jaco_play_raw": "openx_rlds.jaco_play",
|
||||
"lerobot-raw/berkeley_cable_routing_raw": "openx_rlds.berkeley_cable_routing",
|
||||
"lerobot-raw/roboturk_raw": "openx_rlds.roboturk",
|
||||
"lerobot-raw/nyu_door_opening_surprising_effectiveness_raw": "openx_rlds.nyu_door_opening_surprising_effectiveness",
|
||||
"lerobot-raw/viola_raw": "openx_rlds.viola",
|
||||
"lerobot-raw/berkeley_autolab_ur5_raw": "openx_rlds.berkeley_autolab_ur5",
|
||||
"lerobot-raw/toto_raw": "openx_rlds.toto",
|
||||
"lerobot-raw/language_table_raw": "openx_rlds.language_table",
|
||||
"lerobot-raw/columbia_cairlab_pusht_real_raw": "openx_rlds.columbia_cairlab_pusht_real",
|
||||
"lerobot-raw/stanford_kuka_multimodal_dataset_raw": "openx_rlds.stanford_kuka_multimodal_dataset",
|
||||
"lerobot-raw/nyu_rot_dataset_raw": "openx_rlds.nyu_rot_dataset",
|
||||
"lerobot-raw/io_ai_tech_raw": "openx_rlds.io_ai_tech",
|
||||
"lerobot-raw/stanford_hydra_dataset_raw": "openx_rlds.stanford_hydra_dataset",
|
||||
"lerobot-raw/austin_buds_dataset_raw": "openx_rlds.austin_buds_dataset",
|
||||
"lerobot-raw/nyu_franka_play_dataset_raw": "openx_rlds.nyu_franka_play_dataset",
|
||||
"lerobot-raw/maniskill_dataset_raw": "openx_rlds.maniskill_dataset",
|
||||
"lerobot-raw/furniture_bench_dataset_raw": "openx_rlds.furniture_bench_dataset",
|
||||
"lerobot-raw/cmu_franka_exploration_dataset_raw": "openx_rlds.cmu_franka_exploration_dataset",
|
||||
"lerobot-raw/ucsd_kitchen_dataset_raw": "openx_rlds.ucsd_kitchen_dataset",
|
||||
"lerobot-raw/ucsd_pick_and_place_dataset_raw": "openx_rlds.ucsd_pick_and_place_dataset",
|
||||
"lerobot-raw/spoc_raw": "openx_rlds.spoc",
|
||||
"lerobot-raw/austin_sailor_dataset_raw": "openx_rlds.austin_sailor_dataset",
|
||||
"lerobot-raw/austin_sirius_dataset_raw": "openx_rlds.austin_sirius_dataset",
|
||||
"lerobot-raw/bc_z_raw": "openx_rlds.bc_z",
|
||||
"lerobot-raw/utokyo_pr2_opening_fridge_raw": "openx_rlds.utokyo_pr2_opening_fridge",
|
||||
"lerobot-raw/utokyo_pr2_tabletop_manipulation_raw": "openx_rlds.utokyo_pr2_tabletop_manipulation",
|
||||
"lerobot-raw/utokyo_xarm_pick_and_place_raw": "openx_rlds.utokyo_xarm_pick_and_place",
|
||||
"lerobot-raw/utokyo_xarm_bimanual_raw": "openx_rlds.utokyo_xarm_bimanual",
|
||||
"lerobot-raw/utokyo_saytap_raw": "openx_rlds.utokyo_saytap",
|
||||
"lerobot-raw/robo_net_raw": "openx_rlds.robo_net",
|
||||
"lerobot-raw/robo_set_raw": "openx_rlds.robo_set",
|
||||
"lerobot-raw/berkeley_mvp_raw": "openx_rlds.berkeley_mvp",
|
||||
"lerobot-raw/berkeley_rpt_raw": "openx_rlds.berkeley_rpt",
|
||||
"lerobot-raw/kaist_nonprehensile_raw": "openx_rlds.kaist_nonprehensile",
|
||||
"lerobot-raw/stanford_mask_vit_raw": "openx_rlds.stanford_mask_vit",
|
||||
"lerobot-raw/tokyo_u_lsmo_raw": "openx_rlds.tokyo_u_lsmo",
|
||||
"lerobot-raw/dlr_sara_pour_raw": "openx_rlds.dlr_sara_pour",
|
||||
"lerobot-raw/dlr_sara_grid_clamp_raw": "openx_rlds.dlr_sara_grid_clamp",
|
||||
"lerobot-raw/dlr_edan_shared_control_raw": "openx_rlds.dlr_edan_shared_control",
|
||||
"lerobot-raw/asu_table_top_raw": "openx_rlds.asu_table_top",
|
||||
"lerobot-raw/stanford_robocook_raw": "openx_rlds.stanford_robocook",
|
||||
"lerobot-raw/imperialcollege_sawyer_wrist_cam_raw": "openx_rlds.imperialcollege_sawyer_wrist_cam",
|
||||
"lerobot-raw/iamlab_cmu_pickup_insert_raw": "openx_rlds.iamlab_cmu_pickup_insert",
|
||||
"lerobot-raw/uiuc_d3field_raw": "openx_rlds.uiuc_d3field",
|
||||
"lerobot-raw/utaustin_mutex_raw": "openx_rlds.utaustin_mutex",
|
||||
"lerobot-raw/berkeley_fanuc_manipulation_raw": "openx_rlds.berkeley_fanuc_manipulation",
|
||||
"lerobot-raw/cmu_playing_with_food_raw": "openx_rlds.cmu_playing_with_food",
|
||||
"lerobot-raw/cmu_play_fusion_raw": "openx_rlds.cmu_play_fusion",
|
||||
"lerobot-raw/cmu_stretch_raw": "openx_rlds.cmu_stretch",
|
||||
"lerobot-raw/berkeley_gnm_recon_raw": "openx_rlds.berkeley_gnm_recon",
|
||||
"lerobot-raw/berkeley_gnm_cory_hall_raw": "openx_rlds.berkeley_gnm_cory_hall",
|
||||
"lerobot-raw/berkeley_gnm_sac_son_raw": "openx_rlds.berkeley_gnm_sac_son",
|
||||
"lerobot-raw/droid_raw": "openx_rlds.droid",
|
||||
"lerobot-raw/droid_100_raw": "openx_rlds.droid100",
|
||||
"lerobot-raw/fmb_raw": "openx_rlds.fmb",
|
||||
"lerobot-raw/dobbe_raw": "openx_rlds.dobbe",
|
||||
"lerobot-raw/usc_cloth_sim_raw": "openx_rlds.usc_cloth_sim",
|
||||
"lerobot-raw/plex_robosuite_raw": "openx_rlds.plex_robosuite",
|
||||
"lerobot-raw/conq_hose_manipulation_raw": "openx_rlds.conq_hose_manipulation",
|
||||
"lerobot-raw/vima_raw": "openx_rlds.vima",
|
||||
"lerobot-raw/robot_vqa_raw": "openx_rlds.robot_vqa",
|
||||
"lerobot-raw/mimic_play_raw": "openx_rlds.mimic_play",
|
||||
"lerobot-raw/tidybot_raw": "openx_rlds.tidybot",
|
||||
"lerobot-raw/eth_agent_affordances_raw": "openx_rlds.eth_agent_affordances",
|
||||
}
|
||||
|
||||
|
||||
|
@ -110,7 +178,7 @@ def download_all_raw_datasets(data_dir: Path | None = None):
|
|||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description=f"""A script to download raw datasets from Hugging Face hub to a local directory. Here is a
|
||||
non exhaustive list of available repositories to use in `--repo-id`: {AVAILABLE_RAW_REPO_IDS}""",
|
||||
non exhaustive list of available repositories to use in `--repo-id`: {list(AVAILABLE_RAW_REPO_IDS.keys())}""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
|
|
|
@ -0,0 +1,640 @@
|
|||
OPENX_DATASET_CONFIGS:
|
||||
fractal20220817_data:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- base_pose_tool_reached
|
||||
- gripper_closed
|
||||
fps: 3
|
||||
|
||||
kuka:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- clip_function_input/base_pose_tool_reached
|
||||
- gripper_closed
|
||||
fps: 10
|
||||
|
||||
bridge_openx:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- EEF_state
|
||||
- gripper_state
|
||||
fps: 5
|
||||
|
||||
taco_play:
|
||||
image_obs_keys:
|
||||
- rgb_static
|
||||
- rgb_gripper
|
||||
depth_obs_keys:
|
||||
- depth_static
|
||||
- depth_gripper
|
||||
state_obs_keys:
|
||||
- state_eef
|
||||
- state_gripper
|
||||
fps: 15
|
||||
|
||||
jaco_play:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- image_wrist
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state_eef
|
||||
- state_gripper
|
||||
fps: 10
|
||||
|
||||
berkeley_cable_routing:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- top_image
|
||||
- wrist45_image
|
||||
- wrist225_image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- robot_state
|
||||
fps: 10
|
||||
|
||||
roboturk:
|
||||
image_obs_keys:
|
||||
- front_rgb
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- null
|
||||
fps: 10
|
||||
|
||||
nyu_door_opening_surprising_effectiveness:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- null
|
||||
fps: 3
|
||||
|
||||
viola:
|
||||
image_obs_keys:
|
||||
- agentview_rgb
|
||||
- eye_in_hand_rgb
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- joint_states
|
||||
- gripper_states
|
||||
fps: 20
|
||||
|
||||
berkeley_autolab_ur5:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- hand_image
|
||||
depth_obs_keys:
|
||||
- image_with_depth
|
||||
state_obs_keys:
|
||||
- state
|
||||
fps: 5
|
||||
|
||||
toto:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
fps: 30
|
||||
|
||||
language_table:
|
||||
image_obs_keys:
|
||||
- rgb
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- effector_translation
|
||||
fps: 10
|
||||
|
||||
columbia_cairlab_pusht_real:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- wrist_image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- robot_state
|
||||
fps: 10
|
||||
|
||||
stanford_kuka_multimodal_dataset_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- depth_image
|
||||
state_obs_keys:
|
||||
- ee_position
|
||||
- ee_orientation
|
||||
fps: 20
|
||||
|
||||
nyu_rot_dataset_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- eef_state
|
||||
- gripper_state
|
||||
fps: 3
|
||||
|
||||
io_ai_tech:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- image_fisheye
|
||||
- image_left_side
|
||||
- image_right_side
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
fps: 3
|
||||
|
||||
stanford_hydra_dataset_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- wrist_image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- eef_state
|
||||
- gripper_state
|
||||
fps: 10
|
||||
|
||||
austin_buds_dataset_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- wrist_image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
fps: 20
|
||||
|
||||
nyu_franka_play_dataset_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- image_additional_view
|
||||
depth_obs_keys:
|
||||
- depth
|
||||
- depth_additional_view
|
||||
state_obs_keys:
|
||||
- eef_state
|
||||
fps: 3
|
||||
|
||||
maniskill_dataset_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- wrist_image
|
||||
depth_obs_keys:
|
||||
- depth
|
||||
- wrist_depth
|
||||
state_obs_keys:
|
||||
- tcp_pose
|
||||
- gripper_state
|
||||
fps: 20
|
||||
|
||||
furniture_bench_dataset_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- wrist_image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
fps: 10
|
||||
|
||||
cmu_franka_exploration_dataset_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- highres_image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- null
|
||||
fps: 10
|
||||
|
||||
ucsd_kitchen_dataset_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- joint_state
|
||||
fps: 2
|
||||
|
||||
ucsd_pick_and_place_dataset_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- eef_state
|
||||
- gripper_state
|
||||
fps: 3
|
||||
|
||||
spoc:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- image_manipulation
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- null
|
||||
fps: 3
|
||||
|
||||
austin_sailor_dataset_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- wrist_image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
fps: 20
|
||||
|
||||
austin_sirius_dataset_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- wrist_image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
fps: 20
|
||||
|
||||
bc_z:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- present/xyz
|
||||
- present/axis_angle
|
||||
- present/sensed_close
|
||||
fps: 10
|
||||
|
||||
utokyo_pr2_opening_fridge_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- eef_state
|
||||
- gripper_state
|
||||
fps: 10
|
||||
|
||||
utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- eef_state
|
||||
- gripper_state
|
||||
fps: 10
|
||||
|
||||
utokyo_xarm_pick_and_place_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- image2
|
||||
- hand_image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- end_effector_pose
|
||||
fps: 10
|
||||
|
||||
utokyo_xarm_bimanual_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- pose_r
|
||||
fps: 10
|
||||
|
||||
robo_net:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- image1
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- eef_state
|
||||
- gripper_state
|
||||
fps: 1
|
||||
|
||||
robo_set:
|
||||
image_obs_keys:
|
||||
- image_left
|
||||
- image_right
|
||||
- image_wrist
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
- state_velocity
|
||||
fps: 5
|
||||
|
||||
berkeley_mvp_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- hand_image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- gripper
|
||||
- pose
|
||||
- joint_pos
|
||||
fps: 5
|
||||
|
||||
berkeley_rpt_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- hand_image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- joint_pos
|
||||
- gripper
|
||||
fps: 30
|
||||
|
||||
kaist_nonprehensile_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
fps: 10
|
||||
|
||||
stanford_mask_vit_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- eef_state
|
||||
- gripper_state
|
||||
|
||||
tokyo_u_lsmo_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- eef_state
|
||||
- gripper_state
|
||||
fps: 10
|
||||
|
||||
dlr_sara_pour_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
fps: 10
|
||||
|
||||
dlr_sara_grid_clamp_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
fps: 10
|
||||
|
||||
dlr_edan_shared_control_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
fps: 5
|
||||
|
||||
asu_table_top_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- eef_state
|
||||
- gripper_state
|
||||
fps: 12.5
|
||||
|
||||
stanford_robocook_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image_1
|
||||
- image_2
|
||||
depth_obs_keys:
|
||||
- depth_1
|
||||
- depth_2
|
||||
state_obs_keys:
|
||||
- eef_state
|
||||
- gripper_state
|
||||
fps: 5
|
||||
|
||||
imperialcollege_sawyer_wrist_cam:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- wrist_image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
fps: 10
|
||||
|
||||
iamlab_cmu_pickup_insert_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- wrist_image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- joint_state
|
||||
- gripper_state
|
||||
fps: 20
|
||||
|
||||
uiuc_d3field:
|
||||
image_obs_keys:
|
||||
- image_1
|
||||
- image_2
|
||||
depth_obs_keys:
|
||||
- depth_1
|
||||
- depth_2
|
||||
state_obs_keys:
|
||||
- null
|
||||
fps: 1
|
||||
|
||||
utaustin_mutex:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- wrist_image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
fps: 20
|
||||
|
||||
berkeley_fanuc_manipulation:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- wrist_image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- joint_state
|
||||
- gripper_state
|
||||
fps: 10
|
||||
|
||||
cmu_playing_with_food:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- finger_vision_1
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
fps: 10
|
||||
|
||||
cmu_play_fusion:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
fps: 5
|
||||
|
||||
cmu_stretch:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- eef_state
|
||||
- gripper_state
|
||||
fps: 10
|
||||
|
||||
berkeley_gnm_recon:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
- position
|
||||
- yaw
|
||||
fps: 3
|
||||
|
||||
berkeley_gnm_cory_hall:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
- position
|
||||
- yaw
|
||||
fps: 5
|
||||
|
||||
berkeley_gnm_sac_son:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
- position
|
||||
- yaw
|
||||
fps: 10
|
||||
|
||||
droid:
|
||||
image_obs_keys:
|
||||
- exterior_image_1_left
|
||||
- exterior_image_2_left
|
||||
- wrist_image_left
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- proprio
|
||||
fps: 15
|
||||
|
||||
droid_100:
|
||||
image_obs_keys:
|
||||
- exterior_image_1_left
|
||||
- exterior_image_2_left
|
||||
- wrist_image_left
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- proprio
|
||||
fps: 15
|
||||
|
||||
fmb:
|
||||
image_obs_keys:
|
||||
- image_side_1
|
||||
- image_side_2
|
||||
- image_wrist_1
|
||||
- image_wrist_2
|
||||
depth_obs_keys:
|
||||
- image_side_1_depth
|
||||
- image_side_2_depth
|
||||
- image_wrist_1_depth
|
||||
- image_wrist_2_depth
|
||||
state_obs_keys:
|
||||
- proprio
|
||||
fps: 10
|
||||
|
||||
dobbe:
|
||||
image_obs_keys:
|
||||
- wrist_image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- proprio
|
||||
fps: 3.75
|
||||
|
||||
usc_cloth_sim_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- null
|
||||
fps: 10
|
||||
|
||||
plex_robosuite:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- wrist_image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
fps: 20
|
||||
|
||||
conq_hose_manipulation:
|
||||
image_obs_keys:
|
||||
- frontleft_fisheye_image
|
||||
- frontright_fisheye_image
|
||||
- hand_color_image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
fps: 30
|
||||
|
|
@ -0,0 +1,106 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the Licens e.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
NOTE(YL): Adapted from:
|
||||
Octo: https://github.com/octo-models/octo/blob/main/octo/data/utils/data_utils.py
|
||||
|
||||
data_utils.py
|
||||
|
||||
Additional utils for data processing.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def binarize_gripper_actions(actions: tf.Tensor) -> tf.Tensor:
|
||||
"""
|
||||
Converts gripper actions from continuous to binary values (0 and 1).
|
||||
|
||||
We exploit that fact that most of the time, the gripper is fully open (near 1.0) or fully closed (near 0.0). As it
|
||||
transitions between the two, it sometimes passes through a few intermediate values. We relabel those intermediate
|
||||
values based on the state that is reached _after_ those intermediate values.
|
||||
|
||||
In the edge case that the trajectory ends with an intermediate value, we give up on binarizing and relabel that
|
||||
chunk of intermediate values as the last action in the trajectory.
|
||||
|
||||
The `scan_fn` implements the following logic:
|
||||
new_actions = np.empty_like(actions)
|
||||
carry = actions[-1]
|
||||
for i in reversed(range(actions.shape[0])):
|
||||
if in_between_mask[i]:
|
||||
carry = carry
|
||||
else:
|
||||
carry = float(open_mask[i])
|
||||
new_actions[i] = carry
|
||||
"""
|
||||
open_mask, closed_mask = actions > 0.95, actions < 0.05
|
||||
in_between_mask = tf.logical_not(tf.logical_or(open_mask, closed_mask))
|
||||
is_open_float = tf.cast(open_mask, tf.float32)
|
||||
|
||||
def scan_fn(carry, i):
|
||||
return tf.cond(in_between_mask[i], lambda: tf.cast(carry, tf.float32), lambda: is_open_float[i])
|
||||
|
||||
return tf.scan(scan_fn, tf.range(tf.shape(actions)[0]), actions[-1], reverse=True)
|
||||
|
||||
|
||||
def invert_gripper_actions(actions: tf.Tensor) -> tf.Tensor:
|
||||
return 1 - actions
|
||||
|
||||
|
||||
def rel2abs_gripper_actions(actions: tf.Tensor) -> tf.Tensor:
|
||||
"""
|
||||
Converts relative gripper actions (+1 for closing, -1 for opening) to absolute actions (0 = closed; 1 = open).
|
||||
|
||||
Assumes that the first relative gripper is not redundant (i.e. close when already closed)!
|
||||
"""
|
||||
# Note =>> -1 for closing, 1 for opening, 0 for no change
|
||||
opening_mask, closing_mask = actions < -0.1, actions > 0.1
|
||||
thresholded_actions = tf.where(opening_mask, 1, tf.where(closing_mask, -1, 0))
|
||||
|
||||
def scan_fn(carry, i):
|
||||
return tf.cond(thresholded_actions[i] == 0, lambda: carry, lambda: thresholded_actions[i])
|
||||
|
||||
# If no relative grasp, assumes open for whole trajectory
|
||||
start = -1 * thresholded_actions[tf.argmax(thresholded_actions != 0, axis=0)]
|
||||
start = tf.cond(start == 0, lambda: 1, lambda: start)
|
||||
|
||||
# Note =>> -1 for closed, 1 for open
|
||||
new_actions = tf.scan(scan_fn, tf.range(tf.shape(actions)[0]), start)
|
||||
new_actions = tf.cast(new_actions, tf.float32) / 2 + 0.5
|
||||
|
||||
return new_actions
|
||||
|
||||
|
||||
# === Bridge-V2 =>> Dataset-Specific Transform ===
|
||||
def relabel_bridge_actions(traj: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Relabels actions to use reached proprioceptive state; discards last timestep (no-action)."""
|
||||
movement_actions = traj["observation"]["state"][1:, :6] - traj["observation"]["state"][:-1, :6]
|
||||
traj_truncated = tf.nest.map_structure(lambda x: x[:-1], traj)
|
||||
traj_truncated["action"] = tf.concat([movement_actions, traj["action"][:-1, -1:]], axis=1)
|
||||
|
||||
return traj_truncated
|
||||
|
||||
|
||||
# === RLDS Dataset Initialization Utilities ===
|
||||
def pprint_data_mixture(dataset_kwargs_list: List[Dict[str, Any]], dataset_weights: List[int]) -> None:
|
||||
print("\n######################################################################################")
|
||||
print(f"# Loading the following {len(dataset_kwargs_list)} datasets (incl. sampling weight):{'': >24} #")
|
||||
for dataset_kwargs, weight in zip(dataset_kwargs_list, dataset_weights, strict=False):
|
||||
pad = 80 - len(dataset_kwargs["name"])
|
||||
print(f"# {dataset_kwargs['name']}: {weight:=>{pad}f} #")
|
||||
print("######################################################################################\n")
|
|
@ -0,0 +1,200 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
NOTE(YL): Adapted from:
|
||||
OpenVLA: https://github.com/openvla/openvla
|
||||
|
||||
Episode transforms for DROID dataset.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
import tensorflow as tf
|
||||
import tensorflow_graphics.geometry.transformation as tfg
|
||||
|
||||
|
||||
def rmat_to_euler(rot_mat):
|
||||
return tfg.euler.from_rotation_matrix(rot_mat)
|
||||
|
||||
|
||||
def euler_to_rmat(euler):
|
||||
return tfg.rotation_matrix_3d.from_euler(euler)
|
||||
|
||||
|
||||
def invert_rmat(rot_mat):
|
||||
return tfg.rotation_matrix_3d.inverse(rot_mat)
|
||||
|
||||
|
||||
def rotmat_to_rot6d(mat):
|
||||
"""
|
||||
Converts rotation matrix to R6 rotation representation (first two rows in rotation matrix).
|
||||
Args:
|
||||
mat: rotation matrix
|
||||
|
||||
Returns: 6d vector (first two rows of rotation matrix)
|
||||
|
||||
"""
|
||||
r6 = mat[..., :2, :]
|
||||
r6_0, r6_1 = r6[..., 0, :], r6[..., 1, :]
|
||||
r6_flat = tf.concat([r6_0, r6_1], axis=-1)
|
||||
return r6_flat
|
||||
|
||||
|
||||
def velocity_act_to_wrist_frame(velocity, wrist_in_robot_frame):
|
||||
"""
|
||||
Translates velocity actions (translation + rotation) from base frame of the robot to wrist frame.
|
||||
Args:
|
||||
velocity: 6d velocity action (3 x translation, 3 x rotation)
|
||||
wrist_in_robot_frame: 6d pose of the end-effector in robot base frame
|
||||
|
||||
Returns: 9d velocity action in robot wrist frame (3 x translation, 6 x rotation as R6)
|
||||
|
||||
"""
|
||||
r_frame = euler_to_rmat(wrist_in_robot_frame[:, 3:6])
|
||||
r_frame_inv = invert_rmat(r_frame)
|
||||
|
||||
# world to wrist: dT_pi = R^-1 dT_rbt
|
||||
vel_t = (r_frame_inv @ velocity[:, :3][..., None])[..., 0]
|
||||
|
||||
# world to wrist: dR_pi = R^-1 dR_rbt R
|
||||
dr_ = euler_to_rmat(velocity[:, 3:6])
|
||||
dr_ = r_frame_inv @ (dr_ @ r_frame)
|
||||
dr_r6 = rotmat_to_rot6d(dr_)
|
||||
return tf.concat([vel_t, dr_r6], axis=-1)
|
||||
|
||||
|
||||
def rand_swap_exterior_images(img1, img2):
|
||||
"""
|
||||
Randomly swaps the two exterior images (for training with single exterior input).
|
||||
"""
|
||||
return tf.cond(tf.random.uniform(shape=[]) > 0.5, lambda: (img1, img2), lambda: (img2, img1))
|
||||
|
||||
|
||||
def droid_baseact_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
DROID dataset transformation for actions expressed in *base* frame of the robot.
|
||||
"""
|
||||
dt = trajectory["action_dict"]["cartesian_velocity"][:, :3]
|
||||
dr_ = trajectory["action_dict"]["cartesian_velocity"][:, 3:6]
|
||||
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
dt,
|
||||
dr_,
|
||||
1 - trajectory["action_dict"]["gripper_position"],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["observation"]["exterior_image_1_left"], trajectory["observation"]["exterior_image_2_left"] = (
|
||||
rand_swap_exterior_images(
|
||||
trajectory["observation"]["exterior_image_1_left"],
|
||||
trajectory["observation"]["exterior_image_2_left"],
|
||||
)
|
||||
)
|
||||
trajectory["observation"]["proprio"] = tf.concat(
|
||||
(
|
||||
trajectory["observation"]["cartesian_position"],
|
||||
trajectory["observation"]["gripper_position"],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def droid_wristact_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
DROID dataset transformation for actions expressed in *wrist* frame of the robot.
|
||||
"""
|
||||
wrist_act = velocity_act_to_wrist_frame(
|
||||
trajectory["action_dict"]["cartesian_velocity"], trajectory["observation"]["cartesian_position"]
|
||||
)
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
wrist_act,
|
||||
trajectory["action_dict"]["gripper_position"],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["observation"]["exterior_image_1_left"], trajectory["observation"]["exterior_image_2_left"] = (
|
||||
rand_swap_exterior_images(
|
||||
trajectory["observation"]["exterior_image_1_left"],
|
||||
trajectory["observation"]["exterior_image_2_left"],
|
||||
)
|
||||
)
|
||||
trajectory["observation"]["proprio"] = tf.concat(
|
||||
(
|
||||
trajectory["observation"]["cartesian_position"],
|
||||
trajectory["observation"]["gripper_position"],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def droid_finetuning_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
DROID dataset transformation for actions expressed in *base* frame of the robot.
|
||||
"""
|
||||
dt = trajectory["action_dict"]["cartesian_velocity"][:, :3]
|
||||
dr_ = trajectory["action_dict"]["cartesian_velocity"][:, 3:6]
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
dt,
|
||||
dr_,
|
||||
1 - trajectory["action_dict"]["gripper_position"],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["observation"]["proprio"] = tf.concat(
|
||||
(
|
||||
trajectory["observation"]["cartesian_position"],
|
||||
trajectory["observation"]["gripper_position"],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def zero_action_filter(traj: Dict) -> bool:
|
||||
"""
|
||||
Filters transitions whose actions are all-0 (only relative actions, no gripper action).
|
||||
Note: this filter is applied *after* action normalization, so need to compare to "normalized 0".
|
||||
"""
|
||||
droid_q01 = tf.convert_to_tensor(
|
||||
[
|
||||
-0.7776297926902771,
|
||||
-0.5803514122962952,
|
||||
-0.5795090794563293,
|
||||
-0.6464047729969025,
|
||||
-0.7041108310222626,
|
||||
-0.8895104378461838,
|
||||
]
|
||||
)
|
||||
droid_q99 = tf.convert_to_tensor(
|
||||
[
|
||||
0.7597932070493698,
|
||||
0.5726242214441299,
|
||||
0.7351000607013702,
|
||||
0.6705610305070877,
|
||||
0.6464948207139969,
|
||||
0.8897542208433151,
|
||||
]
|
||||
)
|
||||
droid_norm_0_act = (
|
||||
2 * (tf.zeros_like(traj["action"][:, :6]) - droid_q01) / (droid_q99 - droid_q01 + 1e-8) - 1
|
||||
)
|
||||
|
||||
return tf.reduce_any(tf.math.abs(traj["action"][:, :6] - droid_norm_0_act) > 1e-5)
|
|
@ -0,0 +1,859 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
NOTE(YL): Adapted from:
|
||||
OpenVLA: https://github.com/openvla/openvla
|
||||
Octo: https://github.com/octo-models/octo
|
||||
|
||||
transforms.py
|
||||
|
||||
Defines a registry of per-dataset standardization transforms for each dataset in Open-X Embodiment.
|
||||
|
||||
Transforms adopt the following structure:
|
||||
Input: Dictionary of *batched* features (i.e., has leading time dimension)
|
||||
Output: Dictionary `step` =>> {
|
||||
"observation": {
|
||||
<image_keys, depth_image_keys>
|
||||
State (in chosen state representation)
|
||||
},
|
||||
"action": Action (in chosen action representation),
|
||||
"language_instruction": str
|
||||
}
|
||||
"""
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from lerobot.common.datasets.push_dataset_to_hub.openx.data_utils import (
|
||||
binarize_gripper_actions,
|
||||
invert_gripper_actions,
|
||||
rel2abs_gripper_actions,
|
||||
relabel_bridge_actions,
|
||||
)
|
||||
|
||||
|
||||
def droid_baseact_transform_fn():
|
||||
from lerobot.common.datasets.push_dataset_to_hub.openx.droid_utils import droid_baseact_transform
|
||||
|
||||
return droid_baseact_transform
|
||||
|
||||
|
||||
def bridge_openx_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Applies to version of Bridge V2 in Open X-Embodiment mixture.
|
||||
|
||||
Note =>> In original Bridge V2 dataset, the first timestep has an all-zero action, so we remove it!
|
||||
"""
|
||||
for key in trajectory:
|
||||
if key == "traj_metadata":
|
||||
continue
|
||||
elif key in ["observation", "action"]:
|
||||
for key2 in trajectory[key]:
|
||||
trajectory[key][key2] = trajectory[key][key2][1:]
|
||||
else:
|
||||
trajectory[key] = trajectory[key][1:]
|
||||
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"]["world_vector"],
|
||||
trajectory["action"]["rotation_delta"],
|
||||
tf.cast(trajectory["action"]["open_gripper"][:, None], tf.float32),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
||||
trajectory = relabel_bridge_actions(trajectory)
|
||||
trajectory["observation"]["EEF_state"] = trajectory["observation"]["state"][:, :6]
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
||||
return trajectory
|
||||
|
||||
|
||||
def bridge_orig_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Applies to original version of Bridge V2 from the official project website.
|
||||
|
||||
Note =>> In original Bridge V2 dataset, the first timestep has an all-zero action, so we remove it!
|
||||
"""
|
||||
for key in trajectory:
|
||||
if key == "traj_metadata":
|
||||
continue
|
||||
elif key == "observation":
|
||||
for key2 in trajectory[key]:
|
||||
trajectory[key][key2] = trajectory[key][key2][1:]
|
||||
else:
|
||||
trajectory[key] = trajectory[key][1:]
|
||||
|
||||
trajectory["action"] = tf.concat(
|
||||
[
|
||||
trajectory["action"][:, :6],
|
||||
binarize_gripper_actions(trajectory["action"][:, -1])[:, None],
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
trajectory = relabel_bridge_actions(trajectory)
|
||||
trajectory["observation"]["EEF_state"] = trajectory["observation"]["state"][:, :6]
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
||||
return trajectory
|
||||
|
||||
|
||||
def ppgm_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["action"] = tf.concat(
|
||||
[
|
||||
trajectory["action"][:, :6],
|
||||
binarize_gripper_actions(trajectory["action"][:, -1])[:, None],
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
trajectory["observation"]["EEF_state"] = trajectory["observation"]["cartesian_position"][:, :6]
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["gripper_position"][:, -1:]
|
||||
return trajectory
|
||||
|
||||
|
||||
def rt1_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# make gripper action absolute action, +1 = open, 0 = close
|
||||
gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0]
|
||||
gripper_action = rel2abs_gripper_actions(gripper_action)
|
||||
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"]["world_vector"],
|
||||
trajectory["action"]["rotation_delta"],
|
||||
gripper_action[:, None],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
||||
return trajectory
|
||||
|
||||
|
||||
def kuka_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# make gripper action absolute action, +1 = open, 0 = close
|
||||
gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0]
|
||||
gripper_action = rel2abs_gripper_actions(gripper_action)
|
||||
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"]["world_vector"],
|
||||
trajectory["action"]["rotation_delta"],
|
||||
gripper_action[:, None],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
# decode compressed state
|
||||
eef_value = tf.io.decode_compressed(
|
||||
trajectory["observation"]["clip_function_input/base_pose_tool_reached"],
|
||||
compression_type="ZLIB",
|
||||
)
|
||||
eef_value = tf.io.decode_raw(eef_value, tf.float32)
|
||||
trajectory["observation"]["clip_function_input/base_pose_tool_reached"] = tf.reshape(eef_value, (-1, 7))
|
||||
gripper_value = tf.io.decode_compressed(
|
||||
trajectory["observation"]["gripper_closed"], compression_type="ZLIB"
|
||||
)
|
||||
gripper_value = tf.io.decode_raw(gripper_value, tf.float32)
|
||||
trajectory["observation"]["gripper_closed"] = tf.reshape(gripper_value, (-1, 1))
|
||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
||||
return trajectory
|
||||
|
||||
|
||||
def taco_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["state_eef"] = trajectory["observation"]["robot_obs"][:, :6]
|
||||
trajectory["observation"]["state_gripper"] = trajectory["observation"]["robot_obs"][:, 7:8]
|
||||
trajectory["action"] = trajectory["action"]["rel_actions_world"]
|
||||
|
||||
# invert gripper action + clip, +1 = open, 0 = close
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :6],
|
||||
tf.clip_by_value(trajectory["action"][:, -1:], 0, 1),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
||||
return trajectory
|
||||
|
||||
|
||||
def jaco_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["state_eef"] = trajectory["observation"]["end_effector_cartesian_pos"][:, :6]
|
||||
trajectory["observation"]["state_gripper"] = trajectory["observation"]["end_effector_cartesian_pos"][
|
||||
:, -1:
|
||||
]
|
||||
|
||||
# make gripper action absolute action, +1 = open, 0 = close
|
||||
gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0]
|
||||
gripper_action = rel2abs_gripper_actions(gripper_action)
|
||||
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"]["world_vector"],
|
||||
tf.zeros_like(trajectory["action"]["world_vector"]),
|
||||
gripper_action[:, None],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
||||
return trajectory
|
||||
|
||||
|
||||
def berkeley_cable_routing_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"]["world_vector"],
|
||||
trajectory["action"]["rotation_delta"],
|
||||
tf.zeros_like(trajectory["action"]["world_vector"][:, :1]),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
||||
return trajectory
|
||||
|
||||
|
||||
def roboturk_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# invert absolute gripper action, +1 = open, 0 = close
|
||||
gripper_action = invert_gripper_actions(
|
||||
tf.clip_by_value(trajectory["action"]["gripper_closedness_action"], 0, 1)
|
||||
)
|
||||
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"]["world_vector"],
|
||||
trajectory["action"]["rotation_delta"],
|
||||
gripper_action,
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
||||
trajectory["language_embedding"] = trajectory["observation"]["natural_language_embedding"]
|
||||
return trajectory
|
||||
|
||||
|
||||
def nyu_door_opening_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# make gripper action absolute action, +1 = open, 0 = close
|
||||
gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0]
|
||||
gripper_action = rel2abs_gripper_actions(gripper_action)
|
||||
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"]["world_vector"],
|
||||
trajectory["action"]["rotation_delta"],
|
||||
gripper_action[:, None],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
||||
return trajectory
|
||||
|
||||
|
||||
def viola_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# make gripper action, +1 = open, 0 = close
|
||||
gripper_action = trajectory["action"]["gripper_closedness_action"][:, None]
|
||||
gripper_action = tf.clip_by_value(gripper_action, 0, 1)
|
||||
gripper_action = invert_gripper_actions(gripper_action)
|
||||
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"]["world_vector"],
|
||||
trajectory["action"]["rotation_delta"],
|
||||
gripper_action,
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
||||
return trajectory
|
||||
|
||||
|
||||
def berkeley_autolab_ur5_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["state"] = trajectory["observation"]["robot_state"][:, 6:14]
|
||||
|
||||
# make gripper action absolute action, +1 = open, 0 = close
|
||||
gripper_action = trajectory["action"]["gripper_closedness_action"]
|
||||
gripper_action = rel2abs_gripper_actions(gripper_action)
|
||||
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"]["world_vector"],
|
||||
trajectory["action"]["rotation_delta"],
|
||||
gripper_action[:, None],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
||||
return trajectory
|
||||
|
||||
|
||||
def toto_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"]["world_vector"],
|
||||
trajectory["action"]["rotation_delta"],
|
||||
tf.cast(trajectory["action"]["open_gripper"][:, None], tf.float32),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
||||
return trajectory
|
||||
|
||||
|
||||
def language_table_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# default to "open" gripper
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"],
|
||||
tf.zeros_like(trajectory["action"]),
|
||||
tf.zeros_like(trajectory["action"]),
|
||||
tf.ones_like(trajectory["action"][:, :1]),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
# decode language instruction
|
||||
instruction_bytes = trajectory["observation"]["instruction"]
|
||||
instruction_encoded = tf.strings.unicode_encode(instruction_bytes, output_encoding="UTF-8")
|
||||
# Remove trailing padding --> convert RaggedTensor to regular Tensor.
|
||||
trajectory["language_instruction"] = tf.strings.split(instruction_encoded, "\x00")[:, :1].to_tensor()[
|
||||
:, 0
|
||||
]
|
||||
return trajectory
|
||||
|
||||
|
||||
def pusht_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"]["world_vector"],
|
||||
trajectory["action"]["rotation_delta"],
|
||||
trajectory["action"]["gripper_closedness_action"][:, None],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
||||
return trajectory
|
||||
|
||||
|
||||
def stanford_kuka_multimodal_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["depth_image"] = trajectory["observation"]["depth_image"][..., 0]
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :3],
|
||||
tf.zeros_like(trajectory["action"][:, :3]),
|
||||
trajectory["action"][:, -1:],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def nyu_rot_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][..., :6]
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][..., -1:]
|
||||
trajectory["action"] = trajectory["action"][..., :7]
|
||||
return trajectory
|
||||
|
||||
|
||||
def stanford_hydra_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# invert gripper action, +1 = open, 0 = close
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :6],
|
||||
invert_gripper_actions(trajectory["action"][:, -1:]),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
trajectory["observation"]["eef_state"] = tf.concat(
|
||||
(
|
||||
trajectory["observation"]["state"][:, :3],
|
||||
trajectory["observation"]["state"][:, 7:10],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -3:-2]
|
||||
return trajectory
|
||||
|
||||
|
||||
def austin_buds_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# invert gripper action + clip, +1 = open, 0 = close
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :6],
|
||||
invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
trajectory["observation"]["state"] = trajectory["observation"]["state"][:, :8]
|
||||
return trajectory
|
||||
|
||||
|
||||
def nyu_franka_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["depth"] = tf.cast(trajectory["observation"]["depth"][..., 0], tf.float32)
|
||||
trajectory["observation"]["depth_additional_view"] = tf.cast(
|
||||
trajectory["observation"]["depth_additional_view"][..., 0], tf.float32
|
||||
)
|
||||
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, -6:]
|
||||
|
||||
# clip gripper action, +1 = open, 0 = close
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, -8:-2],
|
||||
tf.clip_by_value(trajectory["action"][:, -2:-1], 0, 1),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def maniskill_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][..., 7:8]
|
||||
return trajectory
|
||||
|
||||
|
||||
def furniture_bench_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
import tensorflow_graphics.geometry.transformation as tft
|
||||
|
||||
trajectory["observation"]["state"] = tf.concat(
|
||||
(
|
||||
trajectory["observation"]["state"][:, :7],
|
||||
trajectory["observation"]["state"][:, -1:],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
# invert gripper action + clip, +1 = open, 0 = close
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :3],
|
||||
tft.euler.from_quaternion(trajectory["action"][:, 3:7]),
|
||||
invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def cmu_franka_exploration_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["action"] = trajectory["action"][..., :-1]
|
||||
return trajectory
|
||||
|
||||
|
||||
def ucsd_kitchen_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["joint_state"] = trajectory["observation"]["state"][:, :7]
|
||||
trajectory["action"] = trajectory["action"][..., :-1]
|
||||
return trajectory
|
||||
|
||||
|
||||
def ucsd_pick_place_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :3],
|
||||
tf.zeros_like(trajectory["action"][:, :3]),
|
||||
trajectory["action"][:, -1:],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def austin_sailor_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# invert gripper action + clip, +1 = open, 0 = close
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :6],
|
||||
invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def austin_sirius_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# invert gripper action + clip, +1 = open, 0 = close
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :6],
|
||||
invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def bc_z_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"]["future/xyz_residual"][:, :3],
|
||||
trajectory["action"]["future/axis_angle_residual"][:, :3],
|
||||
invert_gripper_actions(tf.cast(trajectory["action"]["future/target_close"][:, :1], tf.float32)),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
||||
return trajectory
|
||||
|
||||
|
||||
def tokyo_pr2_opening_fridge_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
||||
trajectory["action"] = trajectory["action"][..., :-1]
|
||||
return trajectory
|
||||
|
||||
|
||||
def tokyo_pr2_tabletop_manipulation_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
||||
trajectory["action"] = trajectory["action"][..., :-1]
|
||||
return trajectory
|
||||
|
||||
|
||||
def utokyo_xarm_bimanual_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["action"] = trajectory["action"][..., -7:]
|
||||
return trajectory
|
||||
|
||||
|
||||
def robo_net_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["eef_state"] = tf.concat(
|
||||
(
|
||||
trajectory["observation"]["state"][:, :4],
|
||||
tf.zeros_like(trajectory["observation"]["state"][:, :2]),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :4],
|
||||
tf.zeros_like(trajectory["action"][:, :2]),
|
||||
trajectory["action"][:, -1:],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def berkeley_mvp_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
trajectory["observation"]["state"] = tf.concat((
|
||||
tf.cast(trajectory["observation"]["gripper"][:, None], tf.float32),
|
||||
trajectory["observation"]["pose"],
|
||||
trajectory["observation"]["joint_pos"],),
|
||||
axis=-1,)
|
||||
"""
|
||||
trajectory["observation"]["gripper"] = tf.cast(trajectory["observation"]["gripper"][:, None], tf.float32)
|
||||
return trajectory
|
||||
|
||||
|
||||
def berkeley_rpt_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["gripper"] = tf.cast(trajectory["observation"]["gripper"][:, None], tf.float32)
|
||||
return trajectory
|
||||
|
||||
|
||||
def kaist_nonprehensible_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["state"] = trajectory["observation"]["state"][:, -7:]
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :6],
|
||||
tf.zeros_like(trajectory["action"][:, :1]),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def stanford_mask_vit_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["eef_state"] = tf.concat(
|
||||
(
|
||||
trajectory["observation"]["end_effector_pose"][:, :4],
|
||||
tf.zeros_like(trajectory["observation"]["end_effector_pose"][:, :2]),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["end_effector_pose"][:, -1:]
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :4],
|
||||
tf.zeros_like(trajectory["action"][:, :2]),
|
||||
trajectory["action"][:, -1:],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def tokyo_lsmo_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
||||
return trajectory
|
||||
|
||||
|
||||
def dlr_sara_grid_clamp_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["state"] = trajectory["observation"]["state"][:, :6]
|
||||
return trajectory
|
||||
|
||||
|
||||
def dlr_edan_shared_control_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# invert gripper action, +1 = open, 0 = close
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :6],
|
||||
invert_gripper_actions(trajectory["action"][:, -1:]),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def asu_table_top_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["eef_state"] = trajectory["ground_truth_states"]["EE"]
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
||||
return trajectory
|
||||
|
||||
|
||||
def robocook_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
||||
return trajectory
|
||||
|
||||
|
||||
def imperial_wristcam_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["action"] = trajectory["action"][..., :-1]
|
||||
return trajectory
|
||||
|
||||
|
||||
def iamlab_pick_insert_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
import tensorflow_graphics.geometry.transformation as tft
|
||||
|
||||
trajectory["observation"]["joint_state"] = trajectory["observation"]["state"][:, :7]
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, 7:8]
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :3],
|
||||
tft.euler.from_quaternion(trajectory["action"][:, 3:7]),
|
||||
trajectory["action"][:, 7:8],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def uiuc_d3field_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"],
|
||||
tf.zeros_like(trajectory["action"]),
|
||||
tf.zeros_like(trajectory["action"][:, :1]),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def utaustin_mutex_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["state"] = trajectory["observation"]["state"][:, :8]
|
||||
|
||||
# invert gripper action + clip, +1 = open, 0 = close
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :6],
|
||||
invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def berkeley_fanuc_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["joint_state"] = trajectory["observation"]["state"][:, :6]
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, 6:7]
|
||||
|
||||
# dataset does not store gripper actions, so use gripper state info, invert so +1 = open, 0 = close
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"],
|
||||
invert_gripper_actions(trajectory["observation"]["gripper_state"]),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def cmu_playing_with_food_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
import tensorflow_graphics.geometry.transformation as tft
|
||||
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :3],
|
||||
tft.euler.from_quaternion(trajectory["action"][:, 3:7]),
|
||||
trajectory["action"][:, -1:],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def playfusion_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :3],
|
||||
trajectory["action"][:, -4:],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def cmu_stretch_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["eef_state"] = tf.concat(
|
||||
(
|
||||
trajectory["observation"]["state"][:, :3],
|
||||
tf.zeros_like(trajectory["observation"]["state"][:, :3]),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
||||
trajectory["action"] = trajectory["action"][..., :-1]
|
||||
return trajectory
|
||||
|
||||
|
||||
def gnm_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["state"] = tf.concat(
|
||||
(
|
||||
trajectory["observation"]["position"],
|
||||
tf.zeros_like(trajectory["observation"]["state"][:, :3]),
|
||||
trajectory["observation"]["yaw"],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"],
|
||||
tf.zeros_like(trajectory["action"]),
|
||||
tf.zeros_like(trajectory["action"]),
|
||||
tf.zeros_like(trajectory["action"][:, :1]),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def fmb_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# every input feature is batched, ie has leading batch dimension
|
||||
trajectory["observation"]["proprio"] = tf.concat(
|
||||
(
|
||||
trajectory["observation"]["eef_pose"],
|
||||
trajectory["observation"]["state_gripper_pose"][..., None],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def dobbe_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# every input feature is batched, ie has leading batch dimension
|
||||
trajectory["observation"]["proprio"] = trajectory["observation"]["state"]
|
||||
return trajectory
|
||||
|
||||
|
||||
def robo_set_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# gripper action is in -1...1 --> clip to 0...1, flip
|
||||
gripper_action = trajectory["action"][:, -1:]
|
||||
gripper_action = invert_gripper_actions(tf.clip_by_value(gripper_action, 0, 1))
|
||||
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :7],
|
||||
gripper_action,
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def identity_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return trajectory
|
||||
|
||||
|
||||
# === Registry ===
|
||||
OPENX_STANDARDIZATION_TRANSFORMS = {
|
||||
"bridge_openx": bridge_openx_dataset_transform,
|
||||
"bridge_orig": bridge_orig_dataset_transform,
|
||||
"bridge_dataset": bridge_orig_dataset_transform,
|
||||
"ppgm": ppgm_dataset_transform,
|
||||
"ppgm_static": ppgm_dataset_transform,
|
||||
"ppgm_wrist": ppgm_dataset_transform,
|
||||
"fractal20220817_data": rt1_dataset_transform,
|
||||
"kuka": kuka_dataset_transform,
|
||||
"taco_play": taco_play_dataset_transform,
|
||||
"jaco_play": jaco_play_dataset_transform,
|
||||
"berkeley_cable_routing": berkeley_cable_routing_dataset_transform,
|
||||
"roboturk": roboturk_dataset_transform,
|
||||
"nyu_door_opening_surprising_effectiveness": nyu_door_opening_dataset_transform,
|
||||
"viola": viola_dataset_transform,
|
||||
"berkeley_autolab_ur5": berkeley_autolab_ur5_dataset_transform,
|
||||
"toto": toto_dataset_transform,
|
||||
"language_table": language_table_dataset_transform,
|
||||
"columbia_cairlab_pusht_real": pusht_dataset_transform,
|
||||
"stanford_kuka_multimodal_dataset_converted_externally_to_rlds": stanford_kuka_multimodal_dataset_transform,
|
||||
"nyu_rot_dataset_converted_externally_to_rlds": nyu_rot_dataset_transform,
|
||||
"stanford_hydra_dataset_converted_externally_to_rlds": stanford_hydra_dataset_transform,
|
||||
"austin_buds_dataset_converted_externally_to_rlds": austin_buds_dataset_transform,
|
||||
"nyu_franka_play_dataset_converted_externally_to_rlds": nyu_franka_play_dataset_transform,
|
||||
"maniskill_dataset_converted_externally_to_rlds": maniskill_dataset_transform,
|
||||
"furniture_bench_dataset_converted_externally_to_rlds": furniture_bench_dataset_transform,
|
||||
"cmu_franka_exploration_dataset_converted_externally_to_rlds": cmu_franka_exploration_dataset_transform,
|
||||
"ucsd_kitchen_dataset_converted_externally_to_rlds": ucsd_kitchen_dataset_transform,
|
||||
"ucsd_pick_and_place_dataset_converted_externally_to_rlds": ucsd_pick_place_dataset_transform,
|
||||
"austin_sailor_dataset_converted_externally_to_rlds": austin_sailor_dataset_transform,
|
||||
"austin_sirius_dataset_converted_externally_to_rlds": austin_sirius_dataset_transform,
|
||||
"bc_z": bc_z_dataset_transform,
|
||||
"utokyo_pr2_opening_fridge_converted_externally_to_rlds": tokyo_pr2_opening_fridge_dataset_transform,
|
||||
"utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds": tokyo_pr2_tabletop_manipulation_dataset_transform,
|
||||
"utokyo_xarm_pick_and_place_converted_externally_to_rlds": identity_transform,
|
||||
"utokyo_xarm_bimanual_converted_externally_to_rlds": utokyo_xarm_bimanual_dataset_transform,
|
||||
"robo_net": robo_net_dataset_transform,
|
||||
"berkeley_mvp_converted_externally_to_rlds": berkeley_mvp_dataset_transform,
|
||||
"berkeley_rpt_converted_externally_to_rlds": berkeley_rpt_dataset_transform,
|
||||
"kaist_nonprehensile_converted_externally_to_rlds": kaist_nonprehensible_dataset_transform,
|
||||
"stanford_mask_vit_converted_externally_to_rlds": stanford_mask_vit_dataset_transform,
|
||||
"tokyo_u_lsmo_converted_externally_to_rlds": tokyo_lsmo_dataset_transform,
|
||||
"dlr_sara_pour_converted_externally_to_rlds": identity_transform,
|
||||
"dlr_sara_grid_clamp_converted_externally_to_rlds": dlr_sara_grid_clamp_dataset_transform,
|
||||
"dlr_edan_shared_control_converted_externally_to_rlds": dlr_edan_shared_control_dataset_transform,
|
||||
"asu_table_top_converted_externally_to_rlds": asu_table_top_dataset_transform,
|
||||
"stanford_robocook_converted_externally_to_rlds": robocook_dataset_transform,
|
||||
"imperialcollege_sawyer_wrist_cam": imperial_wristcam_dataset_transform,
|
||||
"iamlab_cmu_pickup_insert_converted_externally_to_rlds": iamlab_pick_insert_dataset_transform,
|
||||
"uiuc_d3field": uiuc_d3field_dataset_transform,
|
||||
"utaustin_mutex": utaustin_mutex_dataset_transform,
|
||||
"berkeley_fanuc_manipulation": berkeley_fanuc_dataset_transform,
|
||||
"cmu_playing_with_food": cmu_playing_with_food_dataset_transform,
|
||||
"cmu_play_fusion": playfusion_dataset_transform,
|
||||
"cmu_stretch": cmu_stretch_dataset_transform,
|
||||
"berkeley_gnm_recon": gnm_dataset_transform,
|
||||
"berkeley_gnm_cory_hall": gnm_dataset_transform,
|
||||
"berkeley_gnm_sac_son": gnm_dataset_transform,
|
||||
"droid": droid_baseact_transform_fn(),
|
||||
"droid_100": droid_baseact_transform_fn(), # first 100 episodes of droid
|
||||
"fmb": fmb_transform,
|
||||
"dobbe": dobbe_dataset_transform,
|
||||
"robo_set": robo_set_dataset_transform,
|
||||
"usc_cloth_sim_converted_externally_to_rlds": identity_transform,
|
||||
"plex_robosuite": identity_transform,
|
||||
"conq_hose_manipulation": identity_transform,
|
||||
"io_ai_tech": identity_transform,
|
||||
"spoc": identity_transform,
|
||||
}
|
|
@ -0,0 +1,359 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
For https://github.com/google-deepmind/open_x_embodiment (OPENX) datasets.
|
||||
|
||||
Example:
|
||||
python lerobot/scripts/push_dataset_to_hub.py \
|
||||
--raw-dir /hdd/tensorflow_datasets/bridge_dataset/1.0.0/ \
|
||||
--repo-id youliangtan/sampled_bridge_data_v2 \
|
||||
--raw-format openx_rlds.bridge_orig \
|
||||
--episodes 3 4 5 8 9
|
||||
|
||||
Exact dataset fps defined in openx/config.py, obtained from:
|
||||
https://docs.google.com/spreadsheets/d/1rPBD77tk60AEIGZrGSODwyyzs5FgCU9Uz3h-3_t2A9g/edit?gid=0#gid=0&range=R:R
|
||||
"""
|
||||
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import tensorflow_datasets as tfds
|
||||
import torch
|
||||
import tqdm
|
||||
import yaml
|
||||
from datasets import Dataset, Features, Image, Sequence, Value
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
from lerobot.common.datasets.push_dataset_to_hub.openx.transforms import OPENX_STANDARDIZATION_TRANSFORMS
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||
concatenate_episodes,
|
||||
get_default_encoding,
|
||||
save_images_concurrently,
|
||||
)
|
||||
from lerobot.common.datasets.utils import (
|
||||
calculate_episode_data_index,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||
|
||||
with open("lerobot/common/datasets/push_dataset_to_hub/openx/configs.yaml", "r") as f:
|
||||
_openx_list = yaml.safe_load(f)
|
||||
|
||||
OPENX_DATASET_CONFIGS = _openx_list["OPENX_DATASET_CONFIGS"]
|
||||
|
||||
np.set_printoptions(precision=2)
|
||||
|
||||
|
||||
def tf_to_torch(data):
|
||||
return torch.from_numpy(data.numpy())
|
||||
|
||||
|
||||
def tf_img_convert(img):
|
||||
if img.dtype == tf.string:
|
||||
img = tf.io.decode_image(img, expand_animations=False, dtype=tf.uint8)
|
||||
elif img.dtype != tf.uint8:
|
||||
raise ValueError(f"Unsupported image dtype: found with dtype {img.dtype}")
|
||||
return img.numpy()
|
||||
|
||||
|
||||
def _broadcast_metadata_rlds(i: tf.Tensor, traj: dict) -> dict:
|
||||
"""
|
||||
In the RLDS format, each trajectory has some top-level metadata that is explicitly separated out, and a "steps"
|
||||
entry. This function moves the "steps" entry to the top level, broadcasting any metadata to the length of the
|
||||
trajectory. This function also adds the extra metadata fields `_len`, `_traj_index`, and `_frame_index`.
|
||||
|
||||
NOTE: adapted from DLimp library https://github.com/kvablack/dlimp/
|
||||
"""
|
||||
steps = traj.pop("steps")
|
||||
|
||||
traj_len = tf.shape(tf.nest.flatten(steps)[0])[0]
|
||||
|
||||
# broadcast metadata to the length of the trajectory
|
||||
metadata = tf.nest.map_structure(lambda x: tf.repeat(x, traj_len), traj)
|
||||
|
||||
# put steps back in
|
||||
assert "traj_metadata" not in steps
|
||||
traj = {**steps, "traj_metadata": metadata}
|
||||
|
||||
assert "_len" not in traj
|
||||
assert "_traj_index" not in traj
|
||||
assert "_frame_index" not in traj
|
||||
traj["_len"] = tf.repeat(traj_len, traj_len)
|
||||
traj["_traj_index"] = tf.repeat(i, traj_len)
|
||||
traj["_frame_index"] = tf.range(traj_len)
|
||||
|
||||
return traj
|
||||
|
||||
|
||||
def load_from_raw(
|
||||
raw_dir: Path,
|
||||
videos_dir: Path,
|
||||
fps: int,
|
||||
video: bool,
|
||||
episodes: list[int] | None = None,
|
||||
encoding: dict | None = None,
|
||||
openx_dataset_name: str | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
raw_dir (Path): _description_
|
||||
videos_dir (Path): _description_
|
||||
fps (int): _description_
|
||||
video (bool): _description_
|
||||
episodes (list[int] | None, optional): _description_. Defaults to None.
|
||||
"""
|
||||
ds_builder = tfds.builder_from_directory(str(raw_dir))
|
||||
dataset = ds_builder.as_dataset(
|
||||
split="all",
|
||||
decoders={"steps": tfds.decode.SkipDecoding()},
|
||||
)
|
||||
|
||||
dataset_info = ds_builder.info
|
||||
print("dataset_info: ", dataset_info)
|
||||
|
||||
ds_length = len(dataset)
|
||||
dataset = dataset.take(ds_length)
|
||||
# "flatten" the dataset as such we can apply trajectory level map() easily
|
||||
# each [obs][key] has a shape of (frame_size, ...)
|
||||
dataset = dataset.enumerate().map(_broadcast_metadata_rlds)
|
||||
|
||||
# we will apply the standardization transform if the dataset_name is provided
|
||||
# if the dataset name is not provided and the goal is to convert any rlds formatted dataset
|
||||
# search for 'image' keys in the observations
|
||||
if openx_dataset_name is not None:
|
||||
print(" - applying standardization transform for dataset: ", openx_dataset_name)
|
||||
assert openx_dataset_name in OPENX_STANDARDIZATION_TRANSFORMS
|
||||
transform_fn = OPENX_STANDARDIZATION_TRANSFORMS[openx_dataset_name]
|
||||
dataset = dataset.map(transform_fn)
|
||||
|
||||
image_keys = OPENX_DATASET_CONFIGS[openx_dataset_name]["image_obs_keys"]
|
||||
else:
|
||||
obs_keys = dataset_info.features["steps"]["observation"].keys()
|
||||
image_keys = [key for key in obs_keys if "image" in key]
|
||||
|
||||
lang_key = "language_instruction" if "language_instruction" in dataset.element_spec else None
|
||||
|
||||
print(" - image_keys: ", image_keys)
|
||||
print(" - lang_key: ", lang_key)
|
||||
|
||||
it = iter(dataset)
|
||||
|
||||
ep_dicts = []
|
||||
# Init temp path to save ep_dicts in case of crash
|
||||
tmp_ep_dicts_dir = videos_dir.parent.joinpath("ep_dicts")
|
||||
tmp_ep_dicts_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# check if ep_dicts have already been saved in /tmp
|
||||
starting_ep_idx = 0
|
||||
saved_ep_dicts = [ep.__str__() for ep in tmp_ep_dicts_dir.iterdir()]
|
||||
if len(saved_ep_dicts) > 0:
|
||||
saved_ep_dicts.sort()
|
||||
# get last ep_idx number
|
||||
starting_ep_idx = int(saved_ep_dicts[-1][-13:-3]) + 1
|
||||
for i in range(starting_ep_idx):
|
||||
episode = next(it)
|
||||
ep_dicts.append(torch.load(saved_ep_dicts[i]))
|
||||
|
||||
# if we user specified episodes, skip the ones not in the list
|
||||
if episodes is not None:
|
||||
if ds_length == 0:
|
||||
raise ValueError("No episodes found.")
|
||||
# convert episodes index to sorted list
|
||||
episodes = sorted(episodes)
|
||||
|
||||
for ep_idx in tqdm.tqdm(range(starting_ep_idx, ds_length)):
|
||||
episode = next(it)
|
||||
|
||||
# if user specified episodes, skip the ones not in the list
|
||||
if episodes is not None:
|
||||
if len(episodes) == 0:
|
||||
break
|
||||
if ep_idx == episodes[0]:
|
||||
# process this episode
|
||||
print(" selecting episode idx: ", ep_idx)
|
||||
episodes.pop(0)
|
||||
else:
|
||||
continue # skip
|
||||
|
||||
num_frames = episode["action"].shape[0]
|
||||
|
||||
###########################################################
|
||||
# Handle the episodic data
|
||||
|
||||
# last step of demonstration is considered done
|
||||
done = torch.zeros(num_frames, dtype=torch.bool)
|
||||
done[-1] = True
|
||||
ep_dict = {}
|
||||
langs = [] # TODO: might be located in "observation"
|
||||
|
||||
image_array_dict = {key: [] for key in image_keys}
|
||||
|
||||
# We will create the state observation tensor by stacking the state
|
||||
# obs keys defined in the openx/configs.py
|
||||
if openx_dataset_name is not None:
|
||||
state_obs_keys = OPENX_DATASET_CONFIGS[openx_dataset_name]["state_obs_keys"]
|
||||
# stack the state observations, if is None, pad with zeros
|
||||
states = []
|
||||
for key in state_obs_keys:
|
||||
if key in episode["observation"]:
|
||||
states.append(tf_to_torch(episode["observation"][key]))
|
||||
else:
|
||||
states.append(torch.zeros(num_frames, 1)) # pad with zeros
|
||||
states = torch.cat(states, dim=1)
|
||||
# assert states.shape == (num_frames, 8), f"states shape: {states.shape}"
|
||||
else:
|
||||
states = tf_to_torch(episode["observation"]["state"])
|
||||
|
||||
actions = tf_to_torch(episode["action"])
|
||||
rewards = tf_to_torch(episode["reward"]).float()
|
||||
|
||||
# If lang_key is present, convert the entire tensor at once
|
||||
if lang_key is not None:
|
||||
langs = [str(x) for x in episode[lang_key]]
|
||||
|
||||
for im_key in image_keys:
|
||||
imgs = episode["observation"][im_key]
|
||||
image_array_dict[im_key] = [tf_img_convert(img) for img in imgs]
|
||||
|
||||
# simple assertions
|
||||
for item in [states, actions, rewards, done]:
|
||||
assert len(item) == num_frames
|
||||
|
||||
###########################################################
|
||||
|
||||
# loop through all cameras
|
||||
for im_key in image_keys:
|
||||
img_key = f"observation.images.{im_key}"
|
||||
imgs_array = image_array_dict[im_key]
|
||||
imgs_array = np.array(imgs_array)
|
||||
if video:
|
||||
# save png images in temporary directory
|
||||
tmp_imgs_dir = videos_dir / "tmp_images"
|
||||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||
|
||||
# encode images to a mp4 video
|
||||
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
||||
video_path = videos_dir / fname
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps, **(encoding or {}))
|
||||
|
||||
# clean temporary images directory
|
||||
shutil.rmtree(tmp_imgs_dir)
|
||||
|
||||
# store the reference to the video frame
|
||||
ep_dict[img_key] = [
|
||||
{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
|
||||
]
|
||||
else:
|
||||
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
||||
|
||||
if lang_key is not None:
|
||||
ep_dict["language_instruction"] = langs
|
||||
|
||||
ep_dict["observation.state"] = states
|
||||
ep_dict["action"] = actions
|
||||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames)
|
||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||
ep_dict["next.reward"] = rewards
|
||||
ep_dict["next.done"] = done
|
||||
|
||||
path_ep_dict = tmp_ep_dicts_dir.joinpath(
|
||||
"ep_dict_" + "0" * (10 - len(str(ep_idx))) + str(ep_idx) + ".pt"
|
||||
)
|
||||
torch.save(ep_dict, path_ep_dict)
|
||||
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
data_dict = concatenate_episodes(ep_dicts)
|
||||
|
||||
total_frames = data_dict["frame_index"].shape[0]
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
return data_dict
|
||||
|
||||
|
||||
def to_hf_dataset(data_dict, video) -> Dataset:
|
||||
features = {}
|
||||
|
||||
keys = [key for key in data_dict if "observation.images." in key]
|
||||
for key in keys:
|
||||
if video:
|
||||
features[key] = VideoFrame()
|
||||
else:
|
||||
features[key] = Image()
|
||||
|
||||
features["observation.state"] = Sequence(
|
||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
if "observation.velocity" in data_dict:
|
||||
features["observation.velocity"] = Sequence(
|
||||
length=data_dict["observation.velocity"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
if "observation.effort" in data_dict:
|
||||
features["observation.effort"] = Sequence(
|
||||
length=data_dict["observation.effort"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
if "language_instruction" in data_dict:
|
||||
features["language_instruction"] = Value(dtype="string", id=None)
|
||||
|
||||
features["action"] = Sequence(
|
||||
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
features["episode_index"] = Value(dtype="int64", id=None)
|
||||
features["frame_index"] = Value(dtype="int64", id=None)
|
||||
features["timestamp"] = Value(dtype="float32", id=None)
|
||||
features["next.reward"] = Value(dtype="float32", id=None)
|
||||
features["next.done"] = Value(dtype="bool", id=None)
|
||||
features["index"] = Value(dtype="int64", id=None)
|
||||
|
||||
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return hf_dataset
|
||||
|
||||
|
||||
def from_raw_to_lerobot_format(
|
||||
raw_dir: Path,
|
||||
videos_dir: Path,
|
||||
fps: int | None = None,
|
||||
video: bool = True,
|
||||
episodes: list[int] | None = None,
|
||||
encoding: dict | None = None,
|
||||
openx_dataset_name: str | None = None,
|
||||
):
|
||||
"""This is a test impl for rlds conversion"""
|
||||
if openx_dataset_name is None:
|
||||
# set a default rlds frame rate if the dataset is not from openx
|
||||
fps = 30
|
||||
elif "fps" not in OPENX_DATASET_CONFIGS[openx_dataset_name]:
|
||||
raise ValueError(
|
||||
"fps for this dataset is not specified in openx/configs.py yet," "means it is not yet tested"
|
||||
)
|
||||
fps = OPENX_DATASET_CONFIGS[openx_dataset_name]["fps"]
|
||||
|
||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, encoding, openx_dataset_name)
|
||||
hf_dataset = to_hf_dataset(data_dict, video)
|
||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||
info = {
|
||||
"codebase_version": CODEBASE_VERSION,
|
||||
"fps": fps,
|
||||
"video": video,
|
||||
}
|
||||
if video:
|
||||
info["encoding"] = get_default_encoding()
|
||||
|
||||
return hf_dataset, episode_data_index, info
|
|
@ -80,6 +80,11 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
|||
if isinstance(first_item, PILImage.Image):
|
||||
to_tensor = transforms.ToTensor()
|
||||
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
|
||||
elif isinstance(first_item, str):
|
||||
# TODO (michel-aractingi): add str2embedding via language tokenizer
|
||||
# For now we leave this part up to the user to choose how to address
|
||||
# language conditioned tasks
|
||||
pass
|
||||
elif isinstance(first_item, dict) and "path" in first_item and "timestamp" in first_item:
|
||||
# video frame will be processed downstream
|
||||
pass
|
||||
|
|
|
@ -66,6 +66,8 @@ def get_from_raw_to_lerobot_format_fn(raw_format: str):
|
|||
from lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format import from_raw_to_lerobot_format
|
||||
elif raw_format == "aloha_hdf5":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import from_raw_to_lerobot_format
|
||||
elif "openx_rlds" in raw_format:
|
||||
from lerobot.common.datasets.push_dataset_to_hub.openx_rlds_format import from_raw_to_lerobot_format
|
||||
elif raw_format == "dora_parquet":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.dora_parquet_format import from_raw_to_lerobot_format
|
||||
elif raw_format == "xarm_pkl":
|
||||
|
@ -197,9 +199,25 @@ def push_dataset_to_hub(
|
|||
|
||||
# convert dataset from original raw format to LeRobot format
|
||||
from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format)
|
||||
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(
|
||||
raw_dir, videos_dir, fps, video, episodes, encoding
|
||||
)
|
||||
|
||||
fmt_kwgs = {
|
||||
"raw_dir": raw_dir,
|
||||
"videos_dir": videos_dir,
|
||||
"fps": fps,
|
||||
"video": video,
|
||||
"episodes": episodes,
|
||||
"encoding": encoding,
|
||||
}
|
||||
|
||||
if "openx_rlds." in raw_format:
|
||||
# Support for official OXE dataset name inside `raw_format`.
|
||||
# For instance, `raw_format="oxe_rlds"` uses the default formating (TODO what does that mean?),
|
||||
# and `raw_format="oxe_rlds.bridge_orig"` uses the brdige_orig formating
|
||||
_, openx_dataset_name = raw_format.split(".")
|
||||
print(f"Converting dataset [{openx_dataset_name}] from 'openx_rlds' to LeRobot format.")
|
||||
fmt_kwgs["openx_dataset_name"] = openx_dataset_name
|
||||
|
||||
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(**fmt_kwgs)
|
||||
|
||||
lerobot_dataset = LeRobotDataset.from_preloaded(
|
||||
repo_id=repo_id,
|
||||
|
@ -268,7 +286,7 @@ def main():
|
|||
"--raw-format",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Dataset type (e.g. `pusht_zarr`, `umi_zarr`, `aloha_hdf5`, `xarm_pkl`, `dora_parquet`).",
|
||||
help="Dataset type (e.g. `pusht_zarr`, `umi_zarr`, `aloha_hdf5`, `xarm_pkl`, `dora_parquet`, `openx_rlds`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
|
@ -328,6 +346,13 @@ def main():
|
|||
default=0,
|
||||
help="When set to 1, resumes a previous run.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache-dir",
|
||||
type=Path,
|
||||
required=False,
|
||||
default="/tmp",
|
||||
help="Directory to store the temporary videos and images generated while creating the dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tests-data-dir",
|
||||
type=Path,
|
||||
|
|
|
@ -84,5 +84,7 @@ if __name__ == "__main__":
|
|||
"lerobot/pusht",
|
||||
"lerobot/aloha_sim_insertion_human",
|
||||
"lerobot/xarm_lift_medium",
|
||||
"lerobot/nyu_franka_play_dataset",
|
||||
"lerobot/cmu_stretch",
|
||||
]:
|
||||
save_dataset_to_safetensors("tests/data/save_dataset_to_safetensors", repo_id=dataset)
|
||||
|
|
|
@ -303,6 +303,9 @@ def test_flatten_unflatten_dict():
|
|||
"lerobot/pusht",
|
||||
"lerobot/aloha_sim_insertion_human",
|
||||
"lerobot/xarm_lift_medium",
|
||||
# (michel-aractingi) commenting the two datasets from openx as test is failing
|
||||
# "lerobot/nyu_franka_play_dataset",
|
||||
# "lerobot/cmu_stretch",
|
||||
],
|
||||
)
|
||||
def test_backward_compatibility(repo_id):
|
||||
|
@ -318,6 +321,11 @@ def test_backward_compatibility(repo_id):
|
|||
new_frame = dataset[i] # noqa: B023
|
||||
old_frame = load_file(test_dir / f"frame_{i}.safetensors") # noqa: B023
|
||||
|
||||
# ignore language instructions (if exists) in language conditioned datasets
|
||||
# TODO (michel-aractingi): transform language obs to langauge embeddings via tokenizer
|
||||
new_frame.pop("language_instruction", None)
|
||||
old_frame.pop("language_instruction", None)
|
||||
|
||||
new_keys = set(new_frame.keys())
|
||||
old_keys = set(old_frame.keys())
|
||||
assert new_keys == old_keys, f"{new_keys=} and {old_keys=} are not the same"
|
||||
|
|
Loading…
Reference in New Issue