Add a constant_warmup lr scheduler for dexvla
This commit is contained in:
parent
8998ba3bb5
commit
105650522a
|
@ -110,6 +110,29 @@ class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
|
||||||
|
|
||||||
return LambdaLR(optimizer, lr_lambda, -1)
|
return LambdaLR(optimizer, lr_lambda, -1)
|
||||||
|
|
||||||
|
@LRSchedulerConfig.register_subclass("constant_with_warmup")
|
||||||
|
@dataclass
|
||||||
|
class ConstantWithWarmupSchedulerConfig(LRSchedulerConfig):
|
||||||
|
"""Used by DexVLA to train Stage2"""
|
||||||
|
num_warmup_steps: int
|
||||||
|
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
|
||||||
|
|
||||||
|
def lr_lambda(current_step):
|
||||||
|
def linear_warmup_schedule(current_step):
|
||||||
|
if current_step <= 0:
|
||||||
|
return 1 / (self.num_warmup_steps + 1)
|
||||||
|
frac = 1 - current_step / self.num_warmup_steps
|
||||||
|
return (1 / (self.num_warmup_steps + 1) - 1) * frac + 1
|
||||||
|
|
||||||
|
def constant_schedule(current_step):
|
||||||
|
return 1
|
||||||
|
|
||||||
|
if current_step < self.num_warmup_steps:
|
||||||
|
return linear_warmup_schedule(current_step)
|
||||||
|
|
||||||
|
return constant_schedule(current_step)
|
||||||
|
|
||||||
|
return LambdaLR(optimizer, lr_lambda, -1)
|
||||||
|
|
||||||
def save_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> None:
|
def save_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> None:
|
||||||
state_dict = scheduler.state_dict()
|
state_dict = scheduler.state_dict()
|
||||||
|
|
Loading…
Reference in New Issue