Add a constant_warmup lr scheduler for dexvla

This commit is contained in:
lesjie-wen 2025-03-18 15:16:08 +08:00
parent 8998ba3bb5
commit 105650522a
1 changed files with 23 additions and 0 deletions

View File

@ -110,6 +110,29 @@ class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
return LambdaLR(optimizer, lr_lambda, -1) return LambdaLR(optimizer, lr_lambda, -1)
@LRSchedulerConfig.register_subclass("constant_with_warmup")
@dataclass
class ConstantWithWarmupSchedulerConfig(LRSchedulerConfig):
"""Used by DexVLA to train Stage2"""
num_warmup_steps: int
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
def lr_lambda(current_step):
def linear_warmup_schedule(current_step):
if current_step <= 0:
return 1 / (self.num_warmup_steps + 1)
frac = 1 - current_step / self.num_warmup_steps
return (1 / (self.num_warmup_steps + 1) - 1) * frac + 1
def constant_schedule(current_step):
return 1
if current_step < self.num_warmup_steps:
return linear_warmup_schedule(current_step)
return constant_schedule(current_step)
return LambdaLR(optimizer, lr_lambda, -1)
def save_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> None: def save_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> None:
state_dict = scheduler.state_dict() state_dict = scheduler.state_dict()