diff --git a/lerobot/common/policies/lr_schedulers.py b/lerobot/common/policies/lr_schedulers.py new file mode 100644 index 00000000..ca223e18 --- /dev/null +++ b/lerobot/common/policies/lr_schedulers.py @@ -0,0 +1,161 @@ +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""PyTorch learning rate schedulers. + +Note: Most of this code was copied as is from the diffusers and transformers libraries with removal of +certain features for simplication. +""" + +import math +from enum import Enum +from functools import partial +from typing import Optional, Union + +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LambdaLR + + +class SchedulerType(Enum): + COSINE = "cosine" + INVERSE_SQRT = "inverse_sqrt" + + +def get_cosine_schedule_with_warmup( + optimizer: Optimizer, + num_warmup_steps: int, + num_training_steps: int, + num_cycles: float = 0.5, + last_epoch: int = -1, +) -> LambdaLR: + """ + Create a schedule with a learning rate that decreases following the values of the cosine function between the + initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the + initial lr set in the optimizer. + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + def lr_lambda(current_step): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + progress = float(current_step - num_warmup_steps) / float( + max(1, num_training_steps - num_warmup_steps) + ) + return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) + + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def _get_inverse_sqrt_schedule_lr_lambda(current_step: int, *, num_warmup_steps: int, timescale: int = None): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + shift = timescale - num_warmup_steps + decay = 1.0 / math.sqrt((current_step + shift) / timescale) + return decay + + +def get_inverse_sqrt_schedule( + optimizer: Optimizer, num_warmup_steps: int, timescale: int = None, last_epoch: int = -1 +): + """ + Create a schedule with an inverse square-root learning rate, from the initial lr set in the optimizer, after a + warmup period which increases lr linearly from 0 to the initial lr set in the optimizer. + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + timescale (`int`, *optional*, defaults to `num_warmup_steps`): + Time scale. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + # Note: this implementation is adapted from + # https://github.com/google-research/big_vision/blob/f071ce68852d56099437004fd70057597a95f6ef/big_vision/utils.py#L930 + + if timescale is None: + timescale = num_warmup_steps or 10_000 + + lr_lambda = partial( + _get_inverse_sqrt_schedule_lr_lambda, num_warmup_steps=num_warmup_steps, timescale=timescale + ) + return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) + + +TYPE_TO_SCHEDULER_FUNCTION = { + SchedulerType.COSINE: get_cosine_schedule_with_warmup, + SchedulerType.INVERSE_SQRT: get_inverse_sqrt_schedule, +} + + +def get_scheduler( + name: Union[str, SchedulerType], + optimizer: Optimizer, + num_warmup_steps: Optional[int] = None, + num_training_steps: Optional[int] = None, + last_epoch: int = -1, +) -> LambdaLR: + """ + Unified API to get any scheduler from its name. + Args: + name (`str` or `SchedulerType`): + The name of the scheduler to use. + optimizer (`torch.optim.Optimizer`): + The optimizer that will be used during training. + num_warmup_steps (`int`, *optional*): + The number of warmup steps to do. This is not required by all schedulers (hence the argument being + optional), the function will raise an error if it's unset and the scheduler type requires it. + num_training_steps (`int``, *optional*): + The number of training steps to do. This is not required by all schedulers (hence the argument being + optional), the function will raise an error if it's unset and the scheduler type requires it. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + """ + name = SchedulerType(name) + if name not in TYPE_TO_SCHEDULER_FUNCTION: + raise ValueError( + f"Unsupported scheduler {name}, expected one of {list(TYPE_TO_SCHEDULER_FUNCTION.keys())}" + ) + + schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] + + # All other schedulers require `num_warmup_steps` + if num_warmup_steps is None: + raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") + + # All other schedulers require `num_training_steps` + if num_training_steps is None: + raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") + + if name == SchedulerType.INVERSE_SQRT: + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps) + + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + last_epoch=last_epoch, + ) diff --git a/tests/test_lr_schedulers.py b/tests/test_lr_schedulers.py new file mode 100644 index 00000000..7b0bcac2 --- /dev/null +++ b/tests/test_lr_schedulers.py @@ -0,0 +1,94 @@ +import math + +import pytest +import torch + +from lerobot.common.policies.lr_schedulers import get_scheduler + + +def test_get_lr_scheduler(): + optimizer = torch.optim.AdamW(torch.nn.Linear(10, 10).parameters(), lr=1e-4) + + lr_scheduler = get_scheduler("cosine", optimizer, num_warmup_steps=500, num_training_steps=2000) + assert lr_scheduler is not None + assert lr_scheduler.__class__.__name__ == "LambdaLR" + + lr_scheduler = get_scheduler("inverse_sqrt", optimizer, num_warmup_steps=500, num_training_steps=2000) + assert lr_scheduler is not None + assert lr_scheduler.__class__.__name__ == "LambdaLR" + + with pytest.raises(ValueError): + get_scheduler("invalid", 100, 1000) + + +def test_cosine_lr_scheduler(): + intervals = 250 + num_warmup_steps = 500 + num_training_steps = 2000 + recorded_lrs_at_intervals = [ + 2.0e-7, + 5.0200000e-5, + 9.9999890e-5, + 9.3248815e-5, + 7.4909255e-5, + 4.9895280e-5, + 2.4909365e-5, + 6.6464649e-6, + ] + optimizer = torch.optim.AdamW( + torch.nn.Linear(10, 10).parameters(), lr=1e-4, betas=(0.95, 0.999), eps=1e-8, weight_decay=1e-6 + ) + + lr_scheduler = get_scheduler( + "cosine", optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps + ) + assert lr_scheduler.get_last_lr()[0] == 0.0 + + for i in range(num_training_steps - intervals): + lr_scheduler.step() + if i % intervals == 0: + recorded = recorded_lrs_at_intervals.pop(0) + assert math.isclose( + lr_scheduler.get_last_lr()[0], recorded, abs_tol=1e-7 + ), f"LR value mismatch at step {i}: {lr_scheduler.get_last_lr()[0]} vs. {recorded}" + + lr_scheduler.step() + assert math.isclose( + lr_scheduler.get_last_lr()[0], recorded_lrs_at_intervals[-1], abs_tol=1e-7 + ), f"LR value mismatch at step {num_training_steps}: {lr_scheduler.get_last_lr()[0]} vs. {recorded_lrs_at_intervals[-1]}" + + +def test_inverse_sqrt_lr_scheduler(): + intervals = 250 + num_warmup_steps = 500 + num_training_steps = 2000 + recorded_lrs_at_intervals = [ + 2.0e-7, + 5.02e-5, + 9.9900150e-05, + 8.1595279e-05, + 7.0675349e-05, + 6.3220270e-05, + 5.7715792e-5, + 5.3436983e-5, + ] + optimizer = torch.optim.AdamW( + torch.nn.Linear(10, 10).parameters(), lr=1e-4, betas=(0.95, 0.999), eps=1e-8, weight_decay=1e-6 + ) + + lr_scheduler = get_scheduler( + "inverse_sqrt", optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps + ) + + for i in range(num_training_steps - intervals): + lr_scheduler.step() + if i % intervals == 0: + recorded = recorded_lrs_at_intervals.pop(0) + assert math.isclose( + lr_scheduler.get_last_lr()[0], recorded, abs_tol=1e-7 + ), f"LR value mismatch at step {i}: {lr_scheduler.get_last_lr()[0]} vs. {recorded}" + + lr_scheduler.step() + assert math.isclose( + lr_scheduler.get_last_lr()[0], recorded_lrs_at_intervals[-1], abs_tol=1e-7 + ), f"LR value mismatch at step {num_training_steps}: {lr_scheduler.get_last_lr()[0]} vs. {recorded_lrs_at_intervals[-1]}"