Add available list of raw repo ids (#312)

This commit is contained in:
Remi 2024-07-13 11:30:50 +02:00 committed by GitHub
parent 471eab3d7e
commit 5ffcb48a9a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 49 additions and 46 deletions

View File

@ -31,38 +31,7 @@ from pathlib import Path
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
AVAILABLE_RAW_REPO_IDS = [
def download_raw(raw_dir: Path, repo_id: str):
# Check repo_id is well formated
if len(repo_id.split("/")) != 2:
raise ValueError(
f"`repo_id` is expected to contain a community or user id `/` the name of the dataset (e.g. 'lerobot/pusht'), but contains '{repo_id}'."
)
user_id, dataset_id = repo_id.split("/")
if not dataset_id.endswith("_raw"):
warnings.warn(
f"`dataset_id` ({dataset_id}) doesn't end with '_raw' (e.g. 'lerobot/pusht_raw'). Following this naming convention by renaming your repository is advised, but not mandatory.",
stacklevel=1,
)
raw_dir = Path(raw_dir)
# Send warning if raw_dir isn't well formated
if raw_dir.parts[-2] != user_id or raw_dir.parts[-1] != dataset_id:
warnings.warn(
f"`raw_dir` ({raw_dir}) doesn't contain a community or user id `/` the name of the dataset that match the `repo_id` (e.g. 'data/lerobot/pusht_raw'). Following this naming convention is advised, but not mandatory.",
stacklevel=1,
)
raw_dir.mkdir(parents=True, exist_ok=True)
logging.info(f"Start downloading from huggingface.co/{user_id} for {dataset_id}")
snapshot_download(f"{repo_id}", repo_type="dataset", local_dir=raw_dir)
logging.info(f"Finish downloading from huggingface.co/{user_id} for {dataset_id}")
def download_all_raw_datasets():
data_dir = Path("data")
repo_ids = [
"cadene/pusht_image_raw", "cadene/pusht_image_raw",
"cadene/xarm_lift_medium_image_raw", "cadene/xarm_lift_medium_image_raw",
"cadene/xarm_lift_medium_replay_image_raw", "cadene/xarm_lift_medium_replay_image_raw",
@ -103,14 +72,48 @@ def download_all_raw_datasets():
"cadene/aloha_static_vinh_cup_left_raw", "cadene/aloha_static_vinh_cup_left_raw",
"cadene/aloha_static_ziploc_slide_raw", "cadene/aloha_static_ziploc_slide_raw",
"cadene/umi_cup_in_the_wild_raw", "cadene/umi_cup_in_the_wild_raw",
] ]
for repo_id in repo_ids:
def download_raw(raw_dir: Path, repo_id: str):
# Check repo_id is well formated
if len(repo_id.split("/")) != 2:
raise ValueError(
f"`repo_id` is expected to contain a community or user id `/` the name of the dataset (e.g. 'lerobot/pusht'), but contains '{repo_id}'."
)
user_id, dataset_id = repo_id.split("/")
if not dataset_id.endswith("_raw"):
warnings.warn(
f"`dataset_id` ({dataset_id}) doesn't end with '_raw' (e.g. 'lerobot/pusht_raw'). Following this naming convention by renaming your repository is advised, but not mandatory.",
stacklevel=1,
)
raw_dir = Path(raw_dir)
# Send warning if raw_dir isn't well formated
if raw_dir.parts[-2] != user_id or raw_dir.parts[-1] != dataset_id:
warnings.warn(
f"`raw_dir` ({raw_dir}) doesn't contain a community or user id `/` the name of the dataset that match the `repo_id` (e.g. 'data/lerobot/pusht_raw'). Following this naming convention is advised, but not mandatory.",
stacklevel=1,
)
raw_dir.mkdir(parents=True, exist_ok=True)
logging.info(f"Start downloading from huggingface.co/{user_id} for {dataset_id}")
snapshot_download(f"{repo_id}", repo_type="dataset", local_dir=raw_dir)
logging.info(f"Finish downloading from huggingface.co/{user_id} for {dataset_id}")
def download_all_raw_datasets():
data_dir = Path("data")
for repo_id in AVAILABLE_RAW_REPO_IDS:
raw_dir = data_dir / repo_id raw_dir = data_dir / repo_id
download_raw(raw_dir, repo_id) download_raw(raw_dir, repo_id)
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser(
description=f"A script to download raw datasets from Hugging Face hub to a local directory. Here is a non exhaustive list of available repositories to use in `--repo-id`: {AVAILABLE_RAW_REPO_IDS}",
)
parser.add_argument( parser.add_argument(
"--raw-dir", "--raw-dir",

View File

@ -208,8 +208,8 @@ def push_dataset_to_hub(
raw_dir = Path(raw_dir) raw_dir = Path(raw_dir)
if not raw_dir.exists(): if not raw_dir.exists():
raise NotADirectoryError( raise NotADirectoryError(
f"{raw_dir} does not exists. Check your paths or run this command to download an existing raw dataset on the hub:" f"{raw_dir} does not exists. Check your paths or run this command to download an existing raw dataset on the hub: "
f"python lerobot/common/datasets/push_dataset_to_hub/_download_raw.py --raw-dir your/raw/dir --repo-id your/repo/id_raw" f"`python lerobot/common/datasets/push_dataset_to_hub/_download_raw.py --raw-dir your/raw/dir --repo-id your/repo/id_raw`"
) )
if local_dir: if local_dir: