Support for DDIMScheduler in Diffusion Policy (#146)
This commit is contained in:
parent
f5de57b385
commit
460df2ccea
|
@ -51,6 +51,7 @@ class DiffusionConfig:
|
|||
use_film_scale_modulation: FiLM (https://arxiv.org/abs/1709.07871) is used for the Unet conditioning.
|
||||
Bias modulation is used be default, while this parameter indicates whether to also use scale
|
||||
modulation.
|
||||
noise_scheduler_type: Name of the noise scheduler to use. Supported options: ["DDPM", "DDIM"].
|
||||
num_train_timesteps: Number of diffusion steps for the forward diffusion schedule.
|
||||
beta_schedule: Name of the diffusion beta schedule as per DDPMScheduler from Hugging Face diffusers.
|
||||
beta_start: Beta value for the first forward-diffusion step.
|
||||
|
@ -110,6 +111,7 @@ class DiffusionConfig:
|
|||
diffusion_step_embed_dim: int = 128
|
||||
use_film_scale_modulation: bool = True
|
||||
# Noise scheduler.
|
||||
noise_scheduler_type: str = "DDPM"
|
||||
num_train_timesteps: int = 100
|
||||
beta_schedule: str = "squaredcos_cap_v2"
|
||||
beta_start: float = 0.0001
|
||||
|
@ -144,3 +146,9 @@ class DiffusionConfig:
|
|||
raise ValueError(
|
||||
f"`prediction_type` must be one of {supported_prediction_types}. Got {self.prediction_type}."
|
||||
)
|
||||
supported_noise_schedulers = ["DDPM", "DDIM"]
|
||||
if self.noise_scheduler_type not in supported_noise_schedulers:
|
||||
raise ValueError(
|
||||
f"`noise_scheduler_type` must be one of {supported_noise_schedulers}. "
|
||||
f"Got {self.noise_scheduler_type}."
|
||||
)
|
||||
|
|
|
@ -13,6 +13,7 @@ import einops
|
|||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
import torchvision
|
||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from robomimic.models.base_nets import SpatialSoftmax
|
||||
|
@ -126,6 +127,19 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
return {"loss": loss}
|
||||
|
||||
|
||||
def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMScheduler:
|
||||
"""
|
||||
Factory for noise scheduler instances of the requested type. All kwargs are passed
|
||||
to the scheduler.
|
||||
"""
|
||||
if name == "DDPM":
|
||||
return DDPMScheduler(**kwargs)
|
||||
elif name == "DDIM":
|
||||
return DDIMScheduler(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unsupported noise scheduler type {name}")
|
||||
|
||||
|
||||
class DiffusionModel(nn.Module):
|
||||
def __init__(self, config: DiffusionConfig):
|
||||
super().__init__()
|
||||
|
@ -138,12 +152,12 @@ class DiffusionModel(nn.Module):
|
|||
* config.n_obs_steps,
|
||||
)
|
||||
|
||||
self.noise_scheduler = DDPMScheduler(
|
||||
self.noise_scheduler = _make_noise_scheduler(
|
||||
config.noise_scheduler_type,
|
||||
num_train_timesteps=config.num_train_timesteps,
|
||||
beta_start=config.beta_start,
|
||||
beta_end=config.beta_end,
|
||||
beta_schedule=config.beta_schedule,
|
||||
variance_type="fixed_small",
|
||||
clip_sample=config.clip_sample,
|
||||
clip_sample_range=config.clip_sample_range,
|
||||
prediction_type=config.prediction_type,
|
||||
|
|
|
@ -85,6 +85,7 @@ policy:
|
|||
diffusion_step_embed_dim: 128
|
||||
use_film_scale_modulation: True
|
||||
# Noise scheduler.
|
||||
noise_scheduler_type: DDPM
|
||||
num_train_timesteps: 100
|
||||
beta_schedule: squaredcos_cap_v2
|
||||
beta_start: 0.0001
|
||||
|
|
Loading…
Reference in New Issue