Support for DDIMScheduler in Diffusion
This commit is contained in:
parent
26d9a070d8
commit
7a80407d4b
|
@ -110,6 +110,7 @@ class DiffusionConfig:
|
|||
diffusion_step_embed_dim: int = 128
|
||||
use_film_scale_modulation: bool = True
|
||||
# Noise scheduler.
|
||||
noise_scheduler_type: str = 'DDPM'
|
||||
num_train_timesteps: int = 100
|
||||
beta_schedule: str = "squaredcos_cap_v2"
|
||||
beta_start: float = 0.0001
|
||||
|
@ -144,3 +145,9 @@ class DiffusionConfig:
|
|||
raise ValueError(
|
||||
f"`prediction_type` must be one of {supported_prediction_types}. Got {self.prediction_type}."
|
||||
)
|
||||
supported_noise_schedulers = ["DDPM", "DDIM"]
|
||||
if self.noise_scheduler_type not in supported_noise_schedulers:
|
||||
raise ValueError(
|
||||
f"`noise_scheduler_type` must be one of {supported_noise_schedulers}. "
|
||||
f"Got {self.noise_scheduler_type}."
|
||||
)
|
||||
|
|
|
@ -14,6 +14,7 @@ import torch
|
|||
import torch.nn.functional as F # noqa: N812
|
||||
import torchvision
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from robomimic.models.base_nets import SpatialSoftmax
|
||||
from torch import Tensor, nn
|
||||
|
@ -138,12 +139,14 @@ class DiffusionModel(nn.Module):
|
|||
* config.n_obs_steps,
|
||||
)
|
||||
|
||||
self.noise_scheduler = DDPMScheduler(
|
||||
noise_scheduler_class = DDPMScheduler
|
||||
if config.noise_scheduler_type == "DDIM":
|
||||
noise_scheduler_class = DDIMScheduler
|
||||
self.noise_scheduler = noise_scheduler_class(
|
||||
num_train_timesteps=config.num_train_timesteps,
|
||||
beta_start=config.beta_start,
|
||||
beta_end=config.beta_end,
|
||||
beta_schedule=config.beta_schedule,
|
||||
variance_type="fixed_small",
|
||||
clip_sample=config.clip_sample,
|
||||
clip_sample_range=config.clip_sample_range,
|
||||
prediction_type=config.prediction_type,
|
||||
|
|
|
@ -85,6 +85,7 @@ policy:
|
|||
diffusion_step_embed_dim: 128
|
||||
use_film_scale_modulation: True
|
||||
# Noise scheduler.
|
||||
noise_scheduler_type: DDPM
|
||||
num_train_timesteps: 100
|
||||
beta_schedule: squaredcos_cap_v2
|
||||
beta_start: 0.0001
|
||||
|
|
Loading…
Reference in New Issue