lerobot/lerobot/scripts/push_dataset_to_hub.py

333 lines
12 KiB
Python

#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Use this script to convert your dataset into LeRobot dataset format and upload it to the Hugging Face hub,
or store it locally. LeRobot dataset format is lightweight, fast to load from, and does not require any
installation of neural net specific packages like pytorch, tensorflow, jax.
Example of how to download raw datasets, convert them into LeRobotDataset format, and push them to the hub:
```
python lerobot/scripts/push_dataset_to_hub.py \
--raw-dir data/pusht_raw \
--raw-format pusht_zarr \
--repo-id lerobot/pusht
python lerobot/scripts/push_dataset_to_hub.py \
--raw-dir data/xarm_lift_medium_raw \
--raw-format xarm_pkl \
--repo-id lerobot/xarm_lift_medium
python lerobot/scripts/push_dataset_to_hub.py \
--raw-dir data/aloha_sim_insertion_scripted_raw \
--raw-format aloha_hdf5 \
--repo-id lerobot/aloha_sim_insertion_scripted
python lerobot/scripts/push_dataset_to_hub.py \
--raw-dir data/umi_cup_in_the_wild_raw \
--raw-format umi_zarr \
--repo-id lerobot/umi_cup_in_the_wild
```
"""
import argparse
import json
import shutil
import warnings
from pathlib import Path
from typing import Any
import torch
from huggingface_hub import HfApi
from safetensors.torch import save_file
from lerobot.common.datasets.compute_stats import compute_stats
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.common.datasets.utils import flatten_dict
def get_from_raw_to_lerobot_format_fn(raw_format: str):
if raw_format == "pusht_zarr":
from lerobot.common.datasets.push_dataset_to_hub.pusht_zarr_format import from_raw_to_lerobot_format
elif raw_format == "umi_zarr":
from lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format import from_raw_to_lerobot_format
elif raw_format == "aloha_hdf5":
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import from_raw_to_lerobot_format
elif raw_format == "dora_parquet":
from lerobot.common.datasets.push_dataset_to_hub.dora_parquet_format import from_raw_to_lerobot_format
elif raw_format == "xarm_pkl":
from lerobot.common.datasets.push_dataset_to_hub.xarm_pkl_format import from_raw_to_lerobot_format
elif raw_format == "cam_png":
from lerobot.common.datasets.push_dataset_to_hub.cam_png_format import from_raw_to_lerobot_format
else:
raise ValueError(
f"The selected {raw_format} can't be found. Did you add it to `lerobot/scripts/push_dataset_to_hub.py::get_from_raw_to_lerobot_format_fn`?"
)
return from_raw_to_lerobot_format
def save_meta_data(
info: dict[str, Any], stats: dict, episode_data_index: dict[str, list], meta_data_dir: Path
):
meta_data_dir.mkdir(parents=True, exist_ok=True)
# save info
info_path = meta_data_dir / "info.json"
with open(str(info_path), "w") as f:
json.dump(info, f, indent=4)
# save stats
stats_path = meta_data_dir / "stats.safetensors"
save_file(flatten_dict(stats), stats_path)
# save episode_data_index
episode_data_index = {key: torch.tensor(episode_data_index[key]) for key in episode_data_index}
ep_data_idx_path = meta_data_dir / "episode_data_index.safetensors"
save_file(episode_data_index, ep_data_idx_path)
def push_meta_data_to_hub(repo_id: str, meta_data_dir: str | Path, revision: str | None):
"""Expect all meta data files to be all stored in a single "meta_data" directory.
On the hugging face repositery, they will be uploaded in a "meta_data" directory at the root.
"""
api = HfApi()
api.upload_folder(
folder_path=meta_data_dir,
path_in_repo="meta_data",
repo_id=repo_id,
revision=revision,
repo_type="dataset",
)
def push_videos_to_hub(repo_id: str, videos_dir: str | Path, revision: str | None):
"""Expect mp4 files to be all stored in a single "videos" directory.
On the hugging face repositery, they will be uploaded in a "videos" directory at the root.
"""
api = HfApi()
api.upload_folder(
folder_path=videos_dir,
path_in_repo="videos",
repo_id=repo_id,
revision=revision,
repo_type="dataset",
allow_patterns="*.mp4",
)
def push_dataset_to_hub(
raw_dir: Path,
raw_format: str,
repo_id: str,
push_to_hub: bool = True,
local_dir: Path | None = None,
fps: int | None = None,
video: bool = True,
batch_size: int = 32,
num_workers: int = 8,
episodes: list[int] | None = None,
force_override: bool = False,
cache_dir: Path = Path("/tmp"),
tests_data_dir: Path | None = None,
):
# Check repo_id is well formated
if len(repo_id.split("/")) != 2:
raise ValueError(
f"`repo_id` is expected to contain a community or user id `/` the name of the dataset (e.g. 'lerobot/pusht'), but instead contains '{repo_id}'."
)
user_id, dataset_id = repo_id.split("/")
# Robustify when `raw_dir` is str instead of Path
raw_dir = Path(raw_dir)
if not raw_dir.exists():
raise NotADirectoryError(
f"{raw_dir} does not exists. Check your paths or run this command to download an existing raw dataset on the hub: "
f"`python lerobot/common/datasets/push_dataset_to_hub/_download_raw.py --raw-dir your/raw/dir --repo-id your/repo/id_raw`"
)
if local_dir:
# Robustify when `local_dir` is str instead of Path
local_dir = Path(local_dir)
# Send warning if local_dir isn't well formated
if local_dir.parts[-2] != user_id or local_dir.parts[-1] != dataset_id:
warnings.warn(
f"`local_dir` ({local_dir}) doesn't contain a community or user id `/` the name of the dataset that match the `repo_id` (e.g. 'data/lerobot/pusht'). Following this naming convention is advised, but not mandatory.",
stacklevel=1,
)
# Check we don't override an existing `local_dir` by mistake
if local_dir.exists():
if force_override:
shutil.rmtree(local_dir)
else:
raise ValueError(f"`local_dir` already exists ({local_dir}). Use `--force-override 1`.")
meta_data_dir = local_dir / "meta_data"
videos_dir = local_dir / "videos"
else:
# Temporary directory used to store images, videos, meta_data
meta_data_dir = Path(cache_dir) / "meta_data"
videos_dir = Path(cache_dir) / "videos"
if raw_format is None:
# TODO(rcadene, adilzouitine): implement auto_find_raw_format
raise NotImplementedError()
# raw_format = auto_find_raw_format(raw_dir)
# convert dataset from original raw format to LeRobot format
from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format)
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(
raw_dir, videos_dir, fps, video, episodes
)
lerobot_dataset = LeRobotDataset.from_preloaded(
repo_id=repo_id,
hf_dataset=hf_dataset,
episode_data_index=episode_data_index,
info=info,
videos_dir=videos_dir,
)
stats = compute_stats(lerobot_dataset, batch_size, num_workers)
if local_dir:
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
hf_dataset.save_to_disk(str(local_dir / "train"))
if push_to_hub or local_dir:
# mandatory for upload
save_meta_data(info, stats, episode_data_index, meta_data_dir)
if push_to_hub:
hf_dataset.push_to_hub(repo_id, revision="main")
push_meta_data_to_hub(repo_id, meta_data_dir, revision="main")
if video:
push_videos_to_hub(repo_id, videos_dir, revision="main")
api = HfApi()
api.create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION)
if tests_data_dir:
# get the first episode
num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0]
test_hf_dataset = hf_dataset.select(range(num_items_first_ep))
episode_data_index = {k: v[:1] for k, v in episode_data_index.items()}
test_hf_dataset = test_hf_dataset.with_format(None)
test_hf_dataset.save_to_disk(str(tests_data_dir / repo_id / "train"))
tests_meta_data = tests_data_dir / repo_id / "meta_data"
save_meta_data(info, stats, episode_data_index, tests_meta_data)
# copy videos of first episode to tests directory
episode_index = 0
tests_videos_dir = tests_data_dir / repo_id / "videos"
tests_videos_dir.mkdir(parents=True, exist_ok=True)
for key in lerobot_dataset.video_frame_keys:
fname = f"{key}_episode_{episode_index:06d}.mp4"
shutil.copy(videos_dir / fname, tests_videos_dir / fname)
if local_dir is None:
# clear cache
shutil.rmtree(meta_data_dir)
shutil.rmtree(videos_dir)
return lerobot_dataset
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--raw-dir",
type=Path,
required=True,
help="Directory containing input raw datasets (e.g. `data/aloha_mobile_chair_raw` or `data/pusht_raw).",
)
# TODO(rcadene): add automatic detection of the format
parser.add_argument(
"--raw-format",
type=str,
required=True,
help="Dataset type (e.g. `pusht_zarr`, `umi_zarr`, `aloha_hdf5`, `xarm_pkl`, `dora_parquet`).",
)
parser.add_argument(
"--repo-id",
type=str,
required=True,
help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
)
parser.add_argument(
"--local-dir",
type=Path,
help="When provided, writes the dataset converted to LeRobotDataset format in this directory (e.g. `data/lerobot/aloha_mobile_chair`).",
)
parser.add_argument(
"--push-to-hub",
type=int,
default=1,
help="Upload to hub.",
)
parser.add_argument(
"--fps",
type=int,
help="Frame rate used to collect videos. If not provided, use the default one specified in the code.",
)
parser.add_argument(
"--video",
type=int,
default=1,
help="Convert each episode of the raw dataset to an mp4 video. This option allows 60 times lower disk space consumption and 25 faster loading time during training.",
)
parser.add_argument(
"--batch-size",
type=int,
default=32,
help="Batch size loaded by DataLoader for computing the dataset statistics.",
)
parser.add_argument(
"--num-workers",
type=int,
default=8,
help="Number of processes of Dataloader for computing the dataset statistics.",
)
parser.add_argument(
"--episodes",
type=int,
nargs="*",
help="When provided, only converts the provided episodes (e.g `--episodes 2 3 4`). Useful to test the code on 1 episode.",
)
parser.add_argument(
"--force-override",
type=int,
default=0,
help="When set to 1, removes provided output directory if it already exists. By default, raises a ValueError exception.",
)
parser.add_argument(
"--tests-data-dir",
type=Path,
help=(
"When provided, save tests artifacts into the given directory "
"(e.g. `--tests-data-dir tests/data` will save to tests/data/{--repo-id})."
),
)
args = parser.parse_args()
push_dataset_to_hub(**vars(args))
if __name__ == "__main__":
main()