From 47de07658c3ee02c9b65b2632152c9311759a46b Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Wed, 8 May 2024 12:56:21 +0100 Subject: [PATCH] Override pretrained model config (#147) --- lerobot/common/policies/factory.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 4819ca80..a819d18f 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -1,4 +1,5 @@ import inspect +import logging from omegaconf import DictConfig, OmegaConf @@ -8,9 +9,10 @@ from lerobot.common.utils.utils import get_safe_torch_device def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg): expected_kwargs = set(inspect.signature(policy_cfg_class).parameters) - assert set(hydra_cfg.policy).issuperset( - expected_kwargs - ), f"Hydra config is missing arguments: {set(expected_kwargs).difference(hydra_cfg.policy)}" + if not set(hydra_cfg.policy).issuperset(expected_kwargs): + logging.warning( + f"Hydra config is missing arguments: {set(expected_kwargs).difference(hydra_cfg.policy)}" + ) policy_cfg = policy_cfg_class( **{ k: v @@ -62,11 +64,18 @@ def make_policy( policy_cls, policy_cfg_class = get_policy_and_config_classes(hydra_cfg.policy.name) + policy_cfg = _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg) if pretrained_policy_name_or_path is None: - policy_cfg = _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg) + # Make a fresh policy. policy = policy_cls(policy_cfg, dataset_stats) else: - policy = policy_cls.from_pretrained(pretrained_policy_name_or_path) + # Load a pretrained policy and override the config if needed (for example, if there are inference-time + # hyperparameters that we want to vary). + # TODO(alexander-soare): This hack makes use of huggingface_hub's tooling to load the policy with, pretrained + # weights which are then loaded into a fresh policy with the desired config. This PR in huggingface_hub should + # make it possible to avoid the hack: https://github.com/huggingface/huggingface_hub/pull/2274. + policy = policy_cls(policy_cfg) + policy.load_state_dict(policy_cls.from_pretrained(pretrained_policy_name_or_path).state_dict()) policy.to(get_safe_torch_device(hydra_cfg.device))