nit: style fix
This commit is contained in:
parent
3b285e2f26
commit
9be0694271
|
@ -111,7 +111,7 @@ class DiffusionConfig:
|
||||||
diffusion_step_embed_dim: int = 128
|
diffusion_step_embed_dim: int = 128
|
||||||
use_film_scale_modulation: bool = True
|
use_film_scale_modulation: bool = True
|
||||||
# Noise scheduler.
|
# Noise scheduler.
|
||||||
noise_scheduler_type: str = 'DDPM'
|
noise_scheduler_type: str = "DDPM"
|
||||||
num_train_timesteps: int = 100
|
num_train_timesteps: int = 100
|
||||||
beta_schedule: str = "squaredcos_cap_v2"
|
beta_schedule: str = "squaredcos_cap_v2"
|
||||||
beta_start: float = 0.0001
|
beta_start: float = 0.0001
|
||||||
|
|
|
@ -28,19 +28,6 @@ from lerobot.common.policies.utils import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMScheduler:
|
|
||||||
"""
|
|
||||||
Factory for noise scheduler instances of the requested type. All kwargs are passed
|
|
||||||
to the scheduler.
|
|
||||||
"""
|
|
||||||
if name == "DDPM":
|
|
||||||
return DDPMScheduler(**kwargs)
|
|
||||||
elif name == "DDIM":
|
|
||||||
return DDIMScheduler(**kwargs)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported noise scheduler type {name}")
|
|
||||||
|
|
||||||
|
|
||||||
class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
"""
|
"""
|
||||||
Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
|
Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
|
||||||
|
@ -140,6 +127,19 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
return {"loss": loss}
|
return {"loss": loss}
|
||||||
|
|
||||||
|
|
||||||
|
def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMScheduler:
|
||||||
|
"""
|
||||||
|
Factory for noise scheduler instances of the requested type. All kwargs are passed
|
||||||
|
to the scheduler.
|
||||||
|
"""
|
||||||
|
if name == "DDPM":
|
||||||
|
return DDPMScheduler(**kwargs)
|
||||||
|
elif name == "DDIM":
|
||||||
|
return DDIMScheduler(**kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported noise scheduler type {name}")
|
||||||
|
|
||||||
|
|
||||||
class DiffusionModel(nn.Module):
|
class DiffusionModel(nn.Module):
|
||||||
def __init__(self, config: DiffusionConfig):
|
def __init__(self, config: DiffusionConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
Loading…
Reference in New Issue