factory for scheduler
This commit is contained in:
parent
fe6899b91f
commit
b54d05276c
|
@ -28,6 +28,19 @@ from lerobot.common.policies.utils import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_noise_scheduler(name: str, kwargs: dict) -> DDPMScheduler | DDIMScheduler:
|
||||||
|
"""
|
||||||
|
Factory for noise scheduler instances of the given name. All kwargs are passed
|
||||||
|
to the scheduler.
|
||||||
|
"""
|
||||||
|
if name == "DDPM":
|
||||||
|
return DDPMScheduler(**kwargs)
|
||||||
|
elif name == "DDIM":
|
||||||
|
return DDIMScheduler(**kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported noise scheduler type {name}")
|
||||||
|
|
||||||
|
|
||||||
class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
"""
|
"""
|
||||||
Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
|
Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
|
||||||
|
@ -139,10 +152,8 @@ class DiffusionModel(nn.Module):
|
||||||
* config.n_obs_steps,
|
* config.n_obs_steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
noise_scheduler_class = DDPMScheduler
|
self.noise_scheduler = _make_noise_scheduler(
|
||||||
if config.noise_scheduler_type == "DDIM":
|
config.noise_scheduler_type,
|
||||||
noise_scheduler_class = DDIMScheduler
|
|
||||||
self.noise_scheduler = noise_scheduler_class(
|
|
||||||
num_train_timesteps=config.num_train_timesteps,
|
num_train_timesteps=config.num_train_timesteps,
|
||||||
beta_start=config.beta_start,
|
beta_start=config.beta_start,
|
||||||
beta_end=config.beta_end,
|
beta_end=config.beta_end,
|
||||||
|
|
Loading…
Reference in New Issue