57 lines
1.9 KiB
Python
57 lines
1.9 KiB
Python
import abc
|
|
import math
|
|
from dataclasses import asdict, dataclass
|
|
|
|
import draccus
|
|
from torch.optim import Optimizer
|
|
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
|
|
|
|
|
@dataclass
|
|
class LRSchedulerConfig(draccus.ChoiceRegistry, abc.ABC):
|
|
num_warmup_steps: int
|
|
|
|
@property
|
|
def type(self) -> str:
|
|
return self.get_choice_name(self.__class__)
|
|
|
|
@abc.abstractmethod
|
|
def build(self, optimizer: Optimizer, num_training_steps: int) -> LRScheduler | None:
|
|
raise NotImplementedError
|
|
|
|
|
|
@LRSchedulerConfig.register_subclass("diffuser")
|
|
@dataclass
|
|
class DiffuserSchedulerConfig(LRSchedulerConfig):
|
|
name: str = "cosine"
|
|
num_warmup_steps: int | None = None
|
|
|
|
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
|
|
from diffusers.optimization import get_scheduler
|
|
|
|
kwargs = {**asdict(self), "num_training_steps": num_training_steps, "optimizer": optimizer}
|
|
return get_scheduler(**kwargs)
|
|
|
|
|
|
@LRSchedulerConfig.register_subclass("vqbet")
|
|
@dataclass
|
|
class VQBeTSchedulerConfig(LRSchedulerConfig):
|
|
num_warmup_steps: int
|
|
num_vqvae_training_steps: int
|
|
num_cycles: float = 0.5
|
|
|
|
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
|
|
def lr_lambda(current_step):
|
|
if current_step < self.num_vqvae_training_steps:
|
|
return float(1)
|
|
else:
|
|
adjusted_step = current_step - self.num_vqvae_training_steps
|
|
if adjusted_step < self.num_warmup_steps:
|
|
return float(adjusted_step) / float(max(1, self.num_warmup_steps))
|
|
progress = float(adjusted_step - self.num_warmup_steps) / float(
|
|
max(1, num_training_steps - self.num_warmup_steps)
|
|
)
|
|
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.num_cycles) * 2.0 * progress)))
|
|
|
|
return LambdaLR(optimizer, lr_lambda, -1)
|