Port LR Schedulers
This commit is contained in:
parent
aca424a481
commit
90e6df3ecb
|
@ -0,0 +1,166 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""PyTorch learning rate schedulers.
|
||||||
|
|
||||||
|
Note: Most of this code was copied as is from the diffusers and transformers libraries with removal of
|
||||||
|
certain features for simplication.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
from enum import Enum
|
||||||
|
from functools import partial
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
from torch.optim.lr_scheduler import LambdaLR
|
||||||
|
|
||||||
|
|
||||||
|
class SchedulerType(Enum):
|
||||||
|
COSINE = "cosine"
|
||||||
|
INVERSE_SQRT = "inverse_sqrt"
|
||||||
|
|
||||||
|
|
||||||
|
def get_cosine_schedule_with_warmup(
|
||||||
|
optimizer: Optimizer,
|
||||||
|
num_warmup_steps: int,
|
||||||
|
num_training_steps: int,
|
||||||
|
num_cycles: float = 0.5,
|
||||||
|
last_epoch: int = -1,
|
||||||
|
) -> LambdaLR:
|
||||||
|
"""
|
||||||
|
Create a schedule with a learning rate that decreases following the values of the cosine function between the
|
||||||
|
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
|
||||||
|
initial lr set in the optimizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer ([`~torch.optim.Optimizer`]):
|
||||||
|
The optimizer for which to schedule the learning rate.
|
||||||
|
num_warmup_steps (`int`):
|
||||||
|
The number of steps for the warmup phase.
|
||||||
|
num_training_steps (`int`):
|
||||||
|
The total number of training steps.
|
||||||
|
last_epoch (`int`, *optional*, defaults to -1):
|
||||||
|
The index of the last epoch when resuming training.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def lr_lambda(current_step):
|
||||||
|
if current_step < num_warmup_steps:
|
||||||
|
return float(current_step) / float(max(1, num_warmup_steps))
|
||||||
|
progress = float(current_step - num_warmup_steps) / float(
|
||||||
|
max(1, num_training_steps - num_warmup_steps)
|
||||||
|
)
|
||||||
|
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
|
||||||
|
|
||||||
|
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_inverse_sqrt_schedule_lr_lambda(current_step: int, *, num_warmup_steps: int, timescale: int = None):
|
||||||
|
if current_step < num_warmup_steps:
|
||||||
|
return float(current_step) / float(max(1, num_warmup_steps))
|
||||||
|
shift = timescale - num_warmup_steps
|
||||||
|
decay = 1.0 / math.sqrt((current_step + shift) / timescale)
|
||||||
|
return decay
|
||||||
|
|
||||||
|
|
||||||
|
def get_inverse_sqrt_schedule(
|
||||||
|
optimizer: Optimizer, num_warmup_steps: int, timescale: int = None, last_epoch: int = -1
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create a schedule with an inverse square-root learning rate, from the initial lr set in the optimizer, after a
|
||||||
|
warmup period which increases lr linearly from 0 to the initial lr set in the optimizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer ([`~torch.optim.Optimizer`]):
|
||||||
|
The optimizer for which to schedule the learning rate.
|
||||||
|
num_warmup_steps (`int`):
|
||||||
|
The number of steps for the warmup phase.
|
||||||
|
timescale (`int`, *optional*, defaults to `num_warmup_steps`):
|
||||||
|
Time scale.
|
||||||
|
last_epoch (`int`, *optional*, defaults to -1):
|
||||||
|
The index of the last epoch when resuming training.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
||||||
|
"""
|
||||||
|
# Note: this implementation is adapted from
|
||||||
|
# https://github.com/google-research/big_vision/blob/f071ce68852d56099437004fd70057597a95f6ef/big_vision/utils.py#L930
|
||||||
|
|
||||||
|
if timescale is None:
|
||||||
|
timescale = num_warmup_steps or 10_000
|
||||||
|
|
||||||
|
lr_lambda = partial(
|
||||||
|
_get_inverse_sqrt_schedule_lr_lambda, num_warmup_steps=num_warmup_steps, timescale=timescale
|
||||||
|
)
|
||||||
|
return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
|
||||||
|
|
||||||
|
|
||||||
|
TYPE_TO_SCHEDULER_FUNCTION = {
|
||||||
|
SchedulerType.COSINE: get_cosine_schedule_with_warmup,
|
||||||
|
SchedulerType.INVERSE_SQRT: get_inverse_sqrt_schedule,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_scheduler(
|
||||||
|
name: Union[str, SchedulerType],
|
||||||
|
optimizer: Optimizer,
|
||||||
|
num_warmup_steps: Optional[int] = None,
|
||||||
|
num_training_steps: Optional[int] = None,
|
||||||
|
last_epoch: int = -1,
|
||||||
|
) -> LambdaLR:
|
||||||
|
"""
|
||||||
|
Unified API to get any scheduler from its name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (`str` or `SchedulerType`):
|
||||||
|
The name of the scheduler to use.
|
||||||
|
optimizer (`torch.optim.Optimizer`):
|
||||||
|
The optimizer that will be used during training.
|
||||||
|
num_warmup_steps (`int`, *optional*):
|
||||||
|
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
|
||||||
|
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
||||||
|
num_training_steps (`int``, *optional*):
|
||||||
|
The number of training steps to do. This is not required by all schedulers (hence the argument being
|
||||||
|
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
||||||
|
last_epoch (`int`, *optional*, defaults to -1):
|
||||||
|
The index of the last epoch when resuming training.
|
||||||
|
"""
|
||||||
|
name = SchedulerType(name)
|
||||||
|
if name not in TYPE_TO_SCHEDULER_FUNCTION:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported scheduler {name}, expected one of {list(TYPE_TO_SCHEDULER_FUNCTION.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
||||||
|
|
||||||
|
# All other schedulers require `num_warmup_steps`
|
||||||
|
if num_warmup_steps is None:
|
||||||
|
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
|
||||||
|
|
||||||
|
# All other schedulers require `num_training_steps`
|
||||||
|
if num_training_steps is None:
|
||||||
|
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
|
||||||
|
|
||||||
|
if name == SchedulerType.INVERSE_SQRT:
|
||||||
|
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
|
||||||
|
|
||||||
|
return schedule_func(
|
||||||
|
optimizer,
|
||||||
|
num_warmup_steps=num_warmup_steps,
|
||||||
|
num_training_steps=num_training_steps,
|
||||||
|
last_epoch=last_epoch,
|
||||||
|
)
|
|
@ -69,7 +69,7 @@ def make_optimizer_and_scheduler(cfg, policy):
|
||||||
cfg.training.adam_eps,
|
cfg.training.adam_eps,
|
||||||
cfg.training.adam_weight_decay,
|
cfg.training.adam_weight_decay,
|
||||||
)
|
)
|
||||||
from diffusers.optimization import get_scheduler
|
from transformers.optimization import get_scheduler
|
||||||
|
|
||||||
lr_scheduler = get_scheduler(
|
lr_scheduler = get_scheduler(
|
||||||
cfg.training.lr_scheduler,
|
cfg.training.lr_scheduler,
|
||||||
|
|
|
@ -0,0 +1,72 @@
|
||||||
|
import math
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lerobot.common.policies.lr_schedulers import get_scheduler
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_lr_scheduler():
|
||||||
|
optimizer = torch.optim.AdamW(torch.nn.Linear(10, 10).parameters(), lr=1e-4)
|
||||||
|
|
||||||
|
lr_scheduler = get_scheduler("cosine", optimizer, num_warmup_steps=500, num_training_steps=2000)
|
||||||
|
assert lr_scheduler is not None
|
||||||
|
assert lr_scheduler.__class__.__name__ == "LambdaLR"
|
||||||
|
|
||||||
|
lr_scheduler = get_scheduler("inverse_sqrt", optimizer, num_warmup_steps=500, num_training_steps=2000)
|
||||||
|
assert lr_scheduler is not None
|
||||||
|
assert lr_scheduler.__class__.__name__ == "LambdaLR"
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
get_scheduler("invalid", 100, 1000)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cosine_lr_scheduler():
|
||||||
|
intervals = 250
|
||||||
|
num_warmup_steps = 500
|
||||||
|
num_training_steps = 2000
|
||||||
|
recorded_lrs_at_intervals = [2.0e-7, 5.0e-5, 1.0e-4, 9.3e-5, 7.5e-5, 5.0e-5, 2.5e-5, 6.6e-6]
|
||||||
|
optimizer = torch.optim.AdamW(
|
||||||
|
torch.nn.Linear(10, 10).parameters(), lr=1e-4, betas=(0.95, 0.999), eps=1e-8, weight_decay=1e-6
|
||||||
|
)
|
||||||
|
|
||||||
|
lr_scheduler = get_scheduler(
|
||||||
|
"cosine", optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps
|
||||||
|
)
|
||||||
|
assert lr_scheduler.get_last_lr()[0] == 0.0
|
||||||
|
|
||||||
|
for i in range(num_training_steps):
|
||||||
|
optimizer.step()
|
||||||
|
lr_scheduler.step()
|
||||||
|
if i % intervals == 0:
|
||||||
|
recorded = recorded_lrs_at_intervals.pop(0)
|
||||||
|
assert math.isclose(
|
||||||
|
lr_scheduler.get_last_lr()[0], recorded
|
||||||
|
), f"LR value mismatch at step {i}: {lr_scheduler.get_last_lr()[0]} vs. {recorded}"
|
||||||
|
|
||||||
|
assert lr_scheduler.get_last_lr()[0] == recorded_lrs_at_intervals.pop(0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_inverse_sqrt_lr_scheduler():
|
||||||
|
intervals = 250
|
||||||
|
num_warmup_steps = 500
|
||||||
|
num_training_steps = 2000
|
||||||
|
recorded_lrs_at_intervals = [2.0e-7, 5.0e-5, 1.0e-4, 8.2e-5, 7.1e-5, 6.3e-5, 5.8e-5, 5.3e-5]
|
||||||
|
optimizer = torch.optim.AdamW(
|
||||||
|
torch.nn.Linear(10, 10).parameters(), lr=1e-4, betas=(0.95, 0.999), eps=1e-8, weight_decay=1e-6
|
||||||
|
)
|
||||||
|
|
||||||
|
lr_scheduler = get_scheduler(
|
||||||
|
"inverse_sqrt", optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(num_training_steps):
|
||||||
|
optimizer.step()
|
||||||
|
lr_scheduler.step()
|
||||||
|
if i % intervals == 0:
|
||||||
|
recorded = recorded_lrs_at_intervals.pop(0)
|
||||||
|
assert math.isclose(
|
||||||
|
lr_scheduler.get_last_lr()[0], recorded
|
||||||
|
), f"LR value mismatch at step {i}: {lr_scheduler.get_last_lr()[0]} vs. {recorded}"
|
||||||
|
|
||||||
|
assert lr_scheduler.get_last_lr()[0] == recorded_lrs_at_intervals.pop(0)
|
Loading…
Reference in New Issue