fix: aloha_hd5 to LerobotDataset v2 frame appending out of the right scope
This commit is contained in:
parent
283545f3f1
commit
d8e4a2ccd7
|
@ -0,0 +1,82 @@
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import h5py
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
|
||||||
|
data_path = Path("/home/ccop/code/aloha_data")
|
||||||
|
|
||||||
|
|
||||||
|
def get_features(hdf5_file):
|
||||||
|
topics = []
|
||||||
|
features = {}
|
||||||
|
hdf5_file.visititems(lambda name, obj: topics.append(name) if isinstance(obj, h5py.Dataset) else None)
|
||||||
|
for topic in topics:
|
||||||
|
# print(topic.replace('/', '.'))
|
||||||
|
if "images" in topic.split("/"):
|
||||||
|
features[topic.replace("/", ".")] = {
|
||||||
|
"dtype": "image",
|
||||||
|
"shape": cv2.imdecode(hdf5_file[topic][0], 1).transpose(2, 0, 1).shape,
|
||||||
|
"names": None,
|
||||||
|
}
|
||||||
|
elif "compress_len" in topic.split("/"):
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
features[topic.replace("/", ".")] = {
|
||||||
|
"dtype": str(hdf5_file[topic][0].dtype),
|
||||||
|
"shape": hdf5_file[topic][0].shape,
|
||||||
|
"names": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
def extract_episode(episode_path, features, n_frames, dataset):
|
||||||
|
with h5py.File(episode_path, "r") as file:
|
||||||
|
# List all groups
|
||||||
|
for frame_idx in range(n_frames):
|
||||||
|
frame = {}
|
||||||
|
for feature in features:
|
||||||
|
if "images" in feature.split("."):
|
||||||
|
frame[feature] = torch.from_numpy(
|
||||||
|
cv2.imdecode(file[feature.replace(".", "/")][frame_idx], 1).transpose(2, 0, 1)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
frame[feature] = torch.from_numpy(file[feature.replace(".", "/")][frame_idx])
|
||||||
|
|
||||||
|
dataset.add_frame(frame)
|
||||||
|
|
||||||
|
|
||||||
|
def get_dataset_properties(raw_folder):
|
||||||
|
from os import listdir
|
||||||
|
|
||||||
|
episode_list = listdir(raw_folder)
|
||||||
|
with h5py.File(raw_folder / episode_list[0], "r") as file:
|
||||||
|
features = get_features(file)
|
||||||
|
n_frames = file["observations/images/cam_high"][:].shape[0]
|
||||||
|
return features, n_frames
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raw_folder = data_path.absolute() / "aloha_stationary_replay_test"
|
||||||
|
episode_file = "episode_0.hdf5"
|
||||||
|
|
||||||
|
features, n_frames = get_dataset_properties(raw_folder)
|
||||||
|
|
||||||
|
dataset = LeRobotDataset.create(
|
||||||
|
repo_id="ccop/aloha_stationary_replay_test_v3",
|
||||||
|
fps=50,
|
||||||
|
robot_type="aloha-stationary",
|
||||||
|
features=features,
|
||||||
|
image_writer_threads=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
extract_episode(raw_folder / episode_file, features, n_frames, dataset)
|
||||||
|
print("save episode!")
|
||||||
|
dataset.save_episode(
|
||||||
|
task="move_cube",
|
||||||
|
)
|
||||||
|
dataset.consolidate()
|
||||||
|
dataset.push_to_hub()
|
Loading…
Reference in New Issue