Merge e5d3ed4de9
into 768e36660d
This commit is contained in:
commit
7bde337b49
|
@ -0,0 +1,492 @@
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import traceback
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import h5py
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME, LeRobotDataset
|
||||||
|
|
||||||
|
|
||||||
|
class AlohaHD5Extractor:
|
||||||
|
TAGS = ["aloha", "robotics", "hdf5"]
|
||||||
|
aloha_stationary = "aloha-stationary"
|
||||||
|
aloha_mobile = "aloha-mobile"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_cameras(hdf5_data: h5py.File):
|
||||||
|
"""
|
||||||
|
Extracts the list of RGB camera keys from the given HDF5 data.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
hdf5_data : h5py.File
|
||||||
|
The HDF5 file object containing the dataset.
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
list of str
|
||||||
|
A list of keys corresponding to RGB cameras in the dataset.
|
||||||
|
"""
|
||||||
|
|
||||||
|
rgb_cameras = [key for key in hdf5_data["/observations/images"] if "depth" not in key]
|
||||||
|
return rgb_cameras
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def check_format(episode_list: list[str] | list[Path], image_compressed: bool = True):
|
||||||
|
"""
|
||||||
|
Check the format of the given list of HDF5 files.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
episode_list : list of str or list of Path
|
||||||
|
List of paths to the HDF5 files to be checked.
|
||||||
|
image_compressed : bool, optional
|
||||||
|
Flag indicating whether the images are compressed (default is True).
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If the episode_list is empty.
|
||||||
|
If any HDF5 file is missing required keys '/action' or '/observations/qpos'.
|
||||||
|
If the '/action' or '/observations/qpos' keys do not have 2 dimensions.
|
||||||
|
If the number of frames in '/action' and '/observations/qpos' keys do not match.
|
||||||
|
If the number of frames in '/observations/images/{camera}' does not match the number of frames in '/action' and '/observations/qpos'.
|
||||||
|
If the dimensions of images do not match the expected dimensions based on the image_compressed flag.
|
||||||
|
If uncompressed images do not have the expected (h, w, c) format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not episode_list:
|
||||||
|
raise ValueError(
|
||||||
|
"No hdf5 files found in the raw directory. Make sure they are named 'episode_*.hdf5'"
|
||||||
|
)
|
||||||
|
for episode_path in episode_list:
|
||||||
|
with h5py.File(episode_path, "r") as data:
|
||||||
|
if not all(key in data for key in ["/action", "/observations/qpos"]):
|
||||||
|
raise ValueError(
|
||||||
|
"Missing required keys in the hdf5 file. Make sure the keys '/action' and '/observations/qpos' are present."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not data["/action"].ndim == data["/observations/qpos"].ndim == 2:
|
||||||
|
raise ValueError(
|
||||||
|
"The '/action' and '/observations/qpos' keys should have both 2 dimensions."
|
||||||
|
)
|
||||||
|
|
||||||
|
if (num_frames := data["/action"].shape[0]) != data["/observations/qpos"].shape[0]:
|
||||||
|
raise ValueError(
|
||||||
|
"The '/action' and '/observations/qpos' keys should have the same number of frames."
|
||||||
|
)
|
||||||
|
|
||||||
|
for camera in AlohaHD5Extractor.get_cameras(data):
|
||||||
|
if num_frames != data[f"/observations/images/{camera}"].shape[0]:
|
||||||
|
raise ValueError(
|
||||||
|
f"The number of frames in '/observations/images/{camera}' should be the same as in '/action' and '/observations/qpos' keys."
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_dims = 2 if image_compressed else 4
|
||||||
|
if data[f"/observations/images/{camera}"].ndim != expected_dims:
|
||||||
|
raise ValueError(
|
||||||
|
f"Expect {expected_dims} dimensions for {'compressed' if image_compressed else 'uncompressed'} images but {data[f'/observations/images/{camera}'].ndim} provided."
|
||||||
|
)
|
||||||
|
if not image_compressed:
|
||||||
|
b, h, w, c = data[f"/observations/images/{camera}"].shape
|
||||||
|
if not c < h and c < w:
|
||||||
|
raise ValueError(f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided.")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def extract_episode_frames(
|
||||||
|
episode_path: str | Path, features: dict[str, dict], image_compressed: bool
|
||||||
|
) -> list[dict[str, torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Extract frames from an episode stored in an HDF5 file.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
episode_path : str or Path
|
||||||
|
Path to the HDF5 file containing the episode data.
|
||||||
|
features : dict of str to dict
|
||||||
|
Dictionary where keys are feature identifiers and values are dictionaries with feature details.
|
||||||
|
image_compressed : bool
|
||||||
|
Flag indicating whether the images are stored in a compressed format.
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
list of dict of str to torch.Tensor
|
||||||
|
List of frames, where each frame is a dictionary mapping feature identifiers to tensors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
frames = []
|
||||||
|
with h5py.File(episode_path, "r") as file:
|
||||||
|
for frame_idx in range(file["/action"].shape[0]):
|
||||||
|
frame = {}
|
||||||
|
for feature_id in features:
|
||||||
|
feature_name_hd5 = (
|
||||||
|
feature_id.replace(".", "/")
|
||||||
|
.replace("observation", "observations")
|
||||||
|
.replace("state", "qpos")
|
||||||
|
)
|
||||||
|
if "images" in feature_id.split("."):
|
||||||
|
image = (
|
||||||
|
(file[feature_name_hd5][frame_idx])
|
||||||
|
if not image_compressed
|
||||||
|
else cv2.imdecode(file[feature_name_hd5][frame_idx], 1)
|
||||||
|
)
|
||||||
|
frame[feature_id] = torch.from_numpy(image.transpose(2, 0, 1))
|
||||||
|
else:
|
||||||
|
frame[feature_id] = torch.from_numpy(file[feature_name_hd5][frame_idx])
|
||||||
|
frames.append(frame)
|
||||||
|
return frames
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def define_features(
|
||||||
|
hdf5_file_path: Path, image_compressed: bool = True, encode_as_video: bool = True
|
||||||
|
) -> dict[str, dict]:
|
||||||
|
"""
|
||||||
|
Define features from an HDF5 file.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
hdf5_file_path : Path
|
||||||
|
The path to the HDF5 file.
|
||||||
|
image_compressed : bool, optional
|
||||||
|
Whether the images are compressed, by default True.
|
||||||
|
encode_as_video : bool, optional
|
||||||
|
Whether to encode images as video or as images, by default True.
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
dict[str, dict]
|
||||||
|
A dictionary where keys are topic names and values are dictionaries
|
||||||
|
containing feature information such as dtype, shape, and names.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Initialize lists to store topics and features
|
||||||
|
topics: list[str] = []
|
||||||
|
features: dict[str, dict] = {}
|
||||||
|
|
||||||
|
# Open the HDF5 file
|
||||||
|
with h5py.File(hdf5_file_path, "r") as hdf5_file:
|
||||||
|
# Collect all dataset names in the HDF5 file
|
||||||
|
hdf5_file.visititems(
|
||||||
|
lambda name, obj: topics.append(name) if isinstance(obj, h5py.Dataset) else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Iterate over each topic to define its features
|
||||||
|
for topic in topics:
|
||||||
|
# If the topic is an image, define it as a video feature
|
||||||
|
destination_topic = (
|
||||||
|
topic.replace("/", ".").replace("observations", "observation").replace("qpos", "state")
|
||||||
|
)
|
||||||
|
if "images" in topic.split("/"):
|
||||||
|
sample = hdf5_file[topic][0]
|
||||||
|
features[destination_topic] = {
|
||||||
|
"dtype": "video" if encode_as_video else "image",
|
||||||
|
"shape": cv2.imdecode(hdf5_file[topic][0], 1).transpose(2, 0, 1).shape
|
||||||
|
if image_compressed
|
||||||
|
else sample.shape,
|
||||||
|
"names": [
|
||||||
|
"channel",
|
||||||
|
"height",
|
||||||
|
"width",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
# Skip compressed length topics
|
||||||
|
elif "compress_len" in topic.split("/"):
|
||||||
|
continue
|
||||||
|
# Otherwise, define it as a regular feature
|
||||||
|
else:
|
||||||
|
features[destination_topic] = {
|
||||||
|
"dtype": str(hdf5_file[topic][0].dtype),
|
||||||
|
"shape": (topic_shape := hdf5_file[topic][0].shape),
|
||||||
|
"names": [f"{topic.split('/')[-1]}_{k}" for k in range(topic_shape[0])],
|
||||||
|
}
|
||||||
|
# Return the defined features
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetConverter:
|
||||||
|
"""
|
||||||
|
A class to convert datasets to Lerobot format.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
raw_path : Path or str
|
||||||
|
The path to the raw dataset.
|
||||||
|
dataset_repo_id : str
|
||||||
|
The repository ID where the dataset will be stored.
|
||||||
|
fps : int
|
||||||
|
Frames per second for the dataset.
|
||||||
|
robot_type : str, optional
|
||||||
|
The type of robot, by default "".
|
||||||
|
encode_as_videos : bool, optional
|
||||||
|
Whether to encode images as videos, by default True.
|
||||||
|
image_compressed : bool, optional
|
||||||
|
Whether the images are compressed, by default True.
|
||||||
|
image_writer_processes : int, optional
|
||||||
|
Number of processes for writing images, by default 0.
|
||||||
|
image_writer_threads : int, optional
|
||||||
|
Number of threads for writing images, by default 0.
|
||||||
|
Methods
|
||||||
|
-------
|
||||||
|
extract_episode(episode_path, task_description='')
|
||||||
|
Extracts frames from a single episode and saves it with a description.
|
||||||
|
extract_episodes(episode_description='')
|
||||||
|
Extracts frames from all episodes and saves them with a description.
|
||||||
|
push_dataset_to_hub(dataset_tags=None, private=False, push_videos=True, license="apache-2.0")
|
||||||
|
Pushes the dataset to the Hugging Face Hub.
|
||||||
|
init_lerobot_dataset()
|
||||||
|
Initializes the Lerobot dataset.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
raw_path: Path | str,
|
||||||
|
dataset_repo_id: str,
|
||||||
|
fps: int,
|
||||||
|
robot_type: str = "",
|
||||||
|
encode_as_videos: bool = True,
|
||||||
|
image_compressed: bool = True,
|
||||||
|
image_writer_processes: int = 0,
|
||||||
|
image_writer_threads: int = 0,
|
||||||
|
):
|
||||||
|
self.raw_path = raw_path if isinstance(raw_path, Path) else Path(raw_path)
|
||||||
|
self.dataset_repo_id = dataset_repo_id
|
||||||
|
self.fps = fps
|
||||||
|
self.robot_type = robot_type
|
||||||
|
self.image_compressed = image_compressed
|
||||||
|
self.image_writer_threads = image_writer_threads
|
||||||
|
self.image_writer_processes = image_writer_processes
|
||||||
|
self.encode_as_videos = encode_as_videos
|
||||||
|
|
||||||
|
self.logger = logging.getLogger(self.__class__.__name__)
|
||||||
|
self.logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
# Add console handler
|
||||||
|
console_handler = logging.StreamHandler()
|
||||||
|
console_handler.setLevel(logging.INFO)
|
||||||
|
formatter = logging.Formatter("%(asctime)s - [%(name)s] - %(message)s")
|
||||||
|
console_handler.setFormatter(formatter)
|
||||||
|
self.logger.addHandler(console_handler)
|
||||||
|
|
||||||
|
self.logger.info(f"{'-' * 10} Aloha HD5 -> Lerobot Converter {'-' * 10}")
|
||||||
|
self.logger.info(f"Processing Aloha HD5 dataset from {self.raw_path}")
|
||||||
|
self.logger.info(f"Dataset will be stored in {self.dataset_repo_id}")
|
||||||
|
self.logger.info(f"FPS: {self.fps}")
|
||||||
|
self.logger.info(f"Robot type: {self.robot_type}")
|
||||||
|
self.logger.info(f"Image compressed: {self.image_compressed}")
|
||||||
|
self.logger.info(f"Encoding images as videos: {self.encode_as_videos}")
|
||||||
|
self.logger.info(f"#writer processes: {self.image_writer_processes}")
|
||||||
|
self.logger.info(f"#writer threads: {self.image_writer_threads}")
|
||||||
|
|
||||||
|
self.episode_list = list(self.raw_path.glob("episode_*.hdf5"))
|
||||||
|
AlohaHD5Extractor.check_format(self.episode_list, image_compressed=self.image_compressed)
|
||||||
|
self.features = AlohaHD5Extractor.define_features(
|
||||||
|
self.episode_list[0],
|
||||||
|
image_compressed=self.image_compressed,
|
||||||
|
encode_as_video=self.encode_as_videos,
|
||||||
|
)
|
||||||
|
|
||||||
|
def extract_episode(self, episode_path, task_description: str = ""):
|
||||||
|
"""
|
||||||
|
Extracts frames from an episode and saves them to the dataset.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
episode_path : str
|
||||||
|
The path to the episode file.
|
||||||
|
task_description : str, optional
|
||||||
|
A description of the task associated with the episode (default is an empty string).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
|
||||||
|
for frame in AlohaHD5Extractor.extract_episode_frames(
|
||||||
|
episode_path, self.features, self.image_compressed
|
||||||
|
):
|
||||||
|
self.dataset.add_frame(frame)
|
||||||
|
self.logger.info(f"Saving Episode with Description: {task_description} ...")
|
||||||
|
self.dataset.save_episode(task=task_description)
|
||||||
|
|
||||||
|
def extract_episodes(self, episode_description: str = ""):
|
||||||
|
"""
|
||||||
|
Extracts episodes from the episode list and processes them.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
episode_description : str, optional
|
||||||
|
A description of the task to be passed to the extract_episode method (default is '').
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
Exception
|
||||||
|
If an error occurs during the processing of an episode, it will be caught and printed.
|
||||||
|
Notes
|
||||||
|
-----
|
||||||
|
After processing all episodes, the dataset is consolidated.
|
||||||
|
"""
|
||||||
|
|
||||||
|
for episode_path in self.episode_list:
|
||||||
|
try:
|
||||||
|
self.extract_episode(episode_path, task_description=episode_description)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing episode {episode_path}", f"{e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
continue
|
||||||
|
self.dataset.consolidate()
|
||||||
|
|
||||||
|
def push_dataset_to_hub(
|
||||||
|
self,
|
||||||
|
dataset_tags: list[str] | None = None,
|
||||||
|
private: bool = False,
|
||||||
|
push_videos: bool = True,
|
||||||
|
license: str | None = "apache-2.0",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Pushes the dataset to the Hugging Face Hub.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
dataset_tags : list of str, optional
|
||||||
|
A list of tags to associate with the dataset on the Hub. Default is None.
|
||||||
|
private : bool, optional
|
||||||
|
If True, the dataset will be private. Default is False.
|
||||||
|
push_videos : bool, optional
|
||||||
|
If True, videos will be pushed along with the dataset. Default is True.
|
||||||
|
license : str, optional
|
||||||
|
The license under which the dataset is released. Default is "apache-2.0".
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.logger.info(f"Pushing dataset to Hugging Face Hub. ID: {self.dataset_repo_id} ...")
|
||||||
|
self.dataset.push_to_hub(
|
||||||
|
tags=dataset_tags,
|
||||||
|
license=license,
|
||||||
|
push_videos=push_videos,
|
||||||
|
private=private,
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_lerobot_dataset(self):
|
||||||
|
"""
|
||||||
|
Initializes the LeRobot dataset.
|
||||||
|
This method cleans the cache if the dataset already exists and then creates a new LeRobot dataset.
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
LeRobotDataset
|
||||||
|
The initialized LeRobot dataset.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Clean the cache if the dataset already exists
|
||||||
|
if os.path.exists(LEROBOT_HOME / self.dataset_repo_id):
|
||||||
|
shutil.rmtree(LEROBOT_HOME / self.dataset_repo_id)
|
||||||
|
self.dataset = LeRobotDataset.create(
|
||||||
|
repo_id=self.dataset_repo_id,
|
||||||
|
fps=self.fps,
|
||||||
|
robot_type=self.robot_type,
|
||||||
|
features=self.features,
|
||||||
|
image_writer_threads=self.image_writer_threads,
|
||||||
|
image_writer_processes=self.image_writer_processes,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.dataset
|
||||||
|
|
||||||
|
|
||||||
|
def str2bool(value):
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return value
|
||||||
|
value = value.lower()
|
||||||
|
if value in ("yes", "true", "t", "y", "1"):
|
||||||
|
return True
|
||||||
|
elif value in ("no", "false", "f", "n", "0"):
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
raise argparse.ArgumentTypeError("Boolean value expected.")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""
|
||||||
|
Convert Aloha HD5 dataset and push to Hugging Face hub.
|
||||||
|
This script processes raw HDF5 files from the Aloha dataset, converts them into a specified format,
|
||||||
|
and optionally uploads the dataset to the Hugging Face hub.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
--raw-path : Path
|
||||||
|
Directory containing the raw HDF5 files.
|
||||||
|
--dataset-repo-id : str
|
||||||
|
Repository ID where the dataset will be stored.
|
||||||
|
--fps : int
|
||||||
|
Frames per second for the dataset.
|
||||||
|
--robot-type : str, optional
|
||||||
|
Type of robot, either "aloha-stationary" or "aloha-mobile". Default is "aloha-stationary".
|
||||||
|
--private : bool, optional
|
||||||
|
Set to True to make the dataset private. Default is False.
|
||||||
|
--push-videos : bool, optional
|
||||||
|
Set to True to push videos to the hub. Default is True.
|
||||||
|
--license : str, optional
|
||||||
|
License for the dataset. Default is "apache-2.0".
|
||||||
|
--image-compressed : bool, optional
|
||||||
|
Set to True if the images are compressed. Default is True.
|
||||||
|
--video-encoding : bool, optional
|
||||||
|
Set to True to encode images as videos. Default is True.
|
||||||
|
--nproc : int, optional
|
||||||
|
Number of image writer processes. Default is 10.
|
||||||
|
--nthreads : int, optional
|
||||||
|
Number of image writer threads. Default is 5.
|
||||||
|
"""
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="Convert Aloha HD5 dataset and push to Hugging Face hub.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--raw-path", type=Path, required=True, help="Directory containing the raw hdf5 files."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset-repo-id", type=str, required=True, help="Repository ID where the dataset will be stored."
|
||||||
|
)
|
||||||
|
parser.add_argument("--fps", type=int, required=True, help="Frames per second for the dataset.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--description", type=str, help="Description of the dataset.", default="Aloha recorded dataset."
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--robot-type",
|
||||||
|
type=str,
|
||||||
|
choices=["aloha-stationary", "aloha-mobile"],
|
||||||
|
default="aloha-stationary",
|
||||||
|
help="Type of robot.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--private", type=str2bool, default=False, help="Set to True to make the dataset private."
|
||||||
|
)
|
||||||
|
parser.add_argument("--push", type=str2bool, default=True, help="Set to True to push videos to the hub.")
|
||||||
|
parser.add_argument("--license", type=str, default="apache-2.0", help="License for the dataset.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--image-compressed", type=str2bool, default=True, help="Set to True if the images are compressed."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--video-encoding", type=str2bool, default=True, help="Set to True to encode images as videos."
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument("--nproc", type=int, default=10, help="Number of image writer processes.")
|
||||||
|
parser.add_argument("--nthreads", type=int, default=5, help="Number of image writer threads.")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
print(
|
||||||
|
args.video_encoding,
|
||||||
|
"-------------------------------------------------------------------------------------------------------",
|
||||||
|
)
|
||||||
|
|
||||||
|
converter = DatasetConverter(
|
||||||
|
raw_path=args.raw_path,
|
||||||
|
dataset_repo_id=args.dataset_repo_id,
|
||||||
|
fps=args.fps,
|
||||||
|
robot_type=args.robot_type,
|
||||||
|
image_compressed=args.image_compressed,
|
||||||
|
encode_as_videos=args.video_encoding,
|
||||||
|
image_writer_processes=args.nproc,
|
||||||
|
image_writer_threads=args.nthreads,
|
||||||
|
)
|
||||||
|
converter.init_lerobot_dataset()
|
||||||
|
converter.extract_episodes(episode_description=args.description)
|
||||||
|
|
||||||
|
if args.push:
|
||||||
|
converter.push_dataset_to_hub(
|
||||||
|
dataset_tags=AlohaHD5Extractor.TAGS, private=args.private, push_videos=True, license=args.license
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -0,0 +1,82 @@
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import h5py
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
|
||||||
|
data_path = Path("/home/ccop/code/aloha_data")
|
||||||
|
|
||||||
|
|
||||||
|
def get_features(hdf5_file):
|
||||||
|
topics = []
|
||||||
|
features = {}
|
||||||
|
hdf5_file.visititems(lambda name, obj: topics.append(name) if isinstance(obj, h5py.Dataset) else None)
|
||||||
|
for topic in topics:
|
||||||
|
# print(topic.replace('/', '.'))
|
||||||
|
if "images" in topic.split("/"):
|
||||||
|
features[topic.replace("/", ".")] = {
|
||||||
|
"dtype": "image",
|
||||||
|
"shape": cv2.imdecode(hdf5_file[topic][0], 1).transpose(2, 0, 1).shape,
|
||||||
|
"names": None,
|
||||||
|
}
|
||||||
|
elif "compress_len" in topic.split("/"):
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
features[topic.replace("/", ".")] = {
|
||||||
|
"dtype": str(hdf5_file[topic][0].dtype),
|
||||||
|
"shape": hdf5_file[topic][0].shape,
|
||||||
|
"names": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
def extract_episode(episode_path, features, n_frames, dataset):
|
||||||
|
with h5py.File(episode_path, "r") as file:
|
||||||
|
# List all groups
|
||||||
|
for frame_idx in range(n_frames):
|
||||||
|
frame = {}
|
||||||
|
for feature in features:
|
||||||
|
if "images" in feature.split("."):
|
||||||
|
frame[feature] = torch.from_numpy(
|
||||||
|
cv2.imdecode(file[feature.replace(".", "/")][frame_idx], 1).transpose(2, 0, 1)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
frame[feature] = torch.from_numpy(file[feature.replace(".", "/")][frame_idx])
|
||||||
|
|
||||||
|
dataset.add_frame(frame)
|
||||||
|
|
||||||
|
|
||||||
|
def get_dataset_properties(raw_folder):
|
||||||
|
from os import listdir
|
||||||
|
|
||||||
|
episode_list = listdir(raw_folder)
|
||||||
|
with h5py.File(raw_folder / episode_list[0], "r") as file:
|
||||||
|
features = get_features(file)
|
||||||
|
n_frames = file["observations/images/cam_high"][:].shape[0]
|
||||||
|
return features, n_frames
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raw_folder = data_path.absolute() / "aloha_stationary_replay_test"
|
||||||
|
episode_file = "episode_0.hdf5"
|
||||||
|
|
||||||
|
features, n_frames = get_dataset_properties(raw_folder)
|
||||||
|
|
||||||
|
dataset = LeRobotDataset.create(
|
||||||
|
repo_id="ccop/aloha_stationary_replay_test_v3",
|
||||||
|
fps=50,
|
||||||
|
robot_type="aloha-stationary",
|
||||||
|
features=features,
|
||||||
|
image_writer_threads=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
extract_episode(raw_folder / episode_file, features, n_frames, dataset)
|
||||||
|
print("save episode!")
|
||||||
|
dataset.save_episode(
|
||||||
|
task="move_cube",
|
||||||
|
)
|
||||||
|
dataset.consolidate()
|
||||||
|
dataset.push_to_hub()
|
Loading…
Reference in New Issue