factory for scheduler

This commit is contained in:
Akshay Kashyap 2024-05-08 10:31:21 -04:00
parent fe6899b91f
commit b54d05276c
1 changed files with 15 additions and 4 deletions

View File

@ -28,6 +28,19 @@ from lerobot.common.policies.utils import (
)
def _make_noise_scheduler(name: str, kwargs: dict) -> DDPMScheduler | DDIMScheduler:
"""
Factory for noise scheduler instances of the given name. All kwargs are passed
to the scheduler.
"""
if name == "DDPM":
return DDPMScheduler(**kwargs)
elif name == "DDIM":
return DDIMScheduler(**kwargs)
else:
raise ValueError(f"Unsupported noise scheduler type {name}")
class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
"""
Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
@ -139,10 +152,8 @@ class DiffusionModel(nn.Module):
* config.n_obs_steps,
)
noise_scheduler_class = DDPMScheduler
if config.noise_scheduler_type == "DDIM":
noise_scheduler_class = DDIMScheduler
self.noise_scheduler = noise_scheduler_class(
self.noise_scheduler = _make_noise_scheduler(
config.noise_scheduler_type,
num_train_timesteps=config.num_train_timesteps,
beta_start=config.beta_start,
beta_end=config.beta_end,