From 460df2ccea1e8dbf3bb76f53ee6fde30761a00fc Mon Sep 17 00:00:00 2001 From: Akshay Kashyap Date: Wed, 8 May 2024 13:05:16 -0400 Subject: [PATCH] Support for DDIMScheduler in Diffusion Policy (#146) --- .../diffusion/configuration_diffusion.py | 8 ++++++++ .../policies/diffusion/modeling_diffusion.py | 18 ++++++++++++++++-- lerobot/configs/policy/diffusion.yaml | 1 + 3 files changed, 25 insertions(+), 2 deletions(-) diff --git a/lerobot/common/policies/diffusion/configuration_diffusion.py b/lerobot/common/policies/diffusion/configuration_diffusion.py index d7341c33..28a514ab 100644 --- a/lerobot/common/policies/diffusion/configuration_diffusion.py +++ b/lerobot/common/policies/diffusion/configuration_diffusion.py @@ -51,6 +51,7 @@ class DiffusionConfig: use_film_scale_modulation: FiLM (https://arxiv.org/abs/1709.07871) is used for the Unet conditioning. Bias modulation is used be default, while this parameter indicates whether to also use scale modulation. + noise_scheduler_type: Name of the noise scheduler to use. Supported options: ["DDPM", "DDIM"]. num_train_timesteps: Number of diffusion steps for the forward diffusion schedule. beta_schedule: Name of the diffusion beta schedule as per DDPMScheduler from Hugging Face diffusers. beta_start: Beta value for the first forward-diffusion step. @@ -110,6 +111,7 @@ class DiffusionConfig: diffusion_step_embed_dim: int = 128 use_film_scale_modulation: bool = True # Noise scheduler. + noise_scheduler_type: str = "DDPM" num_train_timesteps: int = 100 beta_schedule: str = "squaredcos_cap_v2" beta_start: float = 0.0001 @@ -144,3 +146,9 @@ class DiffusionConfig: raise ValueError( f"`prediction_type` must be one of {supported_prediction_types}. Got {self.prediction_type}." ) + supported_noise_schedulers = ["DDPM", "DDIM"] + if self.noise_scheduler_type not in supported_noise_schedulers: + raise ValueError( + f"`noise_scheduler_type` must be one of {supported_noise_schedulers}. " + f"Got {self.noise_scheduler_type}." + ) diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index a7ba5442..3115160f 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -13,6 +13,7 @@ import einops import torch import torch.nn.functional as F # noqa: N812 import torchvision +from diffusers.schedulers.scheduling_ddim import DDIMScheduler from diffusers.schedulers.scheduling_ddpm import DDPMScheduler from huggingface_hub import PyTorchModelHubMixin from robomimic.models.base_nets import SpatialSoftmax @@ -126,6 +127,19 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin): return {"loss": loss} +def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMScheduler: + """ + Factory for noise scheduler instances of the requested type. All kwargs are passed + to the scheduler. + """ + if name == "DDPM": + return DDPMScheduler(**kwargs) + elif name == "DDIM": + return DDIMScheduler(**kwargs) + else: + raise ValueError(f"Unsupported noise scheduler type {name}") + + class DiffusionModel(nn.Module): def __init__(self, config: DiffusionConfig): super().__init__() @@ -138,12 +152,12 @@ class DiffusionModel(nn.Module): * config.n_obs_steps, ) - self.noise_scheduler = DDPMScheduler( + self.noise_scheduler = _make_noise_scheduler( + config.noise_scheduler_type, num_train_timesteps=config.num_train_timesteps, beta_start=config.beta_start, beta_end=config.beta_end, beta_schedule=config.beta_schedule, - variance_type="fixed_small", clip_sample=config.clip_sample, clip_sample_range=config.clip_sample_range, prediction_type=config.prediction_type, diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index 2d611c88..9a4aeb2a 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -85,6 +85,7 @@ policy: diffusion_step_embed_dim: 128 use_film_scale_modulation: True # Noise scheduler. + noise_scheduler_type: DDPM num_train_timesteps: 100 beta_schedule: squaredcos_cap_v2 beta_start: 0.0001