fix: aloha_hd5 to LerobotDataset v2 frame appending out of the right scope

This commit is contained in:
Claudio Coppola 2025-01-27 17:10:38 +00:00
parent 283545f3f1
commit d8e4a2ccd7
1 changed files with 82 additions and 0 deletions

View File

@ -0,0 +1,82 @@
from pathlib import Path
import cv2
import h5py
import torch
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
data_path = Path("/home/ccop/code/aloha_data")
def get_features(hdf5_file):
topics = []
features = {}
hdf5_file.visititems(lambda name, obj: topics.append(name) if isinstance(obj, h5py.Dataset) else None)
for topic in topics:
# print(topic.replace('/', '.'))
if "images" in topic.split("/"):
features[topic.replace("/", ".")] = {
"dtype": "image",
"shape": cv2.imdecode(hdf5_file[topic][0], 1).transpose(2, 0, 1).shape,
"names": None,
}
elif "compress_len" in topic.split("/"):
continue
else:
features[topic.replace("/", ".")] = {
"dtype": str(hdf5_file[topic][0].dtype),
"shape": hdf5_file[topic][0].shape,
"names": None,
}
return features
def extract_episode(episode_path, features, n_frames, dataset):
with h5py.File(episode_path, "r") as file:
# List all groups
for frame_idx in range(n_frames):
frame = {}
for feature in features:
if "images" in feature.split("."):
frame[feature] = torch.from_numpy(
cv2.imdecode(file[feature.replace(".", "/")][frame_idx], 1).transpose(2, 0, 1)
)
else:
frame[feature] = torch.from_numpy(file[feature.replace(".", "/")][frame_idx])
dataset.add_frame(frame)
def get_dataset_properties(raw_folder):
from os import listdir
episode_list = listdir(raw_folder)
with h5py.File(raw_folder / episode_list[0], "r") as file:
features = get_features(file)
n_frames = file["observations/images/cam_high"][:].shape[0]
return features, n_frames
if __name__ == "__main__":
raw_folder = data_path.absolute() / "aloha_stationary_replay_test"
episode_file = "episode_0.hdf5"
features, n_frames = get_dataset_properties(raw_folder)
dataset = LeRobotDataset.create(
repo_id="ccop/aloha_stationary_replay_test_v3",
fps=50,
robot_type="aloha-stationary",
features=features,
image_writer_threads=4,
)
extract_episode(raw_folder / episode_file, features, n_frames, dataset)
print("save episode!")
dataset.save_episode(
task="move_cube",
)
dataset.consolidate()
dataset.push_to_hub()