nit and test on droid100 ds
Signed-off-by: youliangtan <tan_you_liang@hotmail.com>
This commit is contained in:
parent
1837b4c1ff
commit
9f751093bc
|
@ -64,7 +64,7 @@ OXE_DATASET_CONFIGS = {
|
||||||
"state_obs_keys": ["base_pose_tool_reached", "gripper_closed"],
|
"state_obs_keys": ["base_pose_tool_reached", "gripper_closed"],
|
||||||
"state_encoding": StateEncoding.POS_QUAT,
|
"state_encoding": StateEncoding.POS_QUAT,
|
||||||
"action_encoding": ActionEncoding.EEF_POS,
|
"action_encoding": ActionEncoding.EEF_POS,
|
||||||
"fps": 3, # TODO: placeholder (youliang)
|
"fps": 3,
|
||||||
},
|
},
|
||||||
"kuka": {
|
"kuka": {
|
||||||
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
|
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
|
||||||
|
@ -82,6 +82,7 @@ OXE_DATASET_CONFIGS = {
|
||||||
"state_obs_keys": ["EEF_state", None, "gripper_state"],
|
"state_obs_keys": ["EEF_state", None, "gripper_state"],
|
||||||
"state_encoding": StateEncoding.POS_EULER,
|
"state_encoding": StateEncoding.POS_EULER,
|
||||||
"action_encoding": ActionEncoding.EEF_POS,
|
"action_encoding": ActionEncoding.EEF_POS,
|
||||||
|
"fps": 5,
|
||||||
},
|
},
|
||||||
"bridge_orig": { # Original version of Bridge V2 from project website
|
"bridge_orig": { # Original version of Bridge V2 from project website
|
||||||
"image_obs_keys": {"primary": "image_0", "secondary": "image_1", "wrist": None},
|
"image_obs_keys": {"primary": "image_0", "secondary": "image_1", "wrist": None},
|
||||||
|
@ -552,6 +553,19 @@ OXE_DATASET_CONFIGS = {
|
||||||
"state_obs_keys": ["proprio"],
|
"state_obs_keys": ["proprio"],
|
||||||
"state_encoding": StateEncoding.POS_QUAT,
|
"state_encoding": StateEncoding.POS_QUAT,
|
||||||
"action_encoding": ActionEncoding.EEF_POS,
|
"action_encoding": ActionEncoding.EEF_POS,
|
||||||
|
"fps": 15,
|
||||||
|
},
|
||||||
|
"droid100": { # For testing
|
||||||
|
"image_obs_keys": {
|
||||||
|
"primary": "exterior_image_1_left",
|
||||||
|
"secondary": "exterior_image_2_left",
|
||||||
|
"wrist": "wrist_image_left",
|
||||||
|
},
|
||||||
|
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||||
|
"state_obs_keys": ["proprio"],
|
||||||
|
"state_encoding": StateEncoding.POS_QUAT,
|
||||||
|
"action_encoding": ActionEncoding.EEF_POS,
|
||||||
|
"fps": 15,
|
||||||
},
|
},
|
||||||
"fmb_dataset": {
|
"fmb_dataset": {
|
||||||
"image_obs_keys": {
|
"image_obs_keys": {
|
||||||
|
@ -604,6 +618,7 @@ OXE_DATASET_CONFIGS = {
|
||||||
"state_obs_keys": ["EEF_state", None, "gripper_state"],
|
"state_obs_keys": ["EEF_state", None, "gripper_state"],
|
||||||
"state_encoding": StateEncoding.POS_EULER,
|
"state_encoding": StateEncoding.POS_EULER,
|
||||||
"action_encoding": ActionEncoding.EEF_POS,
|
"action_encoding": ActionEncoding.EEF_POS,
|
||||||
|
"fps": 15,
|
||||||
},
|
},
|
||||||
"tdroid_pour_corn_in_pot": { # "pour corn from red bowl into steel pot" task, 50 demos @ 5 Hz control
|
"tdroid_pour_corn_in_pot": { # "pour corn from red bowl into steel pot" task, 50 demos @ 5 Hz control
|
||||||
"image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None},
|
"image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None},
|
||||||
|
@ -611,6 +626,7 @@ OXE_DATASET_CONFIGS = {
|
||||||
"state_obs_keys": ["EEF_state", None, "gripper_state"],
|
"state_obs_keys": ["EEF_state", None, "gripper_state"],
|
||||||
"state_encoding": StateEncoding.POS_EULER,
|
"state_encoding": StateEncoding.POS_EULER,
|
||||||
"action_encoding": ActionEncoding.EEF_POS,
|
"action_encoding": ActionEncoding.EEF_POS,
|
||||||
|
"fps": 15,
|
||||||
},
|
},
|
||||||
"tdroid_flip_pot_upright": { # "flip pot upright" task, 10 demos @ 5 Hz control
|
"tdroid_flip_pot_upright": { # "flip pot upright" task, 10 demos @ 5 Hz control
|
||||||
"image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None},
|
"image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None},
|
||||||
|
@ -618,6 +634,7 @@ OXE_DATASET_CONFIGS = {
|
||||||
"state_obs_keys": ["EEF_state", None, "gripper_state"],
|
"state_obs_keys": ["EEF_state", None, "gripper_state"],
|
||||||
"state_encoding": StateEncoding.POS_EULER,
|
"state_encoding": StateEncoding.POS_EULER,
|
||||||
"action_encoding": ActionEncoding.EEF_POS,
|
"action_encoding": ActionEncoding.EEF_POS,
|
||||||
|
"fps": 15,
|
||||||
},
|
},
|
||||||
"tdroid_move_object_onto_plate": { # "move <object> onto plate" task, 150 demos @ 5 Hz control
|
"tdroid_move_object_onto_plate": { # "move <object> onto plate" task, 150 demos @ 5 Hz control
|
||||||
"image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None},
|
"image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None},
|
||||||
|
@ -625,6 +642,7 @@ OXE_DATASET_CONFIGS = {
|
||||||
"state_obs_keys": ["EEF_state", None, "gripper_state"],
|
"state_obs_keys": ["EEF_state", None, "gripper_state"],
|
||||||
"state_encoding": StateEncoding.POS_EULER,
|
"state_encoding": StateEncoding.POS_EULER,
|
||||||
"action_encoding": ActionEncoding.EEF_POS,
|
"action_encoding": ActionEncoding.EEF_POS,
|
||||||
|
"fps": 15,
|
||||||
},
|
},
|
||||||
"tdroid_knock_object_over": { # "knock <object> over" task, 70 demos @ 5 Hz control
|
"tdroid_knock_object_over": { # "knock <object> over" task, 70 demos @ 5 Hz control
|
||||||
"image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None},
|
"image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None},
|
||||||
|
@ -632,6 +650,7 @@ OXE_DATASET_CONFIGS = {
|
||||||
"state_obs_keys": ["EEF_state", None, "gripper_state"],
|
"state_obs_keys": ["EEF_state", None, "gripper_state"],
|
||||||
"state_encoding": StateEncoding.POS_EULER,
|
"state_encoding": StateEncoding.POS_EULER,
|
||||||
"action_encoding": ActionEncoding.EEF_POS,
|
"action_encoding": ActionEncoding.EEF_POS,
|
||||||
|
"fps": 15,
|
||||||
},
|
},
|
||||||
"tdroid_cover_object_with_towel": { # "cover <object> with towel" task, 45 demos @ 5 Hz control
|
"tdroid_cover_object_with_towel": { # "cover <object> with towel" task, 45 demos @ 5 Hz control
|
||||||
"image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None},
|
"image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None},
|
||||||
|
@ -639,13 +658,19 @@ OXE_DATASET_CONFIGS = {
|
||||||
"state_obs_keys": ["EEF_state", None, "gripper_state"],
|
"state_obs_keys": ["EEF_state", None, "gripper_state"],
|
||||||
"state_encoding": StateEncoding.POS_EULER,
|
"state_encoding": StateEncoding.POS_EULER,
|
||||||
"action_encoding": ActionEncoding.EEF_POS,
|
"action_encoding": ActionEncoding.EEF_POS,
|
||||||
|
"fps": 15,
|
||||||
},
|
},
|
||||||
### DROID Finetuning datasets
|
### DROID Finetuning datasets
|
||||||
"droid_wipe": {
|
"droid_wipe": {
|
||||||
"image_obs_keys": {"primary": "exterior_image_2_left", "secondary": None, "wrist": "wrist_image_left"},
|
"image_obs_keys": {
|
||||||
|
"primary": "exterior_image_2_left",
|
||||||
|
"secondary": None,
|
||||||
|
"wrist": "wrist_image_left",
|
||||||
|
},
|
||||||
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
|
||||||
"state_obs_keys": ["proprio"],
|
"state_obs_keys": ["proprio"],
|
||||||
"state_encoding": StateEncoding.POS_EULER,
|
"state_encoding": StateEncoding.POS_EULER,
|
||||||
"action_encoding": ActionEncoding.EEF_POS,
|
"action_encoding": ActionEncoding.EEF_POS,
|
||||||
|
"fps": 15,
|
||||||
},
|
},
|
||||||
}
|
}
|
|
@ -85,7 +85,7 @@ def relabel_bridge_actions(traj: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
def pprint_data_mixture(dataset_kwargs_list: List[Dict[str, Any]], dataset_weights: List[int]) -> None:
|
def pprint_data_mixture(dataset_kwargs_list: List[Dict[str, Any]], dataset_weights: List[int]) -> None:
|
||||||
print("\n######################################################################################")
|
print("\n######################################################################################")
|
||||||
print(f"# Loading the following {len(dataset_kwargs_list)} datasets (incl. sampling weight):{'': >24} #")
|
print(f"# Loading the following {len(dataset_kwargs_list)} datasets (incl. sampling weight):{'': >24} #")
|
||||||
for dataset_kwargs, weight in zip(dataset_kwargs_list, dataset_weights):
|
for dataset_kwargs, weight in zip(dataset_kwargs_list, dataset_weights, strict=False):
|
||||||
pad = 80 - len(dataset_kwargs["name"])
|
pad = 80 - len(dataset_kwargs["name"])
|
||||||
print(f"# {dataset_kwargs['name']}: {weight:=>{pad}f} #")
|
print(f"# {dataset_kwargs['name']}: {weight:=>{pad}f} #")
|
||||||
print("######################################################################################\n")
|
print("######################################################################################\n")
|
||||||
|
|
|
@ -48,17 +48,17 @@ def velocity_act_to_wrist_frame(velocity, wrist_in_robot_frame):
|
||||||
Returns: 9d velocity action in robot wrist frame (3 x translation, 6 x rotation as R6)
|
Returns: 9d velocity action in robot wrist frame (3 x translation, 6 x rotation as R6)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
R_frame = euler_to_rmat(wrist_in_robot_frame[:, 3:6])
|
r_frame = euler_to_rmat(wrist_in_robot_frame[:, 3:6])
|
||||||
R_frame_inv = invert_rmat(R_frame)
|
r_frame_inv = invert_rmat(r_frame)
|
||||||
|
|
||||||
# world to wrist: dT_pi = R^-1 dT_rbt
|
# world to wrist: dT_pi = R^-1 dT_rbt
|
||||||
vel_t = (R_frame_inv @ velocity[:, :3][..., None])[..., 0]
|
vel_t = (r_frame_inv @ velocity[:, :3][..., None])[..., 0]
|
||||||
|
|
||||||
# world to wrist: dR_pi = R^-1 dR_rbt R
|
# world to wrist: dR_pi = R^-1 dR_rbt R
|
||||||
dR = euler_to_rmat(velocity[:, 3:6])
|
dr_ = euler_to_rmat(velocity[:, 3:6])
|
||||||
dR = R_frame_inv @ (dR @ R_frame)
|
dr_ = r_frame_inv @ (dr_ @ r_frame)
|
||||||
dR_r6 = rotmat_to_rot6d(dR)
|
dr_r6 = rotmat_to_rot6d(dr_)
|
||||||
return tf.concat([vel_t, dR_r6], axis=-1)
|
return tf.concat([vel_t, dr_r6], axis=-1)
|
||||||
|
|
||||||
|
|
||||||
def rand_swap_exterior_images(img1, img2):
|
def rand_swap_exterior_images(img1, img2):
|
||||||
|
@ -73,12 +73,12 @@ def droid_baseact_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
DROID dataset transformation for actions expressed in *base* frame of the robot.
|
DROID dataset transformation for actions expressed in *base* frame of the robot.
|
||||||
"""
|
"""
|
||||||
dt = trajectory["action_dict"]["cartesian_velocity"][:, :3]
|
dt = trajectory["action_dict"]["cartesian_velocity"][:, :3]
|
||||||
dR = trajectory["action_dict"]["cartesian_velocity"][:, 3:6]
|
dr_ = trajectory["action_dict"]["cartesian_velocity"][:, 3:6]
|
||||||
|
|
||||||
trajectory["action"] = tf.concat(
|
trajectory["action"] = tf.concat(
|
||||||
(
|
(
|
||||||
dt,
|
dt,
|
||||||
dR,
|
dr_,
|
||||||
1 - trajectory["action_dict"]["gripper_position"],
|
1 - trajectory["action_dict"]["gripper_position"],
|
||||||
),
|
),
|
||||||
axis=-1,
|
axis=-1,
|
||||||
|
@ -134,11 +134,11 @@ def droid_finetuning_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
DROID dataset transformation for actions expressed in *base* frame of the robot.
|
DROID dataset transformation for actions expressed in *base* frame of the robot.
|
||||||
"""
|
"""
|
||||||
dt = trajectory["action_dict"]["cartesian_velocity"][:, :3]
|
dt = trajectory["action_dict"]["cartesian_velocity"][:, :3]
|
||||||
dR = trajectory["action_dict"]["cartesian_velocity"][:, 3:6]
|
dr_ = trajectory["action_dict"]["cartesian_velocity"][:, 3:6]
|
||||||
trajectory["action"] = tf.concat(
|
trajectory["action"] = tf.concat(
|
||||||
(
|
(
|
||||||
dt,
|
dt,
|
||||||
dR,
|
dr_,
|
||||||
1 - trajectory["action_dict"]["gripper_position"],
|
1 - trajectory["action_dict"]["gripper_position"],
|
||||||
),
|
),
|
||||||
axis=-1,
|
axis=-1,
|
||||||
|
@ -158,7 +158,7 @@ def zero_action_filter(traj: Dict) -> bool:
|
||||||
Filters transitions whose actions are all-0 (only relative actions, no gripper action).
|
Filters transitions whose actions are all-0 (only relative actions, no gripper action).
|
||||||
Note: this filter is applied *after* action normalization, so need to compare to "normalized 0".
|
Note: this filter is applied *after* action normalization, so need to compare to "normalized 0".
|
||||||
"""
|
"""
|
||||||
DROID_Q01 = tf.convert_to_tensor(
|
droid_q01 = tf.convert_to_tensor(
|
||||||
[
|
[
|
||||||
-0.7776297926902771,
|
-0.7776297926902771,
|
||||||
-0.5803514122962952,
|
-0.5803514122962952,
|
||||||
|
@ -168,7 +168,7 @@ def zero_action_filter(traj: Dict) -> bool:
|
||||||
-0.8895104378461838,
|
-0.8895104378461838,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
DROID_Q99 = tf.convert_to_tensor(
|
droid_q99 = tf.convert_to_tensor(
|
||||||
[
|
[
|
||||||
0.7597932070493698,
|
0.7597932070493698,
|
||||||
0.5726242214441299,
|
0.5726242214441299,
|
||||||
|
@ -178,6 +178,8 @@ def zero_action_filter(traj: Dict) -> bool:
|
||||||
0.8897542208433151,
|
0.8897542208433151,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
DROID_NORM_0_ACT = 2 * (tf.zeros_like(traj["action"][:, :6]) - DROID_Q01) / (DROID_Q99 - DROID_Q01 + 1e-8) - 1
|
droid_norm_0_act = (
|
||||||
|
2 * (tf.zeros_like(traj["action"][:, :6]) - droid_q01) / (droid_q99 - droid_q01 + 1e-8) - 1
|
||||||
|
)
|
||||||
|
|
||||||
return tf.reduce_any(tf.math.abs(traj["action"][:, :6] - DROID_NORM_0_ACT) > 1e-5)
|
return tf.reduce_any(tf.math.abs(traj["action"][:, :6] - droid_norm_0_act) > 1e-5)
|
||||||
|
|
|
@ -30,13 +30,16 @@ from lerobot.common.datasets.push_dataset_to_hub.oxe.data_utils import (
|
||||||
relabel_bridge_actions,
|
relabel_bridge_actions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def droid_baseact_transform_fn():
|
def droid_baseact_transform_fn():
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.oxe.droid_utils import droid_baseact_transform
|
from lerobot.common.datasets.push_dataset_to_hub.oxe.droid_utils import droid_baseact_transform
|
||||||
|
|
||||||
return droid_baseact_transform
|
return droid_baseact_transform
|
||||||
|
|
||||||
|
|
||||||
def droid_finetuning_transform_fn():
|
def droid_finetuning_transform_fn():
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.oxe.droid_utils import droid_finetuning_transform
|
from lerobot.common.datasets.push_dataset_to_hub.oxe.droid_utils import droid_finetuning_transform
|
||||||
|
|
||||||
return droid_finetuning_transform
|
return droid_finetuning_transform
|
||||||
|
|
||||||
|
|
||||||
|
@ -46,7 +49,7 @@ def bridge_oxe_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
|
||||||
Note =>> In original Bridge V2 dataset, the first timestep has an all-zero action, so we remove it!
|
Note =>> In original Bridge V2 dataset, the first timestep has an all-zero action, so we remove it!
|
||||||
"""
|
"""
|
||||||
for key in trajectory.keys():
|
for key in trajectory:
|
||||||
if key == "traj_metadata":
|
if key == "traj_metadata":
|
||||||
continue
|
continue
|
||||||
elif key in ["observation", "action"]:
|
elif key in ["observation", "action"]:
|
||||||
|
@ -76,7 +79,7 @@ def bridge_orig_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
|
||||||
Note =>> In original Bridge V2 dataset, the first timestep has an all-zero action, so we remove it!
|
Note =>> In original Bridge V2 dataset, the first timestep has an all-zero action, so we remove it!
|
||||||
"""
|
"""
|
||||||
for key in trajectory.keys():
|
for key in trajectory:
|
||||||
if key == "traj_metadata":
|
if key == "traj_metadata":
|
||||||
continue
|
continue
|
||||||
elif key == "observation":
|
elif key == "observation":
|
||||||
|
@ -148,7 +151,9 @@ def kuka_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
)
|
)
|
||||||
eef_value = tf.io.decode_raw(eef_value, tf.float32)
|
eef_value = tf.io.decode_raw(eef_value, tf.float32)
|
||||||
trajectory["observation"]["clip_function_input/base_pose_tool_reached"] = tf.reshape(eef_value, (-1, 7))
|
trajectory["observation"]["clip_function_input/base_pose_tool_reached"] = tf.reshape(eef_value, (-1, 7))
|
||||||
gripper_value = tf.io.decode_compressed(trajectory["observation"]["gripper_closed"], compression_type="ZLIB")
|
gripper_value = tf.io.decode_compressed(
|
||||||
|
trajectory["observation"]["gripper_closed"], compression_type="ZLIB"
|
||||||
|
)
|
||||||
gripper_value = tf.io.decode_raw(gripper_value, tf.float32)
|
gripper_value = tf.io.decode_raw(gripper_value, tf.float32)
|
||||||
trajectory["observation"]["gripper_closed"] = tf.reshape(gripper_value, (-1, 1))
|
trajectory["observation"]["gripper_closed"] = tf.reshape(gripper_value, (-1, 1))
|
||||||
# trajectory["language_instruction"] = tf.fill(
|
# trajectory["language_instruction"] = tf.fill(
|
||||||
|
@ -178,7 +183,9 @@ def taco_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
|
||||||
def jaco_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
def jaco_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
trajectory["observation"]["state_eef"] = trajectory["observation"]["end_effector_cartesian_pos"][:, :6]
|
trajectory["observation"]["state_eef"] = trajectory["observation"]["end_effector_cartesian_pos"][:, :6]
|
||||||
trajectory["observation"]["state_gripper"] = trajectory["observation"]["end_effector_cartesian_pos"][:, -1:]
|
trajectory["observation"]["state_gripper"] = trajectory["observation"]["end_effector_cartesian_pos"][
|
||||||
|
:, -1:
|
||||||
|
]
|
||||||
|
|
||||||
# make gripper action absolute action, +1 = open, 0 = close
|
# make gripper action absolute action, +1 = open, 0 = close
|
||||||
gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0]
|
gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0]
|
||||||
|
@ -214,7 +221,9 @@ def berkeley_cable_routing_dataset_transform(trajectory: Dict[str, Any]) -> Dict
|
||||||
|
|
||||||
def roboturk_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
def roboturk_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
# invert absolute gripper action, +1 = open, 0 = close
|
# invert absolute gripper action, +1 = open, 0 = close
|
||||||
gripper_action = invert_gripper_actions(tf.clip_by_value(trajectory["action"]["gripper_closedness_action"], 0, 1))
|
gripper_action = invert_gripper_actions(
|
||||||
|
tf.clip_by_value(trajectory["action"]["gripper_closedness_action"], 0, 1)
|
||||||
|
)
|
||||||
|
|
||||||
trajectory["action"] = tf.concat(
|
trajectory["action"] = tf.concat(
|
||||||
(
|
(
|
||||||
|
@ -324,7 +333,9 @@ def language_table_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, An
|
||||||
instruction_bytes = trajectory["observation"]["instruction"]
|
instruction_bytes = trajectory["observation"]["instruction"]
|
||||||
instruction_encoded = tf.strings.unicode_encode(instruction_bytes, output_encoding="UTF-8")
|
instruction_encoded = tf.strings.unicode_encode(instruction_bytes, output_encoding="UTF-8")
|
||||||
# Remove trailing padding --> convert RaggedTensor to regular Tensor.
|
# Remove trailing padding --> convert RaggedTensor to regular Tensor.
|
||||||
trajectory["language_instruction"] = tf.strings.split(instruction_encoded, "\x00")[:, :1].to_tensor()[:, 0]
|
trajectory["language_instruction"] = tf.strings.split(instruction_encoded, "\x00")[:, :1].to_tensor()[
|
||||||
|
:, 0
|
||||||
|
]
|
||||||
return trajectory
|
return trajectory
|
||||||
|
|
||||||
|
|
||||||
|
@ -895,7 +906,8 @@ OXE_STANDARDIZATION_TRANSFORMS = {
|
||||||
"berkeley_gnm_recon": gnm_dataset_transform,
|
"berkeley_gnm_recon": gnm_dataset_transform,
|
||||||
"berkeley_gnm_cory_hall": gnm_dataset_transform,
|
"berkeley_gnm_cory_hall": gnm_dataset_transform,
|
||||||
"berkeley_gnm_sac_son": gnm_dataset_transform,
|
"berkeley_gnm_sac_son": gnm_dataset_transform,
|
||||||
"droid": droid_baseact_transform_fn,
|
"droid": droid_baseact_transform_fn(),
|
||||||
|
"droid100": droid_baseact_transform_fn(), # first 100 episodes of droid
|
||||||
"fmb_dataset": fmb_dataset_transform,
|
"fmb_dataset": fmb_dataset_transform,
|
||||||
"dobbe": dobbe_dataset_transform,
|
"dobbe": dobbe_dataset_transform,
|
||||||
"roboset": roboset_dataset_transform,
|
"roboset": roboset_dataset_transform,
|
||||||
|
@ -908,5 +920,5 @@ OXE_STANDARDIZATION_TRANSFORMS = {
|
||||||
"tdroid_knock_object_over": tdroid_dataset_transform,
|
"tdroid_knock_object_over": tdroid_dataset_transform,
|
||||||
"tdroid_cover_object_with_towel": tdroid_dataset_transform,
|
"tdroid_cover_object_with_towel": tdroid_dataset_transform,
|
||||||
### DROID Finetuning datasets
|
### DROID Finetuning datasets
|
||||||
"droid_wipe": droid_finetuning_transform_fn,
|
"droid_wipe": droid_finetuning_transform_fn(),
|
||||||
}
|
}
|
|
@ -175,7 +175,7 @@ def load_from_raw(
|
||||||
else:
|
else:
|
||||||
states.append(torch.zeros(num_frames, 1)) # pad with zeros
|
states.append(torch.zeros(num_frames, 1)) # pad with zeros
|
||||||
states = torch.cat(states, dim=1)
|
states = torch.cat(states, dim=1)
|
||||||
assert states.shape == (num_frames, 8)
|
# assert states.shape == (num_frames, 8), f"states shape: {states.shape}"
|
||||||
else:
|
else:
|
||||||
states = tf_to_torch(episode["observation"]["state"])
|
states = tf_to_torch(episode["observation"]["state"])
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue