200 lines
7.8 KiB
Python
200 lines
7.8 KiB
Python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import abc
|
|
import logging
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Type, TypeVar
|
|
|
|
import packaging
|
|
import safetensors
|
|
from huggingface_hub import hf_hub_download
|
|
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
|
from huggingface_hub.errors import HfHubHTTPError
|
|
from safetensors.torch import load_model as load_model_as_safetensor
|
|
from safetensors.torch import save_model as save_model_as_safetensor
|
|
from torch import Tensor, nn
|
|
|
|
from lerobot.common.utils.hub import HubMixin
|
|
from lerobot.configs.policies import PreTrainedConfig
|
|
|
|
T = TypeVar("T", bound="PreTrainedPolicy")
|
|
|
|
DEFAULT_POLICY_CARD = """
|
|
---
|
|
# For reference on model card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
|
|
# Doc / guide: https://huggingface.co/docs/hub/model-cards
|
|
{{ card_data }}
|
|
---
|
|
|
|
This policy has been pushed to the Hub using [LeRobot](https://github.com/huggingface/lerobot):
|
|
- Docs: {{ docs_url | default("[More Information Needed]", true) }}
|
|
"""
|
|
|
|
|
|
class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
|
"""
|
|
Base class for policy models.
|
|
"""
|
|
|
|
config_class: None
|
|
name: None
|
|
|
|
def __init__(self, config: PreTrainedConfig, *inputs, **kwargs):
|
|
super().__init__()
|
|
if not isinstance(config, PreTrainedConfig):
|
|
raise ValueError(
|
|
f"Parameter config in `{self.__class__.__name__}(config)` should be an instance of class "
|
|
"`PreTrainedConfig`. To create a model from a pretrained model use "
|
|
f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
|
|
)
|
|
self.config = config
|
|
|
|
def __init_subclass__(cls, **kwargs):
|
|
super().__init_subclass__(**kwargs)
|
|
if not getattr(cls, "config_class", None):
|
|
raise TypeError(f"Class {cls.__name__} must define 'config_class'")
|
|
if not getattr(cls, "name", None):
|
|
raise TypeError(f"Class {cls.__name__} must define 'name'")
|
|
|
|
def _save_pretrained(self, save_directory: Path) -> None:
|
|
self.config._save_pretrained(save_directory)
|
|
model_to_save = self.module if hasattr(self, "module") else self
|
|
save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE))
|
|
|
|
@classmethod
|
|
def from_pretrained(
|
|
cls: Type[T],
|
|
pretrained_name_or_path: str | Path,
|
|
*,
|
|
config: PreTrainedConfig | None = None,
|
|
force_download: bool = False,
|
|
resume_download: bool | None = None,
|
|
proxies: dict | None = None,
|
|
token: str | bool | None = None,
|
|
cache_dir: str | Path | None = None,
|
|
local_files_only: bool = False,
|
|
revision: str | None = None,
|
|
strict: bool = False,
|
|
**kwargs,
|
|
) -> T:
|
|
"""
|
|
The policy is set in evaluation mode by default using `policy.eval()` (dropout modules are
|
|
deactivated). To train it, you should first set it back in training mode with `policy.train()`.
|
|
"""
|
|
if config is None:
|
|
config = PreTrainedConfig.from_pretrained(
|
|
pretrained_name_or_path=pretrained_name_or_path,
|
|
force_download=force_download,
|
|
resume_download=resume_download,
|
|
proxies=proxies,
|
|
token=token,
|
|
cache_dir=cache_dir,
|
|
local_files_only=local_files_only,
|
|
revision=revision,
|
|
**kwargs,
|
|
)
|
|
model_id = str(pretrained_name_or_path)
|
|
instance = cls(config, **kwargs)
|
|
if os.path.isdir(model_id):
|
|
print("Loading weights from local directory")
|
|
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
|
|
policy = cls._load_as_safetensor(instance, model_file, config.device, strict)
|
|
else:
|
|
try:
|
|
model_file = hf_hub_download(
|
|
repo_id=model_id,
|
|
filename=SAFETENSORS_SINGLE_FILE,
|
|
revision=revision,
|
|
cache_dir=cache_dir,
|
|
force_download=force_download,
|
|
proxies=proxies,
|
|
resume_download=resume_download,
|
|
token=token,
|
|
local_files_only=local_files_only,
|
|
)
|
|
policy = cls._load_as_safetensor(instance, model_file, config.device, strict)
|
|
except HfHubHTTPError as e:
|
|
raise FileNotFoundError(
|
|
f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}"
|
|
) from e
|
|
|
|
policy.to(config.device)
|
|
policy.eval()
|
|
return policy
|
|
|
|
@classmethod
|
|
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
|
|
if packaging.version.parse(safetensors.__version__) < packaging.version.parse("0.4.3"):
|
|
load_model_as_safetensor(model, model_file, strict=strict)
|
|
if map_location != "cpu":
|
|
logging.warning(
|
|
"Loading model weights on other devices than 'cpu' is not supported natively in your version of safetensors."
|
|
" This means that the model is loaded on 'cpu' first and then copied to the device."
|
|
" This leads to a slower loading time."
|
|
" Please update safetensors to version 0.4.3 or above for improved performance."
|
|
)
|
|
model.to(map_location)
|
|
else:
|
|
safetensors.torch.load_model(model, model_file, strict=strict, device=map_location)
|
|
return model
|
|
|
|
# def generate_model_card(self, *args, **kwargs) -> ModelCard:
|
|
# card = ModelCard.from_template(
|
|
# card_data=self._hub_mixin_info.model_card_data,
|
|
# template_str=self._hub_mixin_info.model_card_template,
|
|
# repo_url=self._hub_mixin_info.repo_url,
|
|
# docs_url=self._hub_mixin_info.docs_url,
|
|
# **kwargs,
|
|
# )
|
|
# return card
|
|
|
|
@abc.abstractmethod
|
|
def get_optim_params(self) -> dict:
|
|
"""
|
|
Returns the policy-specific parameters dict to be passed on to the optimizer.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abc.abstractmethod
|
|
def reset(self):
|
|
"""To be called whenever the environment is reset.
|
|
|
|
Does things like clearing caches.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
# TODO(aliberts, rcadene): split into 'forward' and 'compute_loss'?
|
|
@abc.abstractmethod
|
|
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]:
|
|
"""_summary_
|
|
|
|
Args:
|
|
batch (dict[str, Tensor]): _description_
|
|
|
|
Returns:
|
|
tuple[Tensor, dict | None]: The loss and potentially other information. Apart from the loss which
|
|
is a Tensor, all other items should be logging-friendly, native Python types.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abc.abstractmethod
|
|
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
|
"""Return one action to run in the environment (potentially in batch mode).
|
|
|
|
When the model uses a history of observations, or outputs a sequence of actions, this method deals
|
|
with caching.
|
|
"""
|
|
raise NotImplementedError
|