[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2025-04-03 05:57:51 +00:00
parent 31788f65dd
commit a0510c0f5e
2 changed files with 9 additions and 5 deletions

View File

@ -109,13 +109,16 @@ class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
return cosine_decay_schedule(current_step) return cosine_decay_schedule(current_step)
return LambdaLR(optimizer, lr_lambda, -1) return LambdaLR(optimizer, lr_lambda, -1)
@LRSchedulerConfig.register_subclass("constant_with_warmup") @LRSchedulerConfig.register_subclass("constant_with_warmup")
@dataclass @dataclass
class ConstantWithWarmupSchedulerConfig(LRSchedulerConfig): class ConstantWithWarmupSchedulerConfig(LRSchedulerConfig):
"""Used by DexVLA to train Stage2""" """Used by DexVLA to train Stage2"""
num_warmup_steps: int
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
num_warmup_steps: int
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
def lr_lambda(current_step): def lr_lambda(current_step):
def linear_warmup_schedule(current_step): def linear_warmup_schedule(current_step):
if current_step <= 0: if current_step <= 0:
@ -133,6 +136,7 @@ class ConstantWithWarmupSchedulerConfig(LRSchedulerConfig):
return LambdaLR(optimizer, lr_lambda, -1) return LambdaLR(optimizer, lr_lambda, -1)
def save_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> None: def save_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> None:
state_dict = scheduler.state_dict() state_dict = scheduler.state_dict()
write_json(state_dict, save_dir / SCHEDULER_STATE) write_json(state_dict, save_dir / SCHEDULER_STATE)