diff --git a/convert_dataset_16_to_20.py b/convert_dataset_16_to_20.py index fdb5f233..c53f7595 100644 --- a/convert_dataset_16_to_20.py +++ b/convert_dataset_16_to_20.py @@ -34,6 +34,7 @@ TODO """ import argparse +import contextlib import json import math import subprocess @@ -45,10 +46,13 @@ import pyarrow.compute as pc import pyarrow.parquet as pq import torch from huggingface_hub import HfApi +from huggingface_hub.errors import EntryNotFoundError from PIL import Image from safetensors.torch import load_file +from lerobot.common.datasets.utils import create_branch from lerobot.common.utils.utils import init_hydra_config +from lerobot.scripts.push_dataset_to_hub import push_dataset_card_to_hub V1_6 = "v1.6" V2_0 = "v2.0" @@ -374,9 +378,11 @@ def convert_dataset( if robot_config is not None: robot_type = robot_config["robot_type"] names = robot_config["names"] + repo_tags = [robot_type] else: robot_type = "unknown" names = get_generic_motor_names(sequence_shapes) + repo_tags = None assert set(names) == set(keys["sequence"]) for key in sequence_shapes: @@ -396,6 +402,7 @@ def convert_dataset( "total_episodes": total_episodes, "total_tasks": len(tasks), "fps": metadata_v1_6["fps"], + "splits": {"train": f"0:{total_episodes}"}, "image_keys": keys["video"] + keys["image"], "keys": keys["sequence"], "shapes": {**image_shapes, **video_shapes, **sequence_shapes}, @@ -404,15 +411,32 @@ def convert_dataset( "episodes": episodes, } write_json(metadata_v2_0, v2_0_dir / "meta" / "info.json") - convert_stats_to_json(v1_6_dir / "meta_data", v2_0_dir / "meta") - # test_repo_id = f"aliberts/{repo_id.split('/')[1]}" - # if hub_api.repo_exists(test_repo_id, repo_type="dataset"): - # hub_api.delete_repo(test_repo_id, repo_type="dataset") + with contextlib.suppress(EntryNotFoundError): + hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision="main") - # hub_api.create_repo(test_repo_id, repo_type="dataset", exist_ok=True) - # hub_api.upload_folder(repo_id=test_repo_id, folder_path=v2_0_dir, repo_type="dataset") + with contextlib.suppress(EntryNotFoundError): + hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta_data", repo_type="dataset", revision="main") + + hub_api.upload_folder( + repo_id=repo_id, + path_in_repo="data", + folder_path=v2_0_dir / "data", + repo_type="dataset", + revision="main", + ) + hub_api.upload_folder( + repo_id=repo_id, + path_in_repo="meta", + folder_path=v2_0_dir / "meta", + repo_type="dataset", + revision="main", + ) + metadata_v2_0.pop("episodes") + card_text = f"```json\n{json.dumps(metadata_v2_0, indent=4)}\n```" + push_dataset_card_to_hub(repo_id=repo_id, revision="main", tags=repo_tags, text=card_text) + create_branch(repo_id=repo_id, branch=V2_0, repo_type="dataset") # TODO: # - [X] Add shapes @@ -422,9 +446,10 @@ def convert_dataset( # - [X] Add task.json # - [X] Add names # - [X] Add robot_type + # - [X] Add splits + # - [X] Push properly to branch v2.0 and delete v1.6 stuff from that branch # - [/] Add sanity checks (encoding, shapes) # - [ ] Handle multitask datasets - # - [ ] Push properly to branch v2.0 and delete v1.6 stuff from that branch def main():