Add upload folders
This commit is contained in:
parent
17a1214e25
commit
1016a983a1
|
@ -34,6 +34,7 @@ TODO
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import contextlib
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import subprocess
|
import subprocess
|
||||||
|
@ -45,10 +46,13 @@ import pyarrow.compute as pc
|
||||||
import pyarrow.parquet as pq
|
import pyarrow.parquet as pq
|
||||||
import torch
|
import torch
|
||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi
|
||||||
|
from huggingface_hub.errors import EntryNotFoundError
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
|
from lerobot.common.datasets.utils import create_branch
|
||||||
from lerobot.common.utils.utils import init_hydra_config
|
from lerobot.common.utils.utils import init_hydra_config
|
||||||
|
from lerobot.scripts.push_dataset_to_hub import push_dataset_card_to_hub
|
||||||
|
|
||||||
V1_6 = "v1.6"
|
V1_6 = "v1.6"
|
||||||
V2_0 = "v2.0"
|
V2_0 = "v2.0"
|
||||||
|
@ -374,9 +378,11 @@ def convert_dataset(
|
||||||
if robot_config is not None:
|
if robot_config is not None:
|
||||||
robot_type = robot_config["robot_type"]
|
robot_type = robot_config["robot_type"]
|
||||||
names = robot_config["names"]
|
names = robot_config["names"]
|
||||||
|
repo_tags = [robot_type]
|
||||||
else:
|
else:
|
||||||
robot_type = "unknown"
|
robot_type = "unknown"
|
||||||
names = get_generic_motor_names(sequence_shapes)
|
names = get_generic_motor_names(sequence_shapes)
|
||||||
|
repo_tags = None
|
||||||
|
|
||||||
assert set(names) == set(keys["sequence"])
|
assert set(names) == set(keys["sequence"])
|
||||||
for key in sequence_shapes:
|
for key in sequence_shapes:
|
||||||
|
@ -396,6 +402,7 @@ def convert_dataset(
|
||||||
"total_episodes": total_episodes,
|
"total_episodes": total_episodes,
|
||||||
"total_tasks": len(tasks),
|
"total_tasks": len(tasks),
|
||||||
"fps": metadata_v1_6["fps"],
|
"fps": metadata_v1_6["fps"],
|
||||||
|
"splits": {"train": f"0:{total_episodes}"},
|
||||||
"image_keys": keys["video"] + keys["image"],
|
"image_keys": keys["video"] + keys["image"],
|
||||||
"keys": keys["sequence"],
|
"keys": keys["sequence"],
|
||||||
"shapes": {**image_shapes, **video_shapes, **sequence_shapes},
|
"shapes": {**image_shapes, **video_shapes, **sequence_shapes},
|
||||||
|
@ -404,15 +411,32 @@ def convert_dataset(
|
||||||
"episodes": episodes,
|
"episodes": episodes,
|
||||||
}
|
}
|
||||||
write_json(metadata_v2_0, v2_0_dir / "meta" / "info.json")
|
write_json(metadata_v2_0, v2_0_dir / "meta" / "info.json")
|
||||||
|
|
||||||
convert_stats_to_json(v1_6_dir / "meta_data", v2_0_dir / "meta")
|
convert_stats_to_json(v1_6_dir / "meta_data", v2_0_dir / "meta")
|
||||||
|
|
||||||
# test_repo_id = f"aliberts/{repo_id.split('/')[1]}"
|
with contextlib.suppress(EntryNotFoundError):
|
||||||
# if hub_api.repo_exists(test_repo_id, repo_type="dataset"):
|
hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision="main")
|
||||||
# hub_api.delete_repo(test_repo_id, repo_type="dataset")
|
|
||||||
|
|
||||||
# hub_api.create_repo(test_repo_id, repo_type="dataset", exist_ok=True)
|
with contextlib.suppress(EntryNotFoundError):
|
||||||
# hub_api.upload_folder(repo_id=test_repo_id, folder_path=v2_0_dir, repo_type="dataset")
|
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta_data", repo_type="dataset", revision="main")
|
||||||
|
|
||||||
|
hub_api.upload_folder(
|
||||||
|
repo_id=repo_id,
|
||||||
|
path_in_repo="data",
|
||||||
|
folder_path=v2_0_dir / "data",
|
||||||
|
repo_type="dataset",
|
||||||
|
revision="main",
|
||||||
|
)
|
||||||
|
hub_api.upload_folder(
|
||||||
|
repo_id=repo_id,
|
||||||
|
path_in_repo="meta",
|
||||||
|
folder_path=v2_0_dir / "meta",
|
||||||
|
repo_type="dataset",
|
||||||
|
revision="main",
|
||||||
|
)
|
||||||
|
metadata_v2_0.pop("episodes")
|
||||||
|
card_text = f"```json\n{json.dumps(metadata_v2_0, indent=4)}\n```"
|
||||||
|
push_dataset_card_to_hub(repo_id=repo_id, revision="main", tags=repo_tags, text=card_text)
|
||||||
|
create_branch(repo_id=repo_id, branch=V2_0, repo_type="dataset")
|
||||||
|
|
||||||
# TODO:
|
# TODO:
|
||||||
# - [X] Add shapes
|
# - [X] Add shapes
|
||||||
|
@ -422,9 +446,10 @@ def convert_dataset(
|
||||||
# - [X] Add task.json
|
# - [X] Add task.json
|
||||||
# - [X] Add names
|
# - [X] Add names
|
||||||
# - [X] Add robot_type
|
# - [X] Add robot_type
|
||||||
|
# - [X] Add splits
|
||||||
|
# - [X] Push properly to branch v2.0 and delete v1.6 stuff from that branch
|
||||||
# - [/] Add sanity checks (encoding, shapes)
|
# - [/] Add sanity checks (encoding, shapes)
|
||||||
# - [ ] Handle multitask datasets
|
# - [ ] Handle multitask datasets
|
||||||
# - [ ] Push properly to branch v2.0 and delete v1.6 stuff from that branch
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
Loading…
Reference in New Issue