fix: aloha_hd5 to LerobotDataset v2 frame appending out of the right scope
This commit is contained in:
parent
283545f3f1
commit
d8e4a2ccd7
|
@ -0,0 +1,82 @@
|
|||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import h5py
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
data_path = Path("/home/ccop/code/aloha_data")
|
||||
|
||||
|
||||
def get_features(hdf5_file):
|
||||
topics = []
|
||||
features = {}
|
||||
hdf5_file.visititems(lambda name, obj: topics.append(name) if isinstance(obj, h5py.Dataset) else None)
|
||||
for topic in topics:
|
||||
# print(topic.replace('/', '.'))
|
||||
if "images" in topic.split("/"):
|
||||
features[topic.replace("/", ".")] = {
|
||||
"dtype": "image",
|
||||
"shape": cv2.imdecode(hdf5_file[topic][0], 1).transpose(2, 0, 1).shape,
|
||||
"names": None,
|
||||
}
|
||||
elif "compress_len" in topic.split("/"):
|
||||
continue
|
||||
else:
|
||||
features[topic.replace("/", ".")] = {
|
||||
"dtype": str(hdf5_file[topic][0].dtype),
|
||||
"shape": hdf5_file[topic][0].shape,
|
||||
"names": None,
|
||||
}
|
||||
|
||||
return features
|
||||
|
||||
|
||||
def extract_episode(episode_path, features, n_frames, dataset):
|
||||
with h5py.File(episode_path, "r") as file:
|
||||
# List all groups
|
||||
for frame_idx in range(n_frames):
|
||||
frame = {}
|
||||
for feature in features:
|
||||
if "images" in feature.split("."):
|
||||
frame[feature] = torch.from_numpy(
|
||||
cv2.imdecode(file[feature.replace(".", "/")][frame_idx], 1).transpose(2, 0, 1)
|
||||
)
|
||||
else:
|
||||
frame[feature] = torch.from_numpy(file[feature.replace(".", "/")][frame_idx])
|
||||
|
||||
dataset.add_frame(frame)
|
||||
|
||||
|
||||
def get_dataset_properties(raw_folder):
|
||||
from os import listdir
|
||||
|
||||
episode_list = listdir(raw_folder)
|
||||
with h5py.File(raw_folder / episode_list[0], "r") as file:
|
||||
features = get_features(file)
|
||||
n_frames = file["observations/images/cam_high"][:].shape[0]
|
||||
return features, n_frames
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raw_folder = data_path.absolute() / "aloha_stationary_replay_test"
|
||||
episode_file = "episode_0.hdf5"
|
||||
|
||||
features, n_frames = get_dataset_properties(raw_folder)
|
||||
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="ccop/aloha_stationary_replay_test_v3",
|
||||
fps=50,
|
||||
robot_type="aloha-stationary",
|
||||
features=features,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
extract_episode(raw_folder / episode_file, features, n_frames, dataset)
|
||||
print("save episode!")
|
||||
dataset.save_episode(
|
||||
task="move_cube",
|
||||
)
|
||||
dataset.consolidate()
|
||||
dataset.push_to_hub()
|
Loading…
Reference in New Issue