[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
31788f65dd
commit
a0510c0f5e
|
@ -109,13 +109,16 @@ class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
|
|||
return cosine_decay_schedule(current_step)
|
||||
|
||||
return LambdaLR(optimizer, lr_lambda, -1)
|
||||
|
||||
|
||||
@LRSchedulerConfig.register_subclass("constant_with_warmup")
|
||||
@dataclass
|
||||
class ConstantWithWarmupSchedulerConfig(LRSchedulerConfig):
|
||||
"""Used by DexVLA to train Stage2"""
|
||||
num_warmup_steps: int
|
||||
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
|
||||
|
||||
num_warmup_steps: int
|
||||
|
||||
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
|
||||
def lr_lambda(current_step):
|
||||
def linear_warmup_schedule(current_step):
|
||||
if current_step <= 0:
|
||||
|
@ -133,6 +136,7 @@ class ConstantWithWarmupSchedulerConfig(LRSchedulerConfig):
|
|||
|
||||
return LambdaLR(optimizer, lr_lambda, -1)
|
||||
|
||||
|
||||
def save_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> None:
|
||||
state_dict = scheduler.state_dict()
|
||||
write_json(state_dict, save_dir / SCHEDULER_STATE)
|
||||
|
|
Loading…
Reference in New Issue