Support for DDIMScheduler in Diffusion Policy (#146)
This commit is contained in:
parent
f5de57b385
commit
460df2ccea
|
@ -51,6 +51,7 @@ class DiffusionConfig:
|
||||||
use_film_scale_modulation: FiLM (https://arxiv.org/abs/1709.07871) is used for the Unet conditioning.
|
use_film_scale_modulation: FiLM (https://arxiv.org/abs/1709.07871) is used for the Unet conditioning.
|
||||||
Bias modulation is used be default, while this parameter indicates whether to also use scale
|
Bias modulation is used be default, while this parameter indicates whether to also use scale
|
||||||
modulation.
|
modulation.
|
||||||
|
noise_scheduler_type: Name of the noise scheduler to use. Supported options: ["DDPM", "DDIM"].
|
||||||
num_train_timesteps: Number of diffusion steps for the forward diffusion schedule.
|
num_train_timesteps: Number of diffusion steps for the forward diffusion schedule.
|
||||||
beta_schedule: Name of the diffusion beta schedule as per DDPMScheduler from Hugging Face diffusers.
|
beta_schedule: Name of the diffusion beta schedule as per DDPMScheduler from Hugging Face diffusers.
|
||||||
beta_start: Beta value for the first forward-diffusion step.
|
beta_start: Beta value for the first forward-diffusion step.
|
||||||
|
@ -110,6 +111,7 @@ class DiffusionConfig:
|
||||||
diffusion_step_embed_dim: int = 128
|
diffusion_step_embed_dim: int = 128
|
||||||
use_film_scale_modulation: bool = True
|
use_film_scale_modulation: bool = True
|
||||||
# Noise scheduler.
|
# Noise scheduler.
|
||||||
|
noise_scheduler_type: str = "DDPM"
|
||||||
num_train_timesteps: int = 100
|
num_train_timesteps: int = 100
|
||||||
beta_schedule: str = "squaredcos_cap_v2"
|
beta_schedule: str = "squaredcos_cap_v2"
|
||||||
beta_start: float = 0.0001
|
beta_start: float = 0.0001
|
||||||
|
@ -144,3 +146,9 @@ class DiffusionConfig:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`prediction_type` must be one of {supported_prediction_types}. Got {self.prediction_type}."
|
f"`prediction_type` must be one of {supported_prediction_types}. Got {self.prediction_type}."
|
||||||
)
|
)
|
||||||
|
supported_noise_schedulers = ["DDPM", "DDIM"]
|
||||||
|
if self.noise_scheduler_type not in supported_noise_schedulers:
|
||||||
|
raise ValueError(
|
||||||
|
f"`noise_scheduler_type` must be one of {supported_noise_schedulers}. "
|
||||||
|
f"Got {self.noise_scheduler_type}."
|
||||||
|
)
|
||||||
|
|
|
@ -13,6 +13,7 @@ import einops
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F # noqa: N812
|
import torch.nn.functional as F # noqa: N812
|
||||||
import torchvision
|
import torchvision
|
||||||
|
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||||
from huggingface_hub import PyTorchModelHubMixin
|
from huggingface_hub import PyTorchModelHubMixin
|
||||||
from robomimic.models.base_nets import SpatialSoftmax
|
from robomimic.models.base_nets import SpatialSoftmax
|
||||||
|
@ -126,6 +127,19 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
return {"loss": loss}
|
return {"loss": loss}
|
||||||
|
|
||||||
|
|
||||||
|
def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMScheduler:
|
||||||
|
"""
|
||||||
|
Factory for noise scheduler instances of the requested type. All kwargs are passed
|
||||||
|
to the scheduler.
|
||||||
|
"""
|
||||||
|
if name == "DDPM":
|
||||||
|
return DDPMScheduler(**kwargs)
|
||||||
|
elif name == "DDIM":
|
||||||
|
return DDIMScheduler(**kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported noise scheduler type {name}")
|
||||||
|
|
||||||
|
|
||||||
class DiffusionModel(nn.Module):
|
class DiffusionModel(nn.Module):
|
||||||
def __init__(self, config: DiffusionConfig):
|
def __init__(self, config: DiffusionConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -138,12 +152,12 @@ class DiffusionModel(nn.Module):
|
||||||
* config.n_obs_steps,
|
* config.n_obs_steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.noise_scheduler = DDPMScheduler(
|
self.noise_scheduler = _make_noise_scheduler(
|
||||||
|
config.noise_scheduler_type,
|
||||||
num_train_timesteps=config.num_train_timesteps,
|
num_train_timesteps=config.num_train_timesteps,
|
||||||
beta_start=config.beta_start,
|
beta_start=config.beta_start,
|
||||||
beta_end=config.beta_end,
|
beta_end=config.beta_end,
|
||||||
beta_schedule=config.beta_schedule,
|
beta_schedule=config.beta_schedule,
|
||||||
variance_type="fixed_small",
|
|
||||||
clip_sample=config.clip_sample,
|
clip_sample=config.clip_sample,
|
||||||
clip_sample_range=config.clip_sample_range,
|
clip_sample_range=config.clip_sample_range,
|
||||||
prediction_type=config.prediction_type,
|
prediction_type=config.prediction_type,
|
||||||
|
|
|
@ -85,6 +85,7 @@ policy:
|
||||||
diffusion_step_embed_dim: 128
|
diffusion_step_embed_dim: 128
|
||||||
use_film_scale_modulation: True
|
use_film_scale_modulation: True
|
||||||
# Noise scheduler.
|
# Noise scheduler.
|
||||||
|
noise_scheduler_type: DDPM
|
||||||
num_train_timesteps: 100
|
num_train_timesteps: 100
|
||||||
beta_schedule: squaredcos_cap_v2
|
beta_schedule: squaredcos_cap_v2
|
||||||
beta_start: 0.0001
|
beta_start: 0.0001
|
||||||
|
|
Loading…
Reference in New Issue