177 lines
6.7 KiB
Python
177 lines
6.7 KiB
Python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import abc
|
|
import logging
|
|
import os
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Type, TypeVar
|
|
|
|
import draccus
|
|
from huggingface_hub import hf_hub_download
|
|
from huggingface_hub.constants import CONFIG_NAME
|
|
from huggingface_hub.errors import HfHubHTTPError
|
|
|
|
from lerobot.common.optim.optimizers import OptimizerConfig
|
|
from lerobot.common.optim.schedulers import LRSchedulerConfig
|
|
from lerobot.common.utils.hub import HubMixin
|
|
from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
|
|
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
|
|
|
# Generic variable that is either PreTrainedConfig or a subclass thereof
|
|
T = TypeVar("T", bound="PreTrainedConfig")
|
|
|
|
|
|
@dataclass
|
|
class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
|
"""
|
|
Base configuration class for policy models.
|
|
|
|
Args:
|
|
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
|
current step and additional steps going back).
|
|
input_shapes: A dictionary defining the shapes of the input data for the policy.
|
|
output_shapes: A dictionary defining the shapes of the output data for the policy.
|
|
input_normalization_modes: A dictionary with key representing the modality and the value specifies the
|
|
normalization mode to apply.
|
|
output_normalization_modes: Similar dictionary as `input_normalization_modes`, but to unnormalize to
|
|
the original scale.
|
|
"""
|
|
|
|
n_obs_steps: int = 1
|
|
normalization_mapping: dict[str, NormalizationMode] = field(default_factory=dict)
|
|
|
|
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
|
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
|
|
|
device: str | None = None # cuda | cpu | mp
|
|
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
|
|
# automatic gradient scaling is used.
|
|
use_amp: bool = False
|
|
|
|
def __post_init__(self):
|
|
self.pretrained_path = None
|
|
if not self.device or not is_torch_device_available(self.device):
|
|
auto_device = auto_select_torch_device()
|
|
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
|
|
self.device = auto_device.type
|
|
|
|
# Automatically deactivate AMP if necessary
|
|
if self.use_amp and not is_amp_available(self.device):
|
|
logging.warning(
|
|
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
|
|
)
|
|
self.use_amp = False
|
|
|
|
@property
|
|
def type(self) -> str:
|
|
return self.get_choice_name(self.__class__)
|
|
|
|
@abc.abstractproperty
|
|
def observation_delta_indices(self) -> list | None:
|
|
raise NotImplementedError
|
|
|
|
@abc.abstractproperty
|
|
def action_delta_indices(self) -> list | None:
|
|
raise NotImplementedError
|
|
|
|
@abc.abstractproperty
|
|
def reward_delta_indices(self) -> list | None:
|
|
raise NotImplementedError
|
|
|
|
@abc.abstractmethod
|
|
def get_optimizer_preset(self) -> OptimizerConfig:
|
|
raise NotImplementedError
|
|
|
|
@abc.abstractmethod
|
|
def get_scheduler_preset(self) -> LRSchedulerConfig | None:
|
|
raise NotImplementedError
|
|
|
|
@abc.abstractmethod
|
|
def validate_features(self) -> None:
|
|
raise NotImplementedError
|
|
|
|
@property
|
|
def robot_state_feature(self) -> PolicyFeature | None:
|
|
for _, ft in self.input_features.items():
|
|
if ft.type is FeatureType.STATE:
|
|
return ft
|
|
return None
|
|
|
|
@property
|
|
def env_state_feature(self) -> PolicyFeature | None:
|
|
for _, ft in self.input_features.items():
|
|
if ft.type is FeatureType.ENV:
|
|
return ft
|
|
return None
|
|
|
|
@property
|
|
def image_features(self) -> dict[str, PolicyFeature]:
|
|
return {key: ft for key, ft in self.input_features.items() if ft.type is FeatureType.VISUAL}
|
|
|
|
@property
|
|
def action_feature(self) -> PolicyFeature | None:
|
|
for _, ft in self.output_features.items():
|
|
if ft.type is FeatureType.ACTION:
|
|
return ft
|
|
return None
|
|
|
|
def _save_pretrained(self, save_directory: Path) -> None:
|
|
with open(save_directory / CONFIG_NAME, "w") as f, draccus.config_type("json"):
|
|
draccus.dump(self, f, indent=4)
|
|
|
|
@classmethod
|
|
def from_pretrained(
|
|
cls: Type[T],
|
|
pretrained_name_or_path: str | Path,
|
|
*,
|
|
force_download: bool = False,
|
|
resume_download: bool = None,
|
|
proxies: dict | None = None,
|
|
token: str | bool | None = None,
|
|
cache_dir: str | Path | None = None,
|
|
local_files_only: bool = False,
|
|
revision: str | None = None,
|
|
**policy_kwargs,
|
|
) -> T:
|
|
model_id = str(pretrained_name_or_path)
|
|
config_file: str | None = None
|
|
if Path(model_id).is_dir():
|
|
if CONFIG_NAME in os.listdir(model_id):
|
|
config_file = os.path.join(model_id, CONFIG_NAME)
|
|
else:
|
|
print(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
|
|
else:
|
|
try:
|
|
config_file = hf_hub_download(
|
|
repo_id=model_id,
|
|
filename=CONFIG_NAME,
|
|
revision=revision,
|
|
cache_dir=cache_dir,
|
|
force_download=force_download,
|
|
proxies=proxies,
|
|
resume_download=resume_download,
|
|
token=token,
|
|
local_files_only=local_files_only,
|
|
)
|
|
except HfHubHTTPError as e:
|
|
raise FileNotFoundError(
|
|
f"{CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
|
|
) from e
|
|
|
|
# HACK: this is very ugly, ideally we'd like to be able to do that natively with draccus
|
|
# something like --policy.path (in addition to --policy.type)
|
|
cli_overrides = policy_kwargs.pop("cli_overrides", [])
|
|
return draccus.parse(cls, config_file, args=cli_overrides)
|