From 47de07658c3ee02c9b65b2632152c9311759a46b Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Wed, 8 May 2024 12:56:21 +0100 Subject: [PATCH 1/2] Override pretrained model config (#147) --- lerobot/common/policies/factory.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 4819ca80..a819d18f 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -1,4 +1,5 @@ import inspect +import logging from omegaconf import DictConfig, OmegaConf @@ -8,9 +9,10 @@ from lerobot.common.utils.utils import get_safe_torch_device def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg): expected_kwargs = set(inspect.signature(policy_cfg_class).parameters) - assert set(hydra_cfg.policy).issuperset( - expected_kwargs - ), f"Hydra config is missing arguments: {set(expected_kwargs).difference(hydra_cfg.policy)}" + if not set(hydra_cfg.policy).issuperset(expected_kwargs): + logging.warning( + f"Hydra config is missing arguments: {set(expected_kwargs).difference(hydra_cfg.policy)}" + ) policy_cfg = policy_cfg_class( **{ k: v @@ -62,11 +64,18 @@ def make_policy( policy_cls, policy_cfg_class = get_policy_and_config_classes(hydra_cfg.policy.name) + policy_cfg = _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg) if pretrained_policy_name_or_path is None: - policy_cfg = _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg) + # Make a fresh policy. policy = policy_cls(policy_cfg, dataset_stats) else: - policy = policy_cls.from_pretrained(pretrained_policy_name_or_path) + # Load a pretrained policy and override the config if needed (for example, if there are inference-time + # hyperparameters that we want to vary). + # TODO(alexander-soare): This hack makes use of huggingface_hub's tooling to load the policy with, pretrained + # weights which are then loaded into a fresh policy with the desired config. This PR in huggingface_hub should + # make it possible to avoid the hack: https://github.com/huggingface/huggingface_hub/pull/2274. + policy = policy_cls(policy_cfg) + policy.load_state_dict(policy_cls.from_pretrained(pretrained_policy_name_or_path).state_dict()) policy.to(get_safe_torch_device(hydra_cfg.device)) From f5de57b385090da0cefed43ca4f2d832bb554af1 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Wed, 8 May 2024 14:57:29 +0100 Subject: [PATCH 2/2] Fix SpatialSoftmax input shape (#150) --- .../common/policies/diffusion/modeling_diffusion.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 91cf6dd0..a7ba5442 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -315,11 +315,13 @@ class DiffusionRgbEncoder(nn.Module): # Set up pooling and final layers. # Use a dry run to get the feature map shape. + # The dummy input should take the number of image channels from `config.input_shapes` and it should use the + # height and width from `config.crop_shape`. + dummy_input = torch.zeros(size=(1, config.input_shapes["observation.image"][0], *config.crop_shape)) with torch.inference_mode(): - feat_map_shape = tuple( - self.backbone(torch.zeros(size=(1, *config.input_shapes["observation.image"]))).shape[1:] - ) - self.pool = SpatialSoftmax(feat_map_shape, num_kp=config.spatial_softmax_num_keypoints) + dummy_feature_map = self.backbone(dummy_input) + feature_map_shape = tuple(dummy_feature_map.shape[1:]) + self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints) self.feature_dim = config.spatial_softmax_num_keypoints * 2 self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim) self.relu = nn.ReLU()