Add upload folders

This commit is contained in:
Simon Alibert 2024-10-04 14:26:50 +02:00
parent 17a1214e25
commit 1016a983a1
1 changed files with 32 additions and 7 deletions

View File

@ -34,6 +34,7 @@ TODO
"""
import argparse
import contextlib
import json
import math
import subprocess
@ -45,10 +46,13 @@ import pyarrow.compute as pc
import pyarrow.parquet as pq
import torch
from huggingface_hub import HfApi
from huggingface_hub.errors import EntryNotFoundError
from PIL import Image
from safetensors.torch import load_file
from lerobot.common.datasets.utils import create_branch
from lerobot.common.utils.utils import init_hydra_config
from lerobot.scripts.push_dataset_to_hub import push_dataset_card_to_hub
V1_6 = "v1.6"
V2_0 = "v2.0"
@ -374,9 +378,11 @@ def convert_dataset(
if robot_config is not None:
robot_type = robot_config["robot_type"]
names = robot_config["names"]
repo_tags = [robot_type]
else:
robot_type = "unknown"
names = get_generic_motor_names(sequence_shapes)
repo_tags = None
assert set(names) == set(keys["sequence"])
for key in sequence_shapes:
@ -396,6 +402,7 @@ def convert_dataset(
"total_episodes": total_episodes,
"total_tasks": len(tasks),
"fps": metadata_v1_6["fps"],
"splits": {"train": f"0:{total_episodes}"},
"image_keys": keys["video"] + keys["image"],
"keys": keys["sequence"],
"shapes": {**image_shapes, **video_shapes, **sequence_shapes},
@ -404,15 +411,32 @@ def convert_dataset(
"episodes": episodes,
}
write_json(metadata_v2_0, v2_0_dir / "meta" / "info.json")
convert_stats_to_json(v1_6_dir / "meta_data", v2_0_dir / "meta")
# test_repo_id = f"aliberts/{repo_id.split('/')[1]}"
# if hub_api.repo_exists(test_repo_id, repo_type="dataset"):
# hub_api.delete_repo(test_repo_id, repo_type="dataset")
with contextlib.suppress(EntryNotFoundError):
hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision="main")
# hub_api.create_repo(test_repo_id, repo_type="dataset", exist_ok=True)
# hub_api.upload_folder(repo_id=test_repo_id, folder_path=v2_0_dir, repo_type="dataset")
with contextlib.suppress(EntryNotFoundError):
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta_data", repo_type="dataset", revision="main")
hub_api.upload_folder(
repo_id=repo_id,
path_in_repo="data",
folder_path=v2_0_dir / "data",
repo_type="dataset",
revision="main",
)
hub_api.upload_folder(
repo_id=repo_id,
path_in_repo="meta",
folder_path=v2_0_dir / "meta",
repo_type="dataset",
revision="main",
)
metadata_v2_0.pop("episodes")
card_text = f"```json\n{json.dumps(metadata_v2_0, indent=4)}\n```"
push_dataset_card_to_hub(repo_id=repo_id, revision="main", tags=repo_tags, text=card_text)
create_branch(repo_id=repo_id, branch=V2_0, repo_type="dataset")
# TODO:
# - [X] Add shapes
@ -422,9 +446,10 @@ def convert_dataset(
# - [X] Add task.json
# - [X] Add names
# - [X] Add robot_type
# - [X] Add splits
# - [X] Push properly to branch v2.0 and delete v1.6 stuff from that branch
# - [/] Add sanity checks (encoding, shapes)
# - [ ] Handle multitask datasets
# - [ ] Push properly to branch v2.0 and delete v1.6 stuff from that branch
def main():