diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 0a13ece6..4ca21413 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -28,6 +28,19 @@ from lerobot.common.policies.utils import ( ) +def _make_noise_scheduler(name: str, kwargs: dict) -> DDPMScheduler | DDIMScheduler: + """ + Factory for noise scheduler instances of the given name. All kwargs are passed + to the scheduler. + """ + if name == "DDPM": + return DDPMScheduler(**kwargs) + elif name == "DDIM": + return DDIMScheduler(**kwargs) + else: + raise ValueError(f"Unsupported noise scheduler type {name}") + + class DiffusionPolicy(nn.Module, PyTorchModelHubMixin): """ Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion" @@ -139,10 +152,8 @@ class DiffusionModel(nn.Module): * config.n_obs_steps, ) - noise_scheduler_class = DDPMScheduler - if config.noise_scheduler_type == "DDIM": - noise_scheduler_class = DDIMScheduler - self.noise_scheduler = noise_scheduler_class( + self.noise_scheduler = _make_noise_scheduler( + config.noise_scheduler_type, num_train_timesteps=config.num_train_timesteps, beta_start=config.beta_start, beta_end=config.beta_end,