From 47de07658c3ee02c9b65b2632152c9311759a46b Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Wed, 8 May 2024 12:56:21 +0100 Subject: [PATCH 1/3] Override pretrained model config (#147) --- lerobot/common/policies/factory.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 4819ca80..a819d18f 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -1,4 +1,5 @@ import inspect +import logging from omegaconf import DictConfig, OmegaConf @@ -8,9 +9,10 @@ from lerobot.common.utils.utils import get_safe_torch_device def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg): expected_kwargs = set(inspect.signature(policy_cfg_class).parameters) - assert set(hydra_cfg.policy).issuperset( - expected_kwargs - ), f"Hydra config is missing arguments: {set(expected_kwargs).difference(hydra_cfg.policy)}" + if not set(hydra_cfg.policy).issuperset(expected_kwargs): + logging.warning( + f"Hydra config is missing arguments: {set(expected_kwargs).difference(hydra_cfg.policy)}" + ) policy_cfg = policy_cfg_class( **{ k: v @@ -62,11 +64,18 @@ def make_policy( policy_cls, policy_cfg_class = get_policy_and_config_classes(hydra_cfg.policy.name) + policy_cfg = _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg) if pretrained_policy_name_or_path is None: - policy_cfg = _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg) + # Make a fresh policy. policy = policy_cls(policy_cfg, dataset_stats) else: - policy = policy_cls.from_pretrained(pretrained_policy_name_or_path) + # Load a pretrained policy and override the config if needed (for example, if there are inference-time + # hyperparameters that we want to vary). + # TODO(alexander-soare): This hack makes use of huggingface_hub's tooling to load the policy with, pretrained + # weights which are then loaded into a fresh policy with the desired config. This PR in huggingface_hub should + # make it possible to avoid the hack: https://github.com/huggingface/huggingface_hub/pull/2274. + policy = policy_cls(policy_cfg) + policy.load_state_dict(policy_cls.from_pretrained(pretrained_policy_name_or_path).state_dict()) policy.to(get_safe_torch_device(hydra_cfg.device)) From f5de57b385090da0cefed43ca4f2d832bb554af1 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Wed, 8 May 2024 14:57:29 +0100 Subject: [PATCH 2/3] Fix SpatialSoftmax input shape (#150) --- .../common/policies/diffusion/modeling_diffusion.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 91cf6dd0..a7ba5442 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -315,11 +315,13 @@ class DiffusionRgbEncoder(nn.Module): # Set up pooling and final layers. # Use a dry run to get the feature map shape. + # The dummy input should take the number of image channels from `config.input_shapes` and it should use the + # height and width from `config.crop_shape`. + dummy_input = torch.zeros(size=(1, config.input_shapes["observation.image"][0], *config.crop_shape)) with torch.inference_mode(): - feat_map_shape = tuple( - self.backbone(torch.zeros(size=(1, *config.input_shapes["observation.image"]))).shape[1:] - ) - self.pool = SpatialSoftmax(feat_map_shape, num_kp=config.spatial_softmax_num_keypoints) + dummy_feature_map = self.backbone(dummy_input) + feature_map_shape = tuple(dummy_feature_map.shape[1:]) + self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints) self.feature_dim = config.spatial_softmax_num_keypoints * 2 self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim) self.relu = nn.ReLU() From 460df2ccea1e8dbf3bb76f53ee6fde30761a00fc Mon Sep 17 00:00:00 2001 From: Akshay Kashyap Date: Wed, 8 May 2024 13:05:16 -0400 Subject: [PATCH 3/3] Support for DDIMScheduler in Diffusion Policy (#146) --- .../diffusion/configuration_diffusion.py | 8 ++++++++ .../policies/diffusion/modeling_diffusion.py | 18 ++++++++++++++++-- lerobot/configs/policy/diffusion.yaml | 1 + 3 files changed, 25 insertions(+), 2 deletions(-) diff --git a/lerobot/common/policies/diffusion/configuration_diffusion.py b/lerobot/common/policies/diffusion/configuration_diffusion.py index d7341c33..28a514ab 100644 --- a/lerobot/common/policies/diffusion/configuration_diffusion.py +++ b/lerobot/common/policies/diffusion/configuration_diffusion.py @@ -51,6 +51,7 @@ class DiffusionConfig: use_film_scale_modulation: FiLM (https://arxiv.org/abs/1709.07871) is used for the Unet conditioning. Bias modulation is used be default, while this parameter indicates whether to also use scale modulation. + noise_scheduler_type: Name of the noise scheduler to use. Supported options: ["DDPM", "DDIM"]. num_train_timesteps: Number of diffusion steps for the forward diffusion schedule. beta_schedule: Name of the diffusion beta schedule as per DDPMScheduler from Hugging Face diffusers. beta_start: Beta value for the first forward-diffusion step. @@ -110,6 +111,7 @@ class DiffusionConfig: diffusion_step_embed_dim: int = 128 use_film_scale_modulation: bool = True # Noise scheduler. + noise_scheduler_type: str = "DDPM" num_train_timesteps: int = 100 beta_schedule: str = "squaredcos_cap_v2" beta_start: float = 0.0001 @@ -144,3 +146,9 @@ class DiffusionConfig: raise ValueError( f"`prediction_type` must be one of {supported_prediction_types}. Got {self.prediction_type}." ) + supported_noise_schedulers = ["DDPM", "DDIM"] + if self.noise_scheduler_type not in supported_noise_schedulers: + raise ValueError( + f"`noise_scheduler_type` must be one of {supported_noise_schedulers}. " + f"Got {self.noise_scheduler_type}." + ) diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index a7ba5442..3115160f 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -13,6 +13,7 @@ import einops import torch import torch.nn.functional as F # noqa: N812 import torchvision +from diffusers.schedulers.scheduling_ddim import DDIMScheduler from diffusers.schedulers.scheduling_ddpm import DDPMScheduler from huggingface_hub import PyTorchModelHubMixin from robomimic.models.base_nets import SpatialSoftmax @@ -126,6 +127,19 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin): return {"loss": loss} +def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMScheduler: + """ + Factory for noise scheduler instances of the requested type. All kwargs are passed + to the scheduler. + """ + if name == "DDPM": + return DDPMScheduler(**kwargs) + elif name == "DDIM": + return DDIMScheduler(**kwargs) + else: + raise ValueError(f"Unsupported noise scheduler type {name}") + + class DiffusionModel(nn.Module): def __init__(self, config: DiffusionConfig): super().__init__() @@ -138,12 +152,12 @@ class DiffusionModel(nn.Module): * config.n_obs_steps, ) - self.noise_scheduler = DDPMScheduler( + self.noise_scheduler = _make_noise_scheduler( + config.noise_scheduler_type, num_train_timesteps=config.num_train_timesteps, beta_start=config.beta_start, beta_end=config.beta_end, beta_schedule=config.beta_schedule, - variance_type="fixed_small", clip_sample=config.clip_sample, clip_sample_range=config.clip_sample_range, prediction_type=config.prediction_type, diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index 2d611c88..9a4aeb2a 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -85,6 +85,7 @@ policy: diffusion_step_embed_dim: 128 use_film_scale_modulation: True # Noise scheduler. + noise_scheduler_type: DDPM num_train_timesteps: 100 beta_schedule: squaredcos_cap_v2 beta_start: 0.0001