Fix dataset version tags (#790)

This commit is contained in:
Simon Alibert 2025-02-28 14:36:20 +01:00 committed by GitHub
parent ed83cbd4f2
commit e81c36cf74
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 23 additions and 1 deletions

View File

@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import contextlib
import logging import logging
import shutil import shutil
from pathlib import Path from pathlib import Path
@ -27,6 +28,7 @@ import torch.utils
from datasets import concatenate_datasets, load_dataset from datasets import concatenate_datasets, load_dataset
from huggingface_hub import HfApi, snapshot_download from huggingface_hub import HfApi, snapshot_download
from huggingface_hub.constants import REPOCARD_NAME from huggingface_hub.constants import REPOCARD_NAME
from huggingface_hub.errors import RevisionNotFoundError
from lerobot.common.constants import HF_LEROBOT_HOME from lerobot.common.constants import HF_LEROBOT_HOME
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
@ -517,6 +519,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
branch: str | None = None, branch: str | None = None,
tags: list | None = None, tags: list | None = None,
license: str | None = "apache-2.0", license: str | None = "apache-2.0",
tag_version: bool = True,
push_videos: bool = True, push_videos: bool = True,
private: bool = False, private: bool = False,
allow_patterns: list[str] | str | None = None, allow_patterns: list[str] | str | None = None,
@ -562,6 +565,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
) )
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch) card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)
if tag_version:
with contextlib.suppress(RevisionNotFoundError):
hub_api.delete_tag(self.repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
hub_api.create_tag(self.repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
def pull_from_repo( def pull_from_repo(
self, self,
allow_patterns: list[str] | str | None = None, allow_patterns: list[str] | str | None = None,

View File

@ -31,6 +31,7 @@ import packaging.version
import torch import torch
from datasets.table import embed_table_storage from datasets.table import embed_table_storage
from huggingface_hub import DatasetCard, DatasetCardData, HfApi from huggingface_hub import DatasetCard, DatasetCardData, HfApi
from huggingface_hub.errors import RevisionNotFoundError
from PIL import Image as PILImage from PIL import Image as PILImage
from torchvision import transforms from torchvision import transforms
@ -325,6 +326,19 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) ->
) )
hub_versions = get_repo_versions(repo_id) hub_versions = get_repo_versions(repo_id)
if not hub_versions:
raise RevisionNotFoundError(
f"""Your dataset must be tagged with a codebase version.
Assuming _version_ is the codebase_version value in the info.json, you can run this:
```python
from huggingface_hub import HfApi
hub_api = HfApi()
hub_api.create_tag("{repo_id}", tag="_version_", repo_type="dataset")
```
"""
)
if target_version in hub_versions: if target_version in hub_versions:
return f"v{target_version}" return f"v{target_version}"

View File

@ -57,7 +57,7 @@ def convert_dataset(
dataset.meta.info["codebase_version"] = CODEBASE_VERSION dataset.meta.info["codebase_version"] = CODEBASE_VERSION
write_info(dataset.meta.info, dataset.root) write_info(dataset.meta.info, dataset.root)
dataset.push_to_hub(branch=branch, allow_patterns="meta/") dataset.push_to_hub(branch=branch, tag_version=False, allow_patterns="meta/")
# delete old stats.json file # delete old stats.json file
if (dataset.root / STATS_PATH).is_file: if (dataset.root / STATS_PATH).is_file: