lerobot/lerobot/common/policies/lr_schedulers.py

162 lines
6.2 KiB
Python

# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch learning rate schedulers.
Note: Most of this code was copied as is from the diffusers and transformers libraries with removal of
certain features for simplication.
"""
import math
from enum import Enum
from functools import partial
from typing import Optional, Union
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR
class SchedulerType(Enum):
COSINE = "cosine"
INVERSE_SQRT = "inverse_sqrt"
def get_cosine_schedule_with_warmup(
optimizer: Optimizer,
num_warmup_steps: int,
num_training_steps: int,
num_cycles: float = 0.5,
last_epoch: int = -1,
) -> LambdaLR:
"""
Create a schedule with a learning rate that decreases following the values of the cosine function between the
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
initial lr set in the optimizer.
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
num_warmup_steps (`int`):
The number of steps for the warmup phase.
num_training_steps (`int`):
The total number of training steps.
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
def lr_lambda(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
progress = float(current_step - num_warmup_steps) / float(
max(1, num_training_steps - num_warmup_steps)
)
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
return LambdaLR(optimizer, lr_lambda, last_epoch)
def _get_inverse_sqrt_schedule_lr_lambda(current_step: int, *, num_warmup_steps: int, timescale: int = None):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
shift = timescale - num_warmup_steps
decay = 1.0 / math.sqrt((current_step + shift) / timescale)
return decay
def get_inverse_sqrt_schedule(
optimizer: Optimizer, num_warmup_steps: int, timescale: int = None, last_epoch: int = -1
):
"""
Create a schedule with an inverse square-root learning rate, from the initial lr set in the optimizer, after a
warmup period which increases lr linearly from 0 to the initial lr set in the optimizer.
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
num_warmup_steps (`int`):
The number of steps for the warmup phase.
timescale (`int`, *optional*, defaults to `num_warmup_steps`):
Time scale.
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
# Note: this implementation is adapted from
# https://github.com/google-research/big_vision/blob/f071ce68852d56099437004fd70057597a95f6ef/big_vision/utils.py#L930
if timescale is None:
timescale = num_warmup_steps or 10_000
lr_lambda = partial(
_get_inverse_sqrt_schedule_lr_lambda, num_warmup_steps=num_warmup_steps, timescale=timescale
)
return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
TYPE_TO_SCHEDULER_FUNCTION = {
SchedulerType.COSINE: get_cosine_schedule_with_warmup,
SchedulerType.INVERSE_SQRT: get_inverse_sqrt_schedule,
}
def get_scheduler(
name: Union[str, SchedulerType],
optimizer: Optimizer,
num_warmup_steps: Optional[int] = None,
num_training_steps: Optional[int] = None,
last_epoch: int = -1,
) -> LambdaLR:
"""
Unified API to get any scheduler from its name.
Args:
name (`str` or `SchedulerType`):
The name of the scheduler to use.
optimizer (`torch.optim.Optimizer`):
The optimizer that will be used during training.
num_warmup_steps (`int`, *optional*):
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it.
num_training_steps (`int``, *optional*):
The number of training steps to do. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it.
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
"""
name = SchedulerType(name)
if name not in TYPE_TO_SCHEDULER_FUNCTION:
raise ValueError(
f"Unsupported scheduler {name}, expected one of {list(TYPE_TO_SCHEDULER_FUNCTION.keys())}"
)
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
# All other schedulers require `num_warmup_steps`
if num_warmup_steps is None:
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
# All other schedulers require `num_training_steps`
if num_training_steps is None:
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
if name == SchedulerType.INVERSE_SQRT:
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
return schedule_func(
optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
last_epoch=last_epoch,
)