Support for DDIMScheduler in Diffusion

This commit is contained in:
Akshay Kashyap 2024-05-07 21:32:29 -04:00
parent 26d9a070d8
commit 7a80407d4b
3 changed files with 13 additions and 2 deletions

View File

@ -110,6 +110,7 @@ class DiffusionConfig:
diffusion_step_embed_dim: int = 128
use_film_scale_modulation: bool = True
# Noise scheduler.
noise_scheduler_type: str = 'DDPM'
num_train_timesteps: int = 100
beta_schedule: str = "squaredcos_cap_v2"
beta_start: float = 0.0001
@ -144,3 +145,9 @@ class DiffusionConfig:
raise ValueError(
f"`prediction_type` must be one of {supported_prediction_types}. Got {self.prediction_type}."
)
supported_noise_schedulers = ["DDPM", "DDIM"]
if self.noise_scheduler_type not in supported_noise_schedulers:
raise ValueError(
f"`noise_scheduler_type` must be one of {supported_noise_schedulers}. "
f"Got {self.noise_scheduler_type}."
)

View File

@ -14,6 +14,7 @@ import torch
import torch.nn.functional as F # noqa: N812
import torchvision
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from huggingface_hub import PyTorchModelHubMixin
from robomimic.models.base_nets import SpatialSoftmax
from torch import Tensor, nn
@ -138,12 +139,14 @@ class DiffusionModel(nn.Module):
* config.n_obs_steps,
)
self.noise_scheduler = DDPMScheduler(
noise_scheduler_class = DDPMScheduler
if config.noise_scheduler_type == "DDIM":
noise_scheduler_class = DDIMScheduler
self.noise_scheduler = noise_scheduler_class(
num_train_timesteps=config.num_train_timesteps,
beta_start=config.beta_start,
beta_end=config.beta_end,
beta_schedule=config.beta_schedule,
variance_type="fixed_small",
clip_sample=config.clip_sample,
clip_sample_range=config.clip_sample_range,
prediction_type=config.prediction_type,

View File

@ -85,6 +85,7 @@ policy:
diffusion_step_embed_dim: 128
use_film_scale_modulation: True
# Noise scheduler.
noise_scheduler_type: DDPM
num_train_timesteps: 100
beta_schedule: squaredcos_cap_v2
beta_start: 0.0001