From 017d5628f1c3775cfbdae5aad35c2a443b1159b6 Mon Sep 17 00:00:00 2001 From: Ben Sprenger Date: Sun, 23 Mar 2025 21:41:34 +0100 Subject: [PATCH] feat(dataset): support hive partitioning schema --- lerobot/common/datasets/lerobot_dataset.py | 21 ++- .../v22/convert_dataset_v21_to_v22.py | 173 ++++++++++++++++++ 2 files changed, 190 insertions(+), 4 deletions(-) create mode 100644 lerobot/common/datasets/v22/convert_dataset_v21_to_v22.py diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index d8da85d6..d66c6b12 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -23,6 +23,7 @@ import datasets import numpy as np import packaging.version import PIL.Image +import pyarrow.dataset as pyarrow_ds import torch import torch.utils from datasets import concatenate_datasets, load_dataset @@ -615,12 +616,24 @@ class LeRobotDataset(torch.utils.data.Dataset): def load_hf_dataset(self) -> datasets.Dataset: """hf_dataset contains all the observations, states, actions, rewards, etc.""" + if self.episodes is None: - path = str(self.root / "data") - hf_dataset = load_dataset("parquet", data_dir=path, split="train") + path_or_files = str(self.root / "data") else: - files = [str(self.root / self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes] - hf_dataset = load_dataset("parquet", data_files=files, split="train") + path_or_files = [ + str(self.root / self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes + ] + + # HACK: load_dataset does not directly support hive partitioning + # however, it is needed to sync the local cache with the hub + # here we use load_dataset to sync the local cache, and then reload + # with hive partitioning + load_dataset("parquet", data_dir=path_or_files, split="train") + + # now that the cache is synced, we can load the dataset with hive partitioning + pa_dataset = pyarrow_ds.dataset(path_or_files, format="parquet", partitioning="hive") + table = pa_dataset.to_table() + hf_dataset = datasets.Dataset(arrow_table=table) # TODO(aliberts): hf_dataset.set_format("torch") hf_dataset.set_transform(hf_transform_to_torch) diff --git a/lerobot/common/datasets/v22/convert_dataset_v21_to_v22.py b/lerobot/common/datasets/v22/convert_dataset_v21_to_v22.py new file mode 100644 index 00000000..6a1ea6d5 --- /dev/null +++ b/lerobot/common/datasets/v22/convert_dataset_v21_to_v22.py @@ -0,0 +1,173 @@ +import argparse +import logging +import re +from pathlib import Path + +from datasets import Dataset +from huggingface_hub import HfApi + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.datasets.utils import load_info, write_info +from lerobot.common.utils.utils import init_logging + + +def process_parquet_files(root_path: Path) -> None: + """ + Processes parquet files in data/chunk-* folders. + + For each `episode_XXX.parquet` file found directly within a `chunk-YYY` directory: + 1. Removes the 'episode_index' column. + 2. Saves the modified data to a new directory structure: + `data/chunk-YYY/episode_index=XXX/episode_XXX.parquet`. + 3. Records the original file path for later removal. + + Args: + root_path: The root directory of the LeRobot dataset. + + Returns: + A list of relative paths (from dataset root) of the original parquet files + that were processed and removed locally. + """ + data_folder = root_path / "data" + + chunk_pattern = re.compile(r"^chunk-(\d+)$") + episode_pattern = re.compile(r"^episode_(\d+)\.parquet$") + + # Process only valid chunk directories + chunk_dirs = [d for d in data_folder.iterdir() if d.is_dir() and chunk_pattern.match(d.name)] + + for chunk_path in chunk_dirs: + logging.info(f"Processing chunk directory: {chunk_path.name}") + + # Find all episode files in this chunk directory + episode_files = [f for f in chunk_path.iterdir() if f.is_file() and episode_pattern.match(f.name)] + + for parquet_file_path in episode_files: + match = episode_pattern.match(parquet_file_path.name) + if not match: + continue + + episode_index = match.group(1) + try: + # Load, process, and save in new location + ds = Dataset.from_parquet(str(parquet_file_path)) + ds = ds.remove_columns(["episode_index"]) + + new_folder = chunk_path / f"episode_index={episode_index}" + new_file_path = new_folder / parquet_file_path.name + + new_folder.mkdir(exist_ok=True) + ds.to_parquet(new_file_path) + + # Record for removal and delete original + parquet_file_path.unlink() + + logging.debug(f"Processed {parquet_file_path.name} successfully") + + except Exception as e: + logging.error(f"Error processing {parquet_file_path.name}: {e}") + + +def rewrite_info_file(root_path: Path) -> None: + """ + Updates the info.json file with the new data_path format and codebase_version. + + Args: + root_path: The root directory of the LeRobot dataset. + """ + info = load_info(root_path) + info["data_path"] = ( + "data/chunk-{episode_chunk:03d}/episode_index={episode_index:06d}/episode_{episode_index:06d}.parquet" + ) + info["codebase_version"] = "v2.2" + write_info(info, root_path) + logging.info("Updated info.json with new data path format and codebase version") + + +def parse_arguments() -> argparse.Namespace: + """Parse and return command line arguments.""" + parser = argparse.ArgumentParser( + description="Process and restructure parquet files in a LeRobot dataset, " + "removing the 'episode_index' column and updating the directory structure. " + "Optionally pushes changes to the Hugging Face Hub." + ) + parser.add_argument( + "--repo-id", + type=str, + required=True, + help="Name of the Hugging Face repository containing the LeRobotDataset (e.g., `lerobot/pusht`).", + ) + parser.add_argument( + "--root", + type=Path, + default=None, + help="Root directory for the dataset stored locally. If None, it will use the Hugging Face cache " + "or download from the Hub.", + ) + parser.add_argument( + "--branch", + type=str, + default=None, # Use None to default to 'main' + help="The git revision (branch, tag, or commit hash) of the dataset to process and push to. Defaults to 'main'.", + ) + parser.add_argument( + "--push-to-hub", + type=int, + default=1, + help="Upload to Hugging Face hub.", + ) + parser.add_argument( + "--private", + action="store_true", + help="If set, the repository on the Hub will be private", + ) + parser.add_argument( + "--tags", + type=str, + nargs="+", + help="List of tags to apply to the dataset on the Hub", + ) + parser.add_argument( + "--upload-large-folder", + action="store_true", + help="If set, upload large folders to the Hub using the `upload_large_folder` method", + ) + parser.add_argument("--license", type=str, default=None, help="License to use for the dataset on the Hub") + return parser.parse_args() + + +def main() -> None: + args = parse_arguments() + + logging.info(f"Loading dataset '{args.repo_id}'...") + dataset = LeRobotDataset(args.repo_id, root=args.root) + logging.info(f"Dataset loaded from {dataset.root}") + + process_parquet_files(dataset.root) + rewrite_info_file(dataset.root) + + if args.push_to_hub: + logging.info(f"Pushing changes to the Hugging Face Hub repository '{args.repo_id}'...") + hub_api = HfApi() + hub_api.delete_folder( + repo_id=args.repo_id, + path_in_repo="data/", + repo_type="dataset", + commit_message="chore: remove old data folder", + ) + + dataset.push_to_hub( + branch=args.branch, + private=bool(args.private), + tags=args.tags, + license=args.license, + upload_large_folder=bool(args.upload_large_folder), + ) + logging.info("Changes pushed to the Hugging Face Hub successfully.") + else: + logging.info("Skipping push to Hub. Dataset updated locally only.") + + +if __name__ == "__main__": + init_logging() + main()