diff --git a/examples/port_datasets/aloha_hd5.py b/examples/port_datasets/aloha_hd5.py index 1e4054a6..2a6bb590 100644 --- a/examples/port_datasets/aloha_hd5.py +++ b/examples/port_datasets/aloha_hd5.py @@ -118,7 +118,11 @@ class AlohaHD5Extractor: for frame_idx in range(file["/action"].shape[0]): frame = {} for feature_id in features: - feature_name_hd5 = feature_id.replace(".", "/") + feature_name_hd5 = ( + feature_id.replace(".", "/") + .replace("observation", "observations") + .replace("state", "qpos") + ) if "images" in feature_id.split("."): image = ( (file[feature_name_hd5][frame_idx]) @@ -153,8 +157,8 @@ class AlohaHD5Extractor: """ # Initialize lists to store topics and features - topics = [] - features = {} + topics: list[str] = [] + features: dict[str, dict] = {} # Open the HDF5 file with h5py.File(hdf5_file_path, "r") as hdf5_file: @@ -166,9 +170,12 @@ class AlohaHD5Extractor: # Iterate over each topic to define its features for topic in topics: # If the topic is an image, define it as a video feature + destination_topic = ( + topic.replace("/", ".").replace("observations", "observation").replace("qpos", "state") + ) if "images" in topic.split("/"): sample = hdf5_file[topic][0] - features[topic.replace("/", ".")] = { + features[destination_topic] = { "dtype": "video" if encode_as_video else "image", "shape": cv2.imdecode(hdf5_file[topic][0], 1).transpose(2, 0, 1).shape if image_compressed @@ -184,7 +191,7 @@ class AlohaHD5Extractor: continue # Otherwise, define it as a regular feature else: - features[topic.replace("/", ".")] = { + features[destination_topic] = { "dtype": str(hdf5_file[topic][0].dtype), "shape": (topic_shape := hdf5_file[topic][0].shape), "names": [f"{topic.split('/')[-1]}_{k}" for k in range(topic_shape[0])],