From d8e4a2ccd7d76cf7b27753db2dd4f5fa4b75a6cb Mon Sep 17 00:00:00 2001 From: Claudio Coppola Date: Mon, 27 Jan 2025 17:10:38 +0000 Subject: [PATCH] fix: aloha_hd5 to LerobotDataset v2 frame appending out of the right scope --- lerobot/scripts/push_aloha_dataset_to_hub.py | 82 ++++++++++++++++++++ 1 file changed, 82 insertions(+) create mode 100644 lerobot/scripts/push_aloha_dataset_to_hub.py diff --git a/lerobot/scripts/push_aloha_dataset_to_hub.py b/lerobot/scripts/push_aloha_dataset_to_hub.py new file mode 100644 index 00000000..c40256d3 --- /dev/null +++ b/lerobot/scripts/push_aloha_dataset_to_hub.py @@ -0,0 +1,82 @@ +from pathlib import Path + +import cv2 +import h5py +import torch + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset + +data_path = Path("/home/ccop/code/aloha_data") + + +def get_features(hdf5_file): + topics = [] + features = {} + hdf5_file.visititems(lambda name, obj: topics.append(name) if isinstance(obj, h5py.Dataset) else None) + for topic in topics: + # print(topic.replace('/', '.')) + if "images" in topic.split("/"): + features[topic.replace("/", ".")] = { + "dtype": "image", + "shape": cv2.imdecode(hdf5_file[topic][0], 1).transpose(2, 0, 1).shape, + "names": None, + } + elif "compress_len" in topic.split("/"): + continue + else: + features[topic.replace("/", ".")] = { + "dtype": str(hdf5_file[topic][0].dtype), + "shape": hdf5_file[topic][0].shape, + "names": None, + } + + return features + + +def extract_episode(episode_path, features, n_frames, dataset): + with h5py.File(episode_path, "r") as file: + # List all groups + for frame_idx in range(n_frames): + frame = {} + for feature in features: + if "images" in feature.split("."): + frame[feature] = torch.from_numpy( + cv2.imdecode(file[feature.replace(".", "/")][frame_idx], 1).transpose(2, 0, 1) + ) + else: + frame[feature] = torch.from_numpy(file[feature.replace(".", "/")][frame_idx]) + + dataset.add_frame(frame) + + +def get_dataset_properties(raw_folder): + from os import listdir + + episode_list = listdir(raw_folder) + with h5py.File(raw_folder / episode_list[0], "r") as file: + features = get_features(file) + n_frames = file["observations/images/cam_high"][:].shape[0] + return features, n_frames + + +if __name__ == "__main__": + raw_folder = data_path.absolute() / "aloha_stationary_replay_test" + episode_file = "episode_0.hdf5" + + features, n_frames = get_dataset_properties(raw_folder) + + dataset = LeRobotDataset.create( + repo_id="ccop/aloha_stationary_replay_test_v3", + fps=50, + robot_type="aloha-stationary", + features=features, + image_writer_threads=4, + ) + + extract_episode(raw_folder / episode_file, features, n_frames, dataset) + print("save episode!") + dataset.save_episode( + task="move_cube", + ) + dataset.consolidate() + dataset.push_to_hub()