This commit is contained in:
Claudio Coppola 2025-04-10 12:44:42 +02:00 committed by GitHub
commit b8d2cbdd62
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 574 additions and 0 deletions

View File

@ -0,0 +1,492 @@
import argparse
import logging
import os
import shutil
import traceback
from pathlib import Path
import cv2
import h5py
import torch
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME, LeRobotDataset
class AlohaHD5Extractor:
TAGS = ["aloha", "robotics", "hdf5"]
aloha_stationary = "aloha-stationary"
aloha_mobile = "aloha-mobile"
@staticmethod
def get_cameras(hdf5_data: h5py.File):
"""
Extracts the list of RGB camera keys from the given HDF5 data.
Parameters
----------
hdf5_data : h5py.File
The HDF5 file object containing the dataset.
Returns
-------
list of str
A list of keys corresponding to RGB cameras in the dataset.
"""
rgb_cameras = [key for key in hdf5_data["/observations/images"] if "depth" not in key]
return rgb_cameras
@staticmethod
def check_format(episode_list: list[str] | list[Path], image_compressed: bool = True):
"""
Check the format of the given list of HDF5 files.
Parameters
----------
episode_list : list of str or list of Path
List of paths to the HDF5 files to be checked.
image_compressed : bool, optional
Flag indicating whether the images are compressed (default is True).
Raises
------
ValueError
If the episode_list is empty.
If any HDF5 file is missing required keys '/action' or '/observations/qpos'.
If the '/action' or '/observations/qpos' keys do not have 2 dimensions.
If the number of frames in '/action' and '/observations/qpos' keys do not match.
If the number of frames in '/observations/images/{camera}' does not match the number of frames in '/action' and '/observations/qpos'.
If the dimensions of images do not match the expected dimensions based on the image_compressed flag.
If uncompressed images do not have the expected (h, w, c) format.
"""
if not episode_list:
raise ValueError(
"No hdf5 files found in the raw directory. Make sure they are named 'episode_*.hdf5'"
)
for episode_path in episode_list:
with h5py.File(episode_path, "r") as data:
if not all(key in data for key in ["/action", "/observations/qpos"]):
raise ValueError(
"Missing required keys in the hdf5 file. Make sure the keys '/action' and '/observations/qpos' are present."
)
if not data["/action"].ndim == data["/observations/qpos"].ndim == 2:
raise ValueError(
"The '/action' and '/observations/qpos' keys should have both 2 dimensions."
)
if (num_frames := data["/action"].shape[0]) != data["/observations/qpos"].shape[0]:
raise ValueError(
"The '/action' and '/observations/qpos' keys should have the same number of frames."
)
for camera in AlohaHD5Extractor.get_cameras(data):
if num_frames != data[f"/observations/images/{camera}"].shape[0]:
raise ValueError(
f"The number of frames in '/observations/images/{camera}' should be the same as in '/action' and '/observations/qpos' keys."
)
expected_dims = 2 if image_compressed else 4
if data[f"/observations/images/{camera}"].ndim != expected_dims:
raise ValueError(
f"Expect {expected_dims} dimensions for {'compressed' if image_compressed else 'uncompressed'} images but {data[f'/observations/images/{camera}'].ndim} provided."
)
if not image_compressed:
b, h, w, c = data[f"/observations/images/{camera}"].shape
if not c < h and c < w:
raise ValueError(f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided.")
@staticmethod
def extract_episode_frames(
episode_path: str | Path, features: dict[str, dict], image_compressed: bool
) -> list[dict[str, torch.Tensor]]:
"""
Extract frames from an episode stored in an HDF5 file.
Parameters
----------
episode_path : str or Path
Path to the HDF5 file containing the episode data.
features : dict of str to dict
Dictionary where keys are feature identifiers and values are dictionaries with feature details.
image_compressed : bool
Flag indicating whether the images are stored in a compressed format.
Returns
-------
list of dict of str to torch.Tensor
List of frames, where each frame is a dictionary mapping feature identifiers to tensors.
"""
frames = []
with h5py.File(episode_path, "r") as file:
for frame_idx in range(file["/action"].shape[0]):
frame = {}
for feature_id in features:
feature_name_hd5 = (
feature_id.replace(".", "/")
.replace("observation", "observations")
.replace("state", "qpos")
)
if "images" in feature_id.split("."):
image = (
(file[feature_name_hd5][frame_idx])
if not image_compressed
else cv2.imdecode(file[feature_name_hd5][frame_idx], 1)
)
frame[feature_id] = torch.from_numpy(image.transpose(2, 0, 1))
else:
frame[feature_id] = torch.from_numpy(file[feature_name_hd5][frame_idx])
frames.append(frame)
return frames
@staticmethod
def define_features(
hdf5_file_path: Path, image_compressed: bool = True, encode_as_video: bool = True
) -> dict[str, dict]:
"""
Define features from an HDF5 file.
Parameters
----------
hdf5_file_path : Path
The path to the HDF5 file.
image_compressed : bool, optional
Whether the images are compressed, by default True.
encode_as_video : bool, optional
Whether to encode images as video or as images, by default True.
Returns
-------
dict[str, dict]
A dictionary where keys are topic names and values are dictionaries
containing feature information such as dtype, shape, and names.
"""
# Initialize lists to store topics and features
topics: list[str] = []
features: dict[str, dict] = {}
# Open the HDF5 file
with h5py.File(hdf5_file_path, "r") as hdf5_file:
# Collect all dataset names in the HDF5 file
hdf5_file.visititems(
lambda name, obj: topics.append(name) if isinstance(obj, h5py.Dataset) else None
)
# Iterate over each topic to define its features
for topic in topics:
# If the topic is an image, define it as a video feature
destination_topic = (
topic.replace("/", ".").replace("observations", "observation").replace("qpos", "state")
)
if "images" in topic.split("/"):
sample = hdf5_file[topic][0]
features[destination_topic] = {
"dtype": "video" if encode_as_video else "image",
"shape": cv2.imdecode(hdf5_file[topic][0], 1).transpose(2, 0, 1).shape
if image_compressed
else sample.shape,
"names": [
"channel",
"height",
"width",
],
}
# Skip compressed length topics
elif "compress_len" in topic.split("/"):
continue
# Otherwise, define it as a regular feature
else:
features[destination_topic] = {
"dtype": str(hdf5_file[topic][0].dtype),
"shape": (topic_shape := hdf5_file[topic][0].shape),
"names": [f"{topic.split('/')[-1]}_{k}" for k in range(topic_shape[0])],
}
# Return the defined features
return features
class DatasetConverter:
"""
A class to convert datasets to Lerobot format.
Parameters
----------
raw_path : Path or str
The path to the raw dataset.
dataset_repo_id : str
The repository ID where the dataset will be stored.
fps : int
Frames per second for the dataset.
robot_type : str, optional
The type of robot, by default "".
encode_as_videos : bool, optional
Whether to encode images as videos, by default True.
image_compressed : bool, optional
Whether the images are compressed, by default True.
image_writer_processes : int, optional
Number of processes for writing images, by default 0.
image_writer_threads : int, optional
Number of threads for writing images, by default 0.
Methods
-------
extract_episode(episode_path, task_description='')
Extracts frames from a single episode and saves it with a description.
extract_episodes(episode_description='')
Extracts frames from all episodes and saves them with a description.
push_dataset_to_hub(dataset_tags=None, private=False, push_videos=True, license="apache-2.0")
Pushes the dataset to the Hugging Face Hub.
init_lerobot_dataset()
Initializes the Lerobot dataset.
"""
def __init__(
self,
raw_path: Path | str,
dataset_repo_id: str,
fps: int,
robot_type: str = "",
encode_as_videos: bool = True,
image_compressed: bool = True,
image_writer_processes: int = 0,
image_writer_threads: int = 0,
):
self.raw_path = raw_path if isinstance(raw_path, Path) else Path(raw_path)
self.dataset_repo_id = dataset_repo_id
self.fps = fps
self.robot_type = robot_type
self.image_compressed = image_compressed
self.image_writer_threads = image_writer_threads
self.image_writer_processes = image_writer_processes
self.encode_as_videos = encode_as_videos
self.logger = logging.getLogger(self.__class__.__name__)
self.logger.setLevel(logging.INFO)
# Add console handler
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s - [%(name)s] - %(message)s")
console_handler.setFormatter(formatter)
self.logger.addHandler(console_handler)
self.logger.info(f"{'-' * 10} Aloha HD5 -> Lerobot Converter {'-' * 10}")
self.logger.info(f"Processing Aloha HD5 dataset from {self.raw_path}")
self.logger.info(f"Dataset will be stored in {self.dataset_repo_id}")
self.logger.info(f"FPS: {self.fps}")
self.logger.info(f"Robot type: {self.robot_type}")
self.logger.info(f"Image compressed: {self.image_compressed}")
self.logger.info(f"Encoding images as videos: {self.encode_as_videos}")
self.logger.info(f"#writer processes: {self.image_writer_processes}")
self.logger.info(f"#writer threads: {self.image_writer_threads}")
self.episode_list = list(self.raw_path.glob("episode_*.hdf5"))
AlohaHD5Extractor.check_format(self.episode_list, image_compressed=self.image_compressed)
self.features = AlohaHD5Extractor.define_features(
self.episode_list[0],
image_compressed=self.image_compressed,
encode_as_video=self.encode_as_videos,
)
def extract_episode(self, episode_path, task_description: str = ""):
"""
Extracts frames from an episode and saves them to the dataset.
Parameters
----------
episode_path : str
The path to the episode file.
task_description : str, optional
A description of the task associated with the episode (default is an empty string).
Returns
-------
None
"""
for frame in AlohaHD5Extractor.extract_episode_frames(
episode_path, self.features, self.image_compressed
):
self.dataset.add_frame(frame)
self.logger.info(f"Saving Episode with Description: {task_description} ...")
self.dataset.save_episode(task=task_description)
def extract_episodes(self, episode_description: str = ""):
"""
Extracts episodes from the episode list and processes them.
Parameters
----------
episode_description : str, optional
A description of the task to be passed to the extract_episode method (default is '').
Raises
------
Exception
If an error occurs during the processing of an episode, it will be caught and printed.
Notes
-----
After processing all episodes, the dataset is consolidated.
"""
for episode_path in self.episode_list:
try:
self.extract_episode(episode_path, task_description=episode_description)
except Exception as e:
print(f"Error processing episode {episode_path}", f"{e}")
traceback.print_exc()
continue
self.dataset.consolidate()
def push_dataset_to_hub(
self,
dataset_tags: list[str] | None = None,
private: bool = False,
push_videos: bool = True,
license: str | None = "apache-2.0",
):
"""
Pushes the dataset to the Hugging Face Hub.
Parameters
----------
dataset_tags : list of str, optional
A list of tags to associate with the dataset on the Hub. Default is None.
private : bool, optional
If True, the dataset will be private. Default is False.
push_videos : bool, optional
If True, videos will be pushed along with the dataset. Default is True.
license : str, optional
The license under which the dataset is released. Default is "apache-2.0".
Returns
-------
None
"""
self.logger.info(f"Pushing dataset to Hugging Face Hub. ID: {self.dataset_repo_id} ...")
self.dataset.push_to_hub(
tags=dataset_tags,
license=license,
push_videos=push_videos,
private=private,
)
def init_lerobot_dataset(self):
"""
Initializes the LeRobot dataset.
This method cleans the cache if the dataset already exists and then creates a new LeRobot dataset.
Returns
-------
LeRobotDataset
The initialized LeRobot dataset.
"""
# Clean the cache if the dataset already exists
if os.path.exists(LEROBOT_HOME / self.dataset_repo_id):
shutil.rmtree(LEROBOT_HOME / self.dataset_repo_id)
self.dataset = LeRobotDataset.create(
repo_id=self.dataset_repo_id,
fps=self.fps,
robot_type=self.robot_type,
features=self.features,
image_writer_threads=self.image_writer_threads,
image_writer_processes=self.image_writer_processes,
)
return self.dataset
def str2bool(value):
if isinstance(value, bool):
return value
value = value.lower()
if value in ("yes", "true", "t", "y", "1"):
return True
elif value in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("Boolean value expected.")
def main():
"""
Convert Aloha HD5 dataset and push to Hugging Face hub.
This script processes raw HDF5 files from the Aloha dataset, converts them into a specified format,
and optionally uploads the dataset to the Hugging Face hub.
Parameters
----------
--raw-path : Path
Directory containing the raw HDF5 files.
--dataset-repo-id : str
Repository ID where the dataset will be stored.
--fps : int
Frames per second for the dataset.
--robot-type : str, optional
Type of robot, either "aloha-stationary" or "aloha-mobile". Default is "aloha-stationary".
--private : bool, optional
Set to True to make the dataset private. Default is False.
--push-videos : bool, optional
Set to True to push videos to the hub. Default is True.
--license : str, optional
License for the dataset. Default is "apache-2.0".
--image-compressed : bool, optional
Set to True if the images are compressed. Default is True.
--video-encoding : bool, optional
Set to True to encode images as videos. Default is True.
--nproc : int, optional
Number of image writer processes. Default is 10.
--nthreads : int, optional
Number of image writer threads. Default is 5.
"""
parser = argparse.ArgumentParser(description="Convert Aloha HD5 dataset and push to Hugging Face hub.")
parser.add_argument(
"--raw-path", type=Path, required=True, help="Directory containing the raw hdf5 files."
)
parser.add_argument(
"--dataset-repo-id", type=str, required=True, help="Repository ID where the dataset will be stored."
)
parser.add_argument("--fps", type=int, required=True, help="Frames per second for the dataset.")
parser.add_argument(
"--description", type=str, help="Description of the dataset.", default="Aloha recorded dataset."
)
parser.add_argument(
"--robot-type",
type=str,
choices=["aloha-stationary", "aloha-mobile"],
default="aloha-stationary",
help="Type of robot.",
)
parser.add_argument(
"--private", type=str2bool, default=False, help="Set to True to make the dataset private."
)
parser.add_argument("--push", type=str2bool, default=True, help="Set to True to push videos to the hub.")
parser.add_argument("--license", type=str, default="apache-2.0", help="License for the dataset.")
parser.add_argument(
"--image-compressed", type=str2bool, default=True, help="Set to True if the images are compressed."
)
parser.add_argument(
"--video-encoding", type=str2bool, default=True, help="Set to True to encode images as videos."
)
parser.add_argument("--nproc", type=int, default=10, help="Number of image writer processes.")
parser.add_argument("--nthreads", type=int, default=5, help="Number of image writer threads.")
args = parser.parse_args()
print(
args.video_encoding,
"-------------------------------------------------------------------------------------------------------",
)
converter = DatasetConverter(
raw_path=args.raw_path,
dataset_repo_id=args.dataset_repo_id,
fps=args.fps,
robot_type=args.robot_type,
image_compressed=args.image_compressed,
encode_as_videos=args.video_encoding,
image_writer_processes=args.nproc,
image_writer_threads=args.nthreads,
)
converter.init_lerobot_dataset()
converter.extract_episodes(episode_description=args.description)
if args.push:
converter.push_dataset_to_hub(
dataset_tags=AlohaHD5Extractor.TAGS, private=args.private, push_videos=True, license=args.license
)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,82 @@
from pathlib import Path
import cv2
import h5py
import torch
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
data_path = Path("/home/ccop/code/aloha_data")
def get_features(hdf5_file):
topics = []
features = {}
hdf5_file.visititems(lambda name, obj: topics.append(name) if isinstance(obj, h5py.Dataset) else None)
for topic in topics:
# print(topic.replace('/', '.'))
if "images" in topic.split("/"):
features[topic.replace("/", ".")] = {
"dtype": "image",
"shape": cv2.imdecode(hdf5_file[topic][0], 1).transpose(2, 0, 1).shape,
"names": None,
}
elif "compress_len" in topic.split("/"):
continue
else:
features[topic.replace("/", ".")] = {
"dtype": str(hdf5_file[topic][0].dtype),
"shape": hdf5_file[topic][0].shape,
"names": None,
}
return features
def extract_episode(episode_path, features, n_frames, dataset):
with h5py.File(episode_path, "r") as file:
# List all groups
for frame_idx in range(n_frames):
frame = {}
for feature in features:
if "images" in feature.split("."):
frame[feature] = torch.from_numpy(
cv2.imdecode(file[feature.replace(".", "/")][frame_idx], 1).transpose(2, 0, 1)
)
else:
frame[feature] = torch.from_numpy(file[feature.replace(".", "/")][frame_idx])
dataset.add_frame(frame)
def get_dataset_properties(raw_folder):
from os import listdir
episode_list = listdir(raw_folder)
with h5py.File(raw_folder / episode_list[0], "r") as file:
features = get_features(file)
n_frames = file["observations/images/cam_high"][:].shape[0]
return features, n_frames
if __name__ == "__main__":
raw_folder = data_path.absolute() / "aloha_stationary_replay_test"
episode_file = "episode_0.hdf5"
features, n_frames = get_dataset_properties(raw_folder)
dataset = LeRobotDataset.create(
repo_id="ccop/aloha_stationary_replay_test_v3",
fps=50,
robot_type="aloha-stationary",
features=features,
image_writer_threads=4,
)
extract_episode(raw_folder / episode_file, features, n_frames, dataset)
print("save episode!")
dataset.save_episode(
task="move_cube",
)
dataset.consolidate()
dataset.push_to_hub()