203 lines
8.8 KiB
Python
203 lines
8.8 KiB
Python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from pathlib import Path
|
|
from tempfile import TemporaryDirectory
|
|
from typing import Any, Type, TypeVar
|
|
|
|
from huggingface_hub import HfApi
|
|
from huggingface_hub.utils import validate_hf_hub_args
|
|
|
|
T = TypeVar("T", bound="HubMixin")
|
|
|
|
|
|
class HubMixin:
|
|
"""
|
|
A Mixin containing the functionality to push an object to the hub.
|
|
|
|
This is similar to huggingface_hub.ModelHubMixin but is lighter and makes less assumptions about its
|
|
subclasses (in particular, the fact that it's not necessarily a model).
|
|
|
|
The inheriting classes must implement '_save_pretrained' and 'from_pretrained'.
|
|
"""
|
|
|
|
def save_pretrained(
|
|
self,
|
|
save_directory: str | Path,
|
|
*,
|
|
repo_id: str | None = None,
|
|
push_to_hub: bool = False,
|
|
card_kwargs: dict[str, Any] | None = None,
|
|
**push_to_hub_kwargs,
|
|
) -> str | None:
|
|
"""
|
|
Save object in local directory.
|
|
|
|
Args:
|
|
save_directory (`str` or `Path`):
|
|
Path to directory in which the object will be saved.
|
|
push_to_hub (`bool`, *optional*, defaults to `False`):
|
|
Whether or not to push your object to the Huggingface Hub after saving it.
|
|
repo_id (`str`, *optional*):
|
|
ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if
|
|
not provided.
|
|
card_kwargs (`Dict[str, Any]`, *optional*):
|
|
Additional arguments passed to the card template to customize the card.
|
|
push_to_hub_kwargs:
|
|
Additional key word arguments passed along to the [`~HubMixin.push_to_hub`] method.
|
|
Returns:
|
|
`str` or `None`: url of the commit on the Hub if `push_to_hub=True`, `None` otherwise.
|
|
"""
|
|
save_directory = Path(save_directory)
|
|
save_directory.mkdir(parents=True, exist_ok=True)
|
|
|
|
# save object (weights, files, etc.)
|
|
self._save_pretrained(save_directory)
|
|
|
|
# push to the Hub if required
|
|
if push_to_hub:
|
|
if repo_id is None:
|
|
repo_id = save_directory.name # Defaults to `save_directory` name
|
|
return self.push_to_hub(repo_id=repo_id, card_kwargs=card_kwargs, **push_to_hub_kwargs)
|
|
return None
|
|
|
|
def _save_pretrained(self, save_directory: Path) -> None:
|
|
"""
|
|
Overwrite this method in subclass to define how to save your object.
|
|
|
|
Args:
|
|
save_directory (`str` or `Path`):
|
|
Path to directory in which the object files will be saved.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@classmethod
|
|
@validate_hf_hub_args
|
|
def from_pretrained(
|
|
cls: Type[T],
|
|
pretrained_name_or_path: str | Path,
|
|
*,
|
|
force_download: bool = False,
|
|
resume_download: bool | None = None,
|
|
proxies: dict | None = None,
|
|
token: str | bool | None = None,
|
|
cache_dir: str | Path | None = None,
|
|
local_files_only: bool = False,
|
|
revision: str | None = None,
|
|
**kwargs,
|
|
) -> T:
|
|
"""
|
|
Download the object from the Huggingface Hub and instantiate it.
|
|
|
|
Args:
|
|
pretrained_name_or_path (`str`, `Path`):
|
|
- Either the `repo_id` (string) of the object hosted on the Hub, e.g. `lerobot/diffusion_pusht`.
|
|
- Or a path to a `directory` containing the object files saved using `.save_pretrained`,
|
|
e.g., `../path/to/my_model_directory/`.
|
|
revision (`str`, *optional*):
|
|
Revision on the Hub. Can be a branch name, a git tag or any commit id.
|
|
Defaults to the latest commit on `main` branch.
|
|
force_download (`bool`, *optional*, defaults to `False`):
|
|
Whether to force (re-)downloading the files from the Hub, overriding the existing cache.
|
|
proxies (`Dict[str, str]`, *optional*):
|
|
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
|
'http://hostname': 'foo.bar:4012'}`. The proxies are used on every request.
|
|
token (`str` or `bool`, *optional*):
|
|
The token to use as HTTP bearer authorization for remote files. By default, it will use the token
|
|
cached when running `huggingface-cli login`.
|
|
cache_dir (`str`, `Path`, *optional*):
|
|
Path to the folder where cached files are stored.
|
|
local_files_only (`bool`, *optional*, defaults to `False`):
|
|
If `True`, avoid downloading the file and return the path to the local cached file if it exists.
|
|
kwargs (`Dict`, *optional*):
|
|
Additional kwargs to pass to the object during initialization.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@validate_hf_hub_args
|
|
def push_to_hub(
|
|
self,
|
|
repo_id: str,
|
|
*,
|
|
commit_message: str | None = None,
|
|
private: bool | None = None,
|
|
token: str | None = None,
|
|
branch: str | None = None,
|
|
create_pr: bool | None = None,
|
|
allow_patterns: list[str] | str | None = None,
|
|
ignore_patterns: list[str] | str | None = None,
|
|
delete_patterns: list[str] | str | None = None,
|
|
card_kwargs: dict[str, Any] | None = None,
|
|
) -> str:
|
|
"""
|
|
Upload model checkpoint to the Hub.
|
|
|
|
Use `allow_patterns` and `ignore_patterns` to precisely filter which files should be pushed to the hub. Use
|
|
`delete_patterns` to delete existing remote files in the same commit. See [`upload_folder`] reference for more
|
|
details.
|
|
|
|
Args:
|
|
repo_id (`str`):
|
|
ID of the repository to push to (example: `"username/my-model"`).
|
|
commit_message (`str`, *optional*):
|
|
Message to commit while pushing.
|
|
private (`bool`, *optional*):
|
|
Whether the repository created should be private.
|
|
If `None` (default), the repo will be public unless the organization's default is private.
|
|
token (`str`, *optional*):
|
|
The token to use as HTTP bearer authorization for remote files. By default, it will use the token
|
|
cached when running `huggingface-cli login`.
|
|
branch (`str`, *optional*):
|
|
The git branch on which to push the model. This defaults to `"main"`.
|
|
create_pr (`boolean`, *optional*):
|
|
Whether or not to create a Pull Request from `branch` with that commit. Defaults to `False`.
|
|
allow_patterns (`List[str]` or `str`, *optional*):
|
|
If provided, only files matching at least one pattern are pushed.
|
|
ignore_patterns (`List[str]` or `str`, *optional*):
|
|
If provided, files matching any of the patterns are not pushed.
|
|
delete_patterns (`List[str]` or `str`, *optional*):
|
|
If provided, remote files matching any of the patterns will be deleted from the repo.
|
|
card_kwargs (`Dict[str, Any]`, *optional*):
|
|
Additional arguments passed to the card template to customize the card.
|
|
|
|
Returns:
|
|
The url of the commit of your object in the given repository.
|
|
"""
|
|
api = HfApi(token=token)
|
|
repo_id = api.create_repo(repo_id=repo_id, private=private, exist_ok=True).repo_id
|
|
|
|
if commit_message is None:
|
|
if "Policy" in self.__class__.__name__:
|
|
commit_message = "Upload policy"
|
|
elif "Config" in self.__class__.__name__:
|
|
commit_message = "Upload config"
|
|
else:
|
|
commit_message = f"Upload {self.__class__.__name__}"
|
|
|
|
# Push the files to the repo in a single commit
|
|
with TemporaryDirectory(ignore_cleanup_errors=True) as tmp:
|
|
saved_path = Path(tmp) / repo_id
|
|
self.save_pretrained(saved_path, card_kwargs=card_kwargs)
|
|
return api.upload_folder(
|
|
repo_id=repo_id,
|
|
repo_type="model",
|
|
folder_path=saved_path,
|
|
commit_message=commit_message,
|
|
revision=branch,
|
|
create_pr=create_pr,
|
|
allow_patterns=allow_patterns,
|
|
ignore_patterns=ignore_patterns,
|
|
delete_patterns=delete_patterns,
|
|
)
|