This commit is contained in:
Ben Sprenger 2025-04-15 15:11:52 +02:00 committed by GitHub
commit 3810d1e764
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 190 additions and 4 deletions

View File

@ -23,6 +23,7 @@ import datasets
import numpy as np import numpy as np
import packaging.version import packaging.version
import PIL.Image import PIL.Image
import pyarrow.dataset as pyarrow_ds
import torch import torch
import torch.utils import torch.utils
from datasets import concatenate_datasets, load_dataset from datasets import concatenate_datasets, load_dataset
@ -615,12 +616,24 @@ class LeRobotDataset(torch.utils.data.Dataset):
def load_hf_dataset(self) -> datasets.Dataset: def load_hf_dataset(self) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc.""" """hf_dataset contains all the observations, states, actions, rewards, etc."""
if self.episodes is None: if self.episodes is None:
path = str(self.root / "data") path_or_files = str(self.root / "data")
hf_dataset = load_dataset("parquet", data_dir=path, split="train")
else: else:
files = [str(self.root / self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes] path_or_files = [
hf_dataset = load_dataset("parquet", data_files=files, split="train") str(self.root / self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes
]
# HACK: load_dataset does not directly support hive partitioning
# however, it is needed to sync the local cache with the hub
# here we use load_dataset to sync the local cache, and then reload
# with hive partitioning
load_dataset("parquet", data_dir=path_or_files, split="train")
# now that the cache is synced, we can load the dataset with hive partitioning
pa_dataset = pyarrow_ds.dataset(path_or_files, format="parquet", partitioning="hive")
table = pa_dataset.to_table()
hf_dataset = datasets.Dataset(arrow_table=table)
# TODO(aliberts): hf_dataset.set_format("torch") # TODO(aliberts): hf_dataset.set_format("torch")
hf_dataset.set_transform(hf_transform_to_torch) hf_dataset.set_transform(hf_transform_to_torch)

View File

@ -0,0 +1,173 @@
import argparse
import logging
import re
from pathlib import Path
from datasets import Dataset
from huggingface_hub import HfApi
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import load_info, write_info
from lerobot.common.utils.utils import init_logging
def process_parquet_files(root_path: Path) -> None:
"""
Processes parquet files in data/chunk-* folders.
For each `episode_XXX.parquet` file found directly within a `chunk-YYY` directory:
1. Removes the 'episode_index' column.
2. Saves the modified data to a new directory structure:
`data/chunk-YYY/episode_index=XXX/episode_XXX.parquet`.
3. Records the original file path for later removal.
Args:
root_path: The root directory of the LeRobot dataset.
Returns:
A list of relative paths (from dataset root) of the original parquet files
that were processed and removed locally.
"""
data_folder = root_path / "data"
chunk_pattern = re.compile(r"^chunk-(\d+)$")
episode_pattern = re.compile(r"^episode_(\d+)\.parquet$")
# Process only valid chunk directories
chunk_dirs = [d for d in data_folder.iterdir() if d.is_dir() and chunk_pattern.match(d.name)]
for chunk_path in chunk_dirs:
logging.info(f"Processing chunk directory: {chunk_path.name}")
# Find all episode files in this chunk directory
episode_files = [f for f in chunk_path.iterdir() if f.is_file() and episode_pattern.match(f.name)]
for parquet_file_path in episode_files:
match = episode_pattern.match(parquet_file_path.name)
if not match:
continue
episode_index = match.group(1)
try:
# Load, process, and save in new location
ds = Dataset.from_parquet(str(parquet_file_path))
ds = ds.remove_columns(["episode_index"])
new_folder = chunk_path / f"episode_index={episode_index}"
new_file_path = new_folder / parquet_file_path.name
new_folder.mkdir(exist_ok=True)
ds.to_parquet(new_file_path)
# Record for removal and delete original
parquet_file_path.unlink()
logging.debug(f"Processed {parquet_file_path.name} successfully")
except Exception as e:
logging.error(f"Error processing {parquet_file_path.name}: {e}")
def rewrite_info_file(root_path: Path) -> None:
"""
Updates the info.json file with the new data_path format and codebase_version.
Args:
root_path: The root directory of the LeRobot dataset.
"""
info = load_info(root_path)
info["data_path"] = (
"data/chunk-{episode_chunk:03d}/episode_index={episode_index:06d}/episode_{episode_index:06d}.parquet"
)
info["codebase_version"] = "v2.2"
write_info(info, root_path)
logging.info("Updated info.json with new data path format and codebase version")
def parse_arguments() -> argparse.Namespace:
"""Parse and return command line arguments."""
parser = argparse.ArgumentParser(
description="Process and restructure parquet files in a LeRobot dataset, "
"removing the 'episode_index' column and updating the directory structure. "
"Optionally pushes changes to the Hugging Face Hub."
)
parser.add_argument(
"--repo-id",
type=str,
required=True,
help="Name of the Hugging Face repository containing the LeRobotDataset (e.g., `lerobot/pusht`).",
)
parser.add_argument(
"--root",
type=Path,
default=None,
help="Root directory for the dataset stored locally. If None, it will use the Hugging Face cache "
"or download from the Hub.",
)
parser.add_argument(
"--branch",
type=str,
default=None, # Use None to default to 'main'
help="The git revision (branch, tag, or commit hash) of the dataset to process and push to. Defaults to 'main'.",
)
parser.add_argument(
"--push-to-hub",
type=int,
default=1,
help="Upload to Hugging Face hub.",
)
parser.add_argument(
"--private",
action="store_true",
help="If set, the repository on the Hub will be private",
)
parser.add_argument(
"--tags",
type=str,
nargs="+",
help="List of tags to apply to the dataset on the Hub",
)
parser.add_argument(
"--upload-large-folder",
action="store_true",
help="If set, upload large folders to the Hub using the `upload_large_folder` method",
)
parser.add_argument("--license", type=str, default=None, help="License to use for the dataset on the Hub")
return parser.parse_args()
def main() -> None:
args = parse_arguments()
logging.info(f"Loading dataset '{args.repo_id}'...")
dataset = LeRobotDataset(args.repo_id, root=args.root)
logging.info(f"Dataset loaded from {dataset.root}")
process_parquet_files(dataset.root)
rewrite_info_file(dataset.root)
if args.push_to_hub:
logging.info(f"Pushing changes to the Hugging Face Hub repository '{args.repo_id}'...")
hub_api = HfApi()
hub_api.delete_folder(
repo_id=args.repo_id,
path_in_repo="data/",
repo_type="dataset",
commit_message="chore: remove old data folder",
)
dataset.push_to_hub(
branch=args.branch,
private=bool(args.private),
tags=args.tags,
license=args.license,
upload_large_folder=bool(args.upload_large_folder),
)
logging.info("Changes pushed to the Hugging Face Hub successfully.")
else:
logging.info("Skipping push to Hub. Dataset updated locally only.")
if __name__ == "__main__":
init_logging()
main()