remove redundant comment, change scheduler name
This commit is contained in:
parent
0b663243e3
commit
8ee1e53fee
|
@ -490,10 +490,9 @@ class VQBeTScheduler:
|
||||||
def __init__(self, optimizer, cfg):
|
def __init__(self, optimizer, cfg):
|
||||||
from diffusers.optimization import get_scheduler
|
from diffusers.optimization import get_scheduler
|
||||||
self.discretize_step = cfg.training.discretize_step
|
self.discretize_step = cfg.training.discretize_step
|
||||||
# self.offline_steps = cfg.training.offline_steps
|
|
||||||
self.optimizing_step = 0
|
self.optimizing_step = 0
|
||||||
|
|
||||||
self.lr_scheduler1 = get_scheduler(
|
self.lr_scheduler = get_scheduler(
|
||||||
cfg.training.lr_scheduler,
|
cfg.training.lr_scheduler,
|
||||||
optimizer=optimizer.encoder_optimizer,
|
optimizer=optimizer.encoder_optimizer,
|
||||||
num_warmup_steps=cfg.training.lr_warmup_steps,
|
num_warmup_steps=cfg.training.lr_warmup_steps,
|
||||||
|
@ -504,9 +503,7 @@ class VQBeTScheduler:
|
||||||
def step(self):
|
def step(self):
|
||||||
self.optimizing_step +=1
|
self.optimizing_step +=1
|
||||||
if self.optimizing_step >= self.discretize_step:
|
if self.optimizing_step >= self.discretize_step:
|
||||||
self.lr_scheduler1.step()
|
self.lr_scheduler.step()
|
||||||
# self.lr_scheduler2.step()
|
|
||||||
# self.lr_scheduler3.step()
|
|
||||||
|
|
||||||
class VQBeTRgbEncoder(nn.Module):
|
class VQBeTRgbEncoder(nn.Module):
|
||||||
"""Encoder an RGB image into a 1D feature vector.
|
"""Encoder an RGB image into a 1D feature vector.
|
||||||
|
|
Loading…
Reference in New Issue