diff --git a/lerobot/common/optim/schedulers.py b/lerobot/common/optim/schedulers.py index 7e158394..486a0b99 100644 --- a/lerobot/common/optim/schedulers.py +++ b/lerobot/common/optim/schedulers.py @@ -110,6 +110,29 @@ class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig): return LambdaLR(optimizer, lr_lambda, -1) +@LRSchedulerConfig.register_subclass("constant_with_warmup") +@dataclass +class ConstantWithWarmupSchedulerConfig(LRSchedulerConfig): + """Used by DexVLA to train Stage2""" + num_warmup_steps: int + def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR: + + def lr_lambda(current_step): + def linear_warmup_schedule(current_step): + if current_step <= 0: + return 1 / (self.num_warmup_steps + 1) + frac = 1 - current_step / self.num_warmup_steps + return (1 / (self.num_warmup_steps + 1) - 1) * frac + 1 + + def constant_schedule(current_step): + return 1 + + if current_step < self.num_warmup_steps: + return linear_warmup_schedule(current_step) + + return constant_schedule(current_step) + + return LambdaLR(optimizer, lr_lambda, -1) def save_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> None: state_dict = scheduler.state_dict()