From 9be0694271fc94d9ece4483666b29bbfb91eb7ed Mon Sep 17 00:00:00 2001 From: Akshay Kashyap Date: Wed, 8 May 2024 12:56:37 -0400 Subject: [PATCH] nit: style fix --- .../diffusion/configuration_diffusion.py | 2 +- .../policies/diffusion/modeling_diffusion.py | 26 +++++++++---------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/lerobot/common/policies/diffusion/configuration_diffusion.py b/lerobot/common/policies/diffusion/configuration_diffusion.py index 6fe3587b..28a514ab 100644 --- a/lerobot/common/policies/diffusion/configuration_diffusion.py +++ b/lerobot/common/policies/diffusion/configuration_diffusion.py @@ -111,7 +111,7 @@ class DiffusionConfig: diffusion_step_embed_dim: int = 128 use_film_scale_modulation: bool = True # Noise scheduler. - noise_scheduler_type: str = 'DDPM' + noise_scheduler_type: str = "DDPM" num_train_timesteps: int = 100 beta_schedule: str = "squaredcos_cap_v2" beta_start: float = 0.0001 diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 3cf0516e..3115160f 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -28,19 +28,6 @@ from lerobot.common.policies.utils import ( ) -def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMScheduler: - """ - Factory for noise scheduler instances of the requested type. All kwargs are passed - to the scheduler. - """ - if name == "DDPM": - return DDPMScheduler(**kwargs) - elif name == "DDIM": - return DDIMScheduler(**kwargs) - else: - raise ValueError(f"Unsupported noise scheduler type {name}") - - class DiffusionPolicy(nn.Module, PyTorchModelHubMixin): """ Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion" @@ -140,6 +127,19 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin): return {"loss": loss} +def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMScheduler: + """ + Factory for noise scheduler instances of the requested type. All kwargs are passed + to the scheduler. + """ + if name == "DDPM": + return DDPMScheduler(**kwargs) + elif name == "DDIM": + return DDIMScheduler(**kwargs) + else: + raise ValueError(f"Unsupported noise scheduler type {name}") + + class DiffusionModel(nn.Module): def __init__(self, config: DiffusionConfig): super().__init__()