Support for DDIMScheduler in Diffusion Policy (#146)

This commit is contained in:
Akshay Kashyap 2024-05-08 13:05:16 -04:00 committed by GitHub
parent f5de57b385
commit 460df2ccea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 25 additions and 2 deletions

View File

@ -51,6 +51,7 @@ class DiffusionConfig:
use_film_scale_modulation: FiLM (https://arxiv.org/abs/1709.07871) is used for the Unet conditioning. use_film_scale_modulation: FiLM (https://arxiv.org/abs/1709.07871) is used for the Unet conditioning.
Bias modulation is used be default, while this parameter indicates whether to also use scale Bias modulation is used be default, while this parameter indicates whether to also use scale
modulation. modulation.
noise_scheduler_type: Name of the noise scheduler to use. Supported options: ["DDPM", "DDIM"].
num_train_timesteps: Number of diffusion steps for the forward diffusion schedule. num_train_timesteps: Number of diffusion steps for the forward diffusion schedule.
beta_schedule: Name of the diffusion beta schedule as per DDPMScheduler from Hugging Face diffusers. beta_schedule: Name of the diffusion beta schedule as per DDPMScheduler from Hugging Face diffusers.
beta_start: Beta value for the first forward-diffusion step. beta_start: Beta value for the first forward-diffusion step.
@ -110,6 +111,7 @@ class DiffusionConfig:
diffusion_step_embed_dim: int = 128 diffusion_step_embed_dim: int = 128
use_film_scale_modulation: bool = True use_film_scale_modulation: bool = True
# Noise scheduler. # Noise scheduler.
noise_scheduler_type: str = "DDPM"
num_train_timesteps: int = 100 num_train_timesteps: int = 100
beta_schedule: str = "squaredcos_cap_v2" beta_schedule: str = "squaredcos_cap_v2"
beta_start: float = 0.0001 beta_start: float = 0.0001
@ -144,3 +146,9 @@ class DiffusionConfig:
raise ValueError( raise ValueError(
f"`prediction_type` must be one of {supported_prediction_types}. Got {self.prediction_type}." f"`prediction_type` must be one of {supported_prediction_types}. Got {self.prediction_type}."
) )
supported_noise_schedulers = ["DDPM", "DDIM"]
if self.noise_scheduler_type not in supported_noise_schedulers:
raise ValueError(
f"`noise_scheduler_type` must be one of {supported_noise_schedulers}. "
f"Got {self.noise_scheduler_type}."
)

View File

@ -13,6 +13,7 @@ import einops
import torch import torch
import torch.nn.functional as F # noqa: N812 import torch.nn.functional as F # noqa: N812
import torchvision import torchvision
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from huggingface_hub import PyTorchModelHubMixin from huggingface_hub import PyTorchModelHubMixin
from robomimic.models.base_nets import SpatialSoftmax from robomimic.models.base_nets import SpatialSoftmax
@ -126,6 +127,19 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
return {"loss": loss} return {"loss": loss}
def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMScheduler:
"""
Factory for noise scheduler instances of the requested type. All kwargs are passed
to the scheduler.
"""
if name == "DDPM":
return DDPMScheduler(**kwargs)
elif name == "DDIM":
return DDIMScheduler(**kwargs)
else:
raise ValueError(f"Unsupported noise scheduler type {name}")
class DiffusionModel(nn.Module): class DiffusionModel(nn.Module):
def __init__(self, config: DiffusionConfig): def __init__(self, config: DiffusionConfig):
super().__init__() super().__init__()
@ -138,12 +152,12 @@ class DiffusionModel(nn.Module):
* config.n_obs_steps, * config.n_obs_steps,
) )
self.noise_scheduler = DDPMScheduler( self.noise_scheduler = _make_noise_scheduler(
config.noise_scheduler_type,
num_train_timesteps=config.num_train_timesteps, num_train_timesteps=config.num_train_timesteps,
beta_start=config.beta_start, beta_start=config.beta_start,
beta_end=config.beta_end, beta_end=config.beta_end,
beta_schedule=config.beta_schedule, beta_schedule=config.beta_schedule,
variance_type="fixed_small",
clip_sample=config.clip_sample, clip_sample=config.clip_sample,
clip_sample_range=config.clip_sample_range, clip_sample_range=config.clip_sample_range,
prediction_type=config.prediction_type, prediction_type=config.prediction_type,

View File

@ -85,6 +85,7 @@ policy:
diffusion_step_embed_dim: 128 diffusion_step_embed_dim: 128
use_film_scale_modulation: True use_film_scale_modulation: True
# Noise scheduler. # Noise scheduler.
noise_scheduler_type: DDPM
num_train_timesteps: 100 num_train_timesteps: 100
beta_schedule: squaredcos_cap_v2 beta_schedule: squaredcos_cap_v2
beta_start: 0.0001 beta_start: 0.0001