Add dataset cards (#363)
This commit is contained in:
parent
bbe9057225
commit
b98ea415c1
|
@ -788,13 +788,16 @@ python lerobot/scripts/control_robot.py record \
|
||||||
--fps 30 \
|
--fps 30 \
|
||||||
--root data \
|
--root data \
|
||||||
--repo-id ${HF_USER}/koch_test \
|
--repo-id ${HF_USER}/koch_test \
|
||||||
|
--tags tutorial \
|
||||||
--warmup-time-s 5 \
|
--warmup-time-s 5 \
|
||||||
--episode-time-s 30 \
|
--episode-time-s 30 \
|
||||||
--reset-time-s 30 \
|
--reset-time-s 30 \
|
||||||
--num-episodes 2
|
--num-episodes 2
|
||||||
```
|
```
|
||||||
|
|
||||||
This will write your dataset to `{root}/{repo-id}` (e.g. `data/cadene/koch_test`).
|
This will write your dataset locally to `{root}/{repo-id}` (e.g. `data/cadene/koch_test`) and push it on the hub at `https://huggingface.co/datasets/{HF_USER}/{repo-id}`. Your dataset will be automatically tagged with `LeRobot` for the community to find it easily, and you can also add custom tags (in this case `tutorial` for example).
|
||||||
|
|
||||||
|
You can look for other LeRobot datasets on the hub by searching for `LeRobot` tags: https://huggingface.co/datasets?other=LeRobot
|
||||||
|
|
||||||
Remember to add `--robot-overrides '~cameras'` if you don't have any cameras and you still use the default `koch.yaml` configuration.
|
Remember to add `--robot-overrides '~cameras'` if you don't have any cameras and you still use the default `koch.yaml` configuration.
|
||||||
|
|
||||||
|
@ -998,6 +1001,7 @@ python lerobot/scripts/control_robot.py record \
|
||||||
--fps 30 \
|
--fps 30 \
|
||||||
--root data \
|
--root data \
|
||||||
--repo-id ${HF_USER}/eval_koch_test \
|
--repo-id ${HF_USER}/eval_koch_test \
|
||||||
|
--tags tutorial eval \
|
||||||
--warmup-time-s 5 \
|
--warmup-time-s 5 \
|
||||||
--episode-time-s 30 \
|
--episode-time-s 30 \
|
||||||
--reset-time-s 30 \
|
--reset-time-s 30 \
|
||||||
|
|
|
@ -23,11 +23,19 @@ from typing import Dict
|
||||||
import datasets
|
import datasets
|
||||||
import torch
|
import torch
|
||||||
from datasets import load_dataset, load_from_disk
|
from datasets import load_dataset, load_from_disk
|
||||||
from huggingface_hub import HfApi, hf_hub_download, snapshot_download
|
from huggingface_hub import DatasetCard, HfApi, hf_hub_download, snapshot_download
|
||||||
from PIL import Image as PILImage
|
from PIL import Image as PILImage
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
|
|
||||||
|
DATASET_CARD_TEMPLATE = """
|
||||||
|
---
|
||||||
|
# Metadata will go there
|
||||||
|
---
|
||||||
|
This dataset was created using [🤗 LeRobot](https://github.com/huggingface/lerobot).
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
def flatten_dict(d, parent_key="", sep="/"):
|
def flatten_dict(d, parent_key="", sep="/"):
|
||||||
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
|
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
|
||||||
|
@ -400,3 +408,14 @@ def create_branch(repo_id, *, branch: str, repo_type: str | None = None):
|
||||||
api.delete_branch(repo_id, repo_type=repo_type, branch=branch)
|
api.delete_branch(repo_id, repo_type=repo_type, branch=branch)
|
||||||
|
|
||||||
api.create_branch(repo_id, repo_type=repo_type, branch=branch)
|
api.create_branch(repo_id, repo_type=repo_type, branch=branch)
|
||||||
|
|
||||||
|
|
||||||
|
def create_lerobot_dataset_card(tags: list | None = None, text: str | None = None) -> DatasetCard:
|
||||||
|
card = DatasetCard(DATASET_CARD_TEMPLATE)
|
||||||
|
card.data.task_categories = ["robotics"]
|
||||||
|
card.data.tags = ["LeRobot"]
|
||||||
|
if tags is not None:
|
||||||
|
card.data.tags += tags
|
||||||
|
if text is not None:
|
||||||
|
card.text += text
|
||||||
|
return card
|
||||||
|
|
|
@ -130,7 +130,12 @@ from lerobot.common.robot_devices.robots.factory import make_robot
|
||||||
from lerobot.common.robot_devices.robots.utils import Robot
|
from lerobot.common.robot_devices.robots.utils import Robot
|
||||||
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
|
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
|
||||||
from lerobot.scripts.eval import get_pretrained_policy_path
|
from lerobot.scripts.eval import get_pretrained_policy_path
|
||||||
from lerobot.scripts.push_dataset_to_hub import push_meta_data_to_hub, push_videos_to_hub, save_meta_data
|
from lerobot.scripts.push_dataset_to_hub import (
|
||||||
|
push_dataset_card_to_hub,
|
||||||
|
push_meta_data_to_hub,
|
||||||
|
push_videos_to_hub,
|
||||||
|
save_meta_data,
|
||||||
|
)
|
||||||
|
|
||||||
########################################################################################
|
########################################################################################
|
||||||
# Utilities
|
# Utilities
|
||||||
|
@ -292,6 +297,7 @@ def record(
|
||||||
video=True,
|
video=True,
|
||||||
run_compute_stats=True,
|
run_compute_stats=True,
|
||||||
push_to_hub=True,
|
push_to_hub=True,
|
||||||
|
tags=None,
|
||||||
num_image_writers=8,
|
num_image_writers=8,
|
||||||
force_override=False,
|
force_override=False,
|
||||||
):
|
):
|
||||||
|
@ -647,6 +653,7 @@ def record(
|
||||||
if push_to_hub:
|
if push_to_hub:
|
||||||
hf_dataset.push_to_hub(repo_id, revision="main")
|
hf_dataset.push_to_hub(repo_id, revision="main")
|
||||||
push_meta_data_to_hub(repo_id, meta_data_dir, revision="main")
|
push_meta_data_to_hub(repo_id, meta_data_dir, revision="main")
|
||||||
|
push_dataset_card_to_hub(repo_id, revision="main", tags=tags)
|
||||||
if video:
|
if video:
|
||||||
push_videos_to_hub(repo_id, videos_dir, revision="main")
|
push_videos_to_hub(repo_id, videos_dir, revision="main")
|
||||||
create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION)
|
create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION)
|
||||||
|
@ -758,6 +765,12 @@ if __name__ == "__main__":
|
||||||
default=1,
|
default=1,
|
||||||
help="Upload dataset to Hugging Face hub.",
|
help="Upload dataset to Hugging Face hub.",
|
||||||
)
|
)
|
||||||
|
parser_record.add_argument(
|
||||||
|
"--tags",
|
||||||
|
type=str,
|
||||||
|
nargs="*",
|
||||||
|
help="Add tags to your dataset on the hub.",
|
||||||
|
)
|
||||||
parser_record.add_argument(
|
parser_record.add_argument(
|
||||||
"--num-image-writers",
|
"--num-image-writers",
|
||||||
type=int,
|
type=int,
|
||||||
|
|
|
@ -56,7 +56,7 @@ from safetensors.torch import save_file
|
||||||
from lerobot.common.datasets.compute_stats import compute_stats
|
from lerobot.common.datasets.compute_stats import compute_stats
|
||||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.utils import check_repo_id
|
from lerobot.common.datasets.push_dataset_to_hub.utils import check_repo_id
|
||||||
from lerobot.common.datasets.utils import create_branch, flatten_dict
|
from lerobot.common.datasets.utils import create_branch, create_lerobot_dataset_card, flatten_dict
|
||||||
|
|
||||||
|
|
||||||
def get_from_raw_to_lerobot_format_fn(raw_format: str):
|
def get_from_raw_to_lerobot_format_fn(raw_format: str):
|
||||||
|
@ -114,6 +114,14 @@ def push_meta_data_to_hub(repo_id: str, meta_data_dir: str | Path, revision: str
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def push_dataset_card_to_hub(
|
||||||
|
repo_id: str, revision: str | None, tags: list | None = None, text: str | None = None
|
||||||
|
):
|
||||||
|
"""Creates and pushes a LeRobotDataset Card with appropriate tags to easily find it on the hub."""
|
||||||
|
card = create_lerobot_dataset_card(tags=tags, text=text)
|
||||||
|
card.push_to_hub(repo_id=repo_id, repo_type="dataset", revision=revision)
|
||||||
|
|
||||||
|
|
||||||
def push_videos_to_hub(repo_id: str, videos_dir: str | Path, revision: str | None):
|
def push_videos_to_hub(repo_id: str, videos_dir: str | Path, revision: str | None):
|
||||||
"""Expect mp4 files to be all stored in a single "videos" directory.
|
"""Expect mp4 files to be all stored in a single "videos" directory.
|
||||||
On the hugging face repositery, they will be uploaded in a "videos" directory at the root.
|
On the hugging face repositery, they will be uploaded in a "videos" directory at the root.
|
||||||
|
@ -213,6 +221,7 @@ def push_dataset_to_hub(
|
||||||
if push_to_hub:
|
if push_to_hub:
|
||||||
hf_dataset.push_to_hub(repo_id, revision="main")
|
hf_dataset.push_to_hub(repo_id, revision="main")
|
||||||
push_meta_data_to_hub(repo_id, meta_data_dir, revision="main")
|
push_meta_data_to_hub(repo_id, meta_data_dir, revision="main")
|
||||||
|
push_dataset_card_to_hub(repo_id, revision="main")
|
||||||
if video:
|
if video:
|
||||||
push_videos_to_hub(repo_id, videos_dir, revision="main")
|
push_videos_to_hub(repo_id, videos_dir, revision="main")
|
||||||
create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION)
|
create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION)
|
||||||
|
|
Loading…
Reference in New Issue