observations -> observation, qpos -> state

This commit is contained in:
Claudio Coppola 2024-12-20 09:09:24 +00:00
parent b65fd4428d
commit 72e0d24167
1 changed files with 12 additions and 5 deletions

View File

@ -118,7 +118,11 @@ class AlohaHD5Extractor:
for frame_idx in range(file["/action"].shape[0]): for frame_idx in range(file["/action"].shape[0]):
frame = {} frame = {}
for feature_id in features: for feature_id in features:
feature_name_hd5 = feature_id.replace(".", "/") feature_name_hd5 = (
feature_id.replace(".", "/")
.replace("observation", "observations")
.replace("state", "qpos")
)
if "images" in feature_id.split("."): if "images" in feature_id.split("."):
image = ( image = (
(file[feature_name_hd5][frame_idx]) (file[feature_name_hd5][frame_idx])
@ -153,8 +157,8 @@ class AlohaHD5Extractor:
""" """
# Initialize lists to store topics and features # Initialize lists to store topics and features
topics = [] topics: list[str] = []
features = {} features: dict[str, dict] = {}
# Open the HDF5 file # Open the HDF5 file
with h5py.File(hdf5_file_path, "r") as hdf5_file: with h5py.File(hdf5_file_path, "r") as hdf5_file:
@ -166,9 +170,12 @@ class AlohaHD5Extractor:
# Iterate over each topic to define its features # Iterate over each topic to define its features
for topic in topics: for topic in topics:
# If the topic is an image, define it as a video feature # If the topic is an image, define it as a video feature
destination_topic = (
topic.replace("/", ".").replace("observations", "observation").replace("qpos", "state")
)
if "images" in topic.split("/"): if "images" in topic.split("/"):
sample = hdf5_file[topic][0] sample = hdf5_file[topic][0]
features[topic.replace("/", ".")] = { features[destination_topic] = {
"dtype": "video" if encode_as_video else "image", "dtype": "video" if encode_as_video else "image",
"shape": cv2.imdecode(hdf5_file[topic][0], 1).transpose(2, 0, 1).shape "shape": cv2.imdecode(hdf5_file[topic][0], 1).transpose(2, 0, 1).shape
if image_compressed if image_compressed
@ -184,7 +191,7 @@ class AlohaHD5Extractor:
continue continue
# Otherwise, define it as a regular feature # Otherwise, define it as a regular feature
else: else:
features[topic.replace("/", ".")] = { features[destination_topic] = {
"dtype": str(hdf5_file[topic][0].dtype), "dtype": str(hdf5_file[topic][0].dtype),
"shape": (topic_shape := hdf5_file[topic][0].shape), "shape": (topic_shape := hdf5_file[topic][0].shape),
"names": [f"{topic.split('/')[-1]}_{k}" for k in range(topic_shape[0])], "names": [f"{topic.split('/')[-1]}_{k}" for k in range(topic_shape[0])],