add schedulers and test files

This commit is contained in:
Akshay Kashyap 2024-05-29 22:57:32 -07:00
parent a30e4e523a
commit 7baa282af8
2 changed files with 255 additions and 0 deletions

View File

@ -0,0 +1,161 @@
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch learning rate schedulers.
Note: Most of this code was copied as is from the diffusers and transformers libraries with removal of
certain features for simplication.
"""
import math
from enum import Enum
from functools import partial
from typing import Optional, Union
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR
class SchedulerType(Enum):
COSINE = "cosine"
INVERSE_SQRT = "inverse_sqrt"
def get_cosine_schedule_with_warmup(
optimizer: Optimizer,
num_warmup_steps: int,
num_training_steps: int,
num_cycles: float = 0.5,
last_epoch: int = -1,
) -> LambdaLR:
"""
Create a schedule with a learning rate that decreases following the values of the cosine function between the
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
initial lr set in the optimizer.
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
num_warmup_steps (`int`):
The number of steps for the warmup phase.
num_training_steps (`int`):
The total number of training steps.
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
def lr_lambda(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
progress = float(current_step - num_warmup_steps) / float(
max(1, num_training_steps - num_warmup_steps)
)
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
return LambdaLR(optimizer, lr_lambda, last_epoch)
def _get_inverse_sqrt_schedule_lr_lambda(current_step: int, *, num_warmup_steps: int, timescale: int = None):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
shift = timescale - num_warmup_steps
decay = 1.0 / math.sqrt((current_step + shift) / timescale)
return decay
def get_inverse_sqrt_schedule(
optimizer: Optimizer, num_warmup_steps: int, timescale: int = None, last_epoch: int = -1
):
"""
Create a schedule with an inverse square-root learning rate, from the initial lr set in the optimizer, after a
warmup period which increases lr linearly from 0 to the initial lr set in the optimizer.
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
num_warmup_steps (`int`):
The number of steps for the warmup phase.
timescale (`int`, *optional*, defaults to `num_warmup_steps`):
Time scale.
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
# Note: this implementation is adapted from
# https://github.com/google-research/big_vision/blob/f071ce68852d56099437004fd70057597a95f6ef/big_vision/utils.py#L930
if timescale is None:
timescale = num_warmup_steps or 10_000
lr_lambda = partial(
_get_inverse_sqrt_schedule_lr_lambda, num_warmup_steps=num_warmup_steps, timescale=timescale
)
return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
TYPE_TO_SCHEDULER_FUNCTION = {
SchedulerType.COSINE: get_cosine_schedule_with_warmup,
SchedulerType.INVERSE_SQRT: get_inverse_sqrt_schedule,
}
def get_scheduler(
name: Union[str, SchedulerType],
optimizer: Optimizer,
num_warmup_steps: Optional[int] = None,
num_training_steps: Optional[int] = None,
last_epoch: int = -1,
) -> LambdaLR:
"""
Unified API to get any scheduler from its name.
Args:
name (`str` or `SchedulerType`):
The name of the scheduler to use.
optimizer (`torch.optim.Optimizer`):
The optimizer that will be used during training.
num_warmup_steps (`int`, *optional*):
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it.
num_training_steps (`int``, *optional*):
The number of training steps to do. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it.
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
"""
name = SchedulerType(name)
if name not in TYPE_TO_SCHEDULER_FUNCTION:
raise ValueError(
f"Unsupported scheduler {name}, expected one of {list(TYPE_TO_SCHEDULER_FUNCTION.keys())}"
)
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
# All other schedulers require `num_warmup_steps`
if num_warmup_steps is None:
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
# All other schedulers require `num_training_steps`
if num_training_steps is None:
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
if name == SchedulerType.INVERSE_SQRT:
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
return schedule_func(
optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
last_epoch=last_epoch,
)

View File

@ -0,0 +1,94 @@
import math
import pytest
import torch
from lerobot.common.policies.lr_schedulers import get_scheduler
def test_get_lr_scheduler():
optimizer = torch.optim.AdamW(torch.nn.Linear(10, 10).parameters(), lr=1e-4)
lr_scheduler = get_scheduler("cosine", optimizer, num_warmup_steps=500, num_training_steps=2000)
assert lr_scheduler is not None
assert lr_scheduler.__class__.__name__ == "LambdaLR"
lr_scheduler = get_scheduler("inverse_sqrt", optimizer, num_warmup_steps=500, num_training_steps=2000)
assert lr_scheduler is not None
assert lr_scheduler.__class__.__name__ == "LambdaLR"
with pytest.raises(ValueError):
get_scheduler("invalid", 100, 1000)
def test_cosine_lr_scheduler():
intervals = 250
num_warmup_steps = 500
num_training_steps = 2000
recorded_lrs_at_intervals = [
2.0e-7,
5.0200000e-5,
9.9999890e-5,
9.3248815e-5,
7.4909255e-5,
4.9895280e-5,
2.4909365e-5,
6.6464649e-6,
]
optimizer = torch.optim.AdamW(
torch.nn.Linear(10, 10).parameters(), lr=1e-4, betas=(0.95, 0.999), eps=1e-8, weight_decay=1e-6
)
lr_scheduler = get_scheduler(
"cosine", optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps
)
assert lr_scheduler.get_last_lr()[0] == 0.0
for i in range(num_training_steps - intervals):
lr_scheduler.step()
if i % intervals == 0:
recorded = recorded_lrs_at_intervals.pop(0)
assert math.isclose(
lr_scheduler.get_last_lr()[0], recorded, abs_tol=1e-7
), f"LR value mismatch at step {i}: {lr_scheduler.get_last_lr()[0]} vs. {recorded}"
lr_scheduler.step()
assert math.isclose(
lr_scheduler.get_last_lr()[0], recorded_lrs_at_intervals[-1], abs_tol=1e-7
), f"LR value mismatch at step {num_training_steps}: {lr_scheduler.get_last_lr()[0]} vs. {recorded_lrs_at_intervals[-1]}"
def test_inverse_sqrt_lr_scheduler():
intervals = 250
num_warmup_steps = 500
num_training_steps = 2000
recorded_lrs_at_intervals = [
2.0e-7,
5.02e-5,
9.9900150e-05,
8.1595279e-05,
7.0675349e-05,
6.3220270e-05,
5.7715792e-5,
5.3436983e-5,
]
optimizer = torch.optim.AdamW(
torch.nn.Linear(10, 10).parameters(), lr=1e-4, betas=(0.95, 0.999), eps=1e-8, weight_decay=1e-6
)
lr_scheduler = get_scheduler(
"inverse_sqrt", optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps
)
for i in range(num_training_steps - intervals):
lr_scheduler.step()
if i % intervals == 0:
recorded = recorded_lrs_at_intervals.pop(0)
assert math.isclose(
lr_scheduler.get_last_lr()[0], recorded, abs_tol=1e-7
), f"LR value mismatch at step {i}: {lr_scheduler.get_last_lr()[0]} vs. {recorded}"
lr_scheduler.step()
assert math.isclose(
lr_scheduler.get_last_lr()[0], recorded_lrs_at_intervals[-1], abs_tol=1e-7
), f"LR value mismatch at step {num_training_steps}: {lr_scheduler.get_last_lr()[0]} vs. {recorded_lrs_at_intervals[-1]}"