[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
31788f65dd
commit
a0510c0f5e
|
@ -109,13 +109,16 @@ class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
|
||||||
return cosine_decay_schedule(current_step)
|
return cosine_decay_schedule(current_step)
|
||||||
|
|
||||||
return LambdaLR(optimizer, lr_lambda, -1)
|
return LambdaLR(optimizer, lr_lambda, -1)
|
||||||
|
|
||||||
|
|
||||||
@LRSchedulerConfig.register_subclass("constant_with_warmup")
|
@LRSchedulerConfig.register_subclass("constant_with_warmup")
|
||||||
@dataclass
|
@dataclass
|
||||||
class ConstantWithWarmupSchedulerConfig(LRSchedulerConfig):
|
class ConstantWithWarmupSchedulerConfig(LRSchedulerConfig):
|
||||||
"""Used by DexVLA to train Stage2"""
|
"""Used by DexVLA to train Stage2"""
|
||||||
num_warmup_steps: int
|
|
||||||
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
|
|
||||||
|
|
||||||
|
num_warmup_steps: int
|
||||||
|
|
||||||
|
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
|
||||||
def lr_lambda(current_step):
|
def lr_lambda(current_step):
|
||||||
def linear_warmup_schedule(current_step):
|
def linear_warmup_schedule(current_step):
|
||||||
if current_step <= 0:
|
if current_step <= 0:
|
||||||
|
@ -133,6 +136,7 @@ class ConstantWithWarmupSchedulerConfig(LRSchedulerConfig):
|
||||||
|
|
||||||
return LambdaLR(optimizer, lr_lambda, -1)
|
return LambdaLR(optimizer, lr_lambda, -1)
|
||||||
|
|
||||||
|
|
||||||
def save_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> None:
|
def save_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> None:
|
||||||
state_dict = scheduler.state_dict()
|
state_dict = scheduler.state_dict()
|
||||||
write_json(state_dict, save_dir / SCHEDULER_STATE)
|
write_json(state_dict, save_dir / SCHEDULER_STATE)
|
||||||
|
|
Loading…
Reference in New Issue