diff --git a/lerobot/common/policies/lr_schedulers.py b/lerobot/common/policies/lr_schedulers.py new file mode 100644 index 00000000..93569404 --- /dev/null +++ b/lerobot/common/policies/lr_schedulers.py @@ -0,0 +1,166 @@ +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""PyTorch learning rate schedulers. + +Note: Most of this code was copied as is from the diffusers and transformers libraries with removal of +certain features for simplication. +""" + +import math +from enum import Enum +from functools import partial +from typing import Optional, Union + +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LambdaLR + + +class SchedulerType(Enum): + COSINE = "cosine" + INVERSE_SQRT = "inverse_sqrt" + + +def get_cosine_schedule_with_warmup( + optimizer: Optimizer, + num_warmup_steps: int, + num_training_steps: int, + num_cycles: float = 0.5, + last_epoch: int = -1, +) -> LambdaLR: + """ + Create a schedule with a learning rate that decreases following the values of the cosine function between the + initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the + initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + def lr_lambda(current_step): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + progress = float(current_step - num_warmup_steps) / float( + max(1, num_training_steps - num_warmup_steps) + ) + return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) + + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def _get_inverse_sqrt_schedule_lr_lambda(current_step: int, *, num_warmup_steps: int, timescale: int = None): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + shift = timescale - num_warmup_steps + decay = 1.0 / math.sqrt((current_step + shift) / timescale) + return decay + + +def get_inverse_sqrt_schedule( + optimizer: Optimizer, num_warmup_steps: int, timescale: int = None, last_epoch: int = -1 +): + """ + Create a schedule with an inverse square-root learning rate, from the initial lr set in the optimizer, after a + warmup period which increases lr linearly from 0 to the initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + timescale (`int`, *optional*, defaults to `num_warmup_steps`): + Time scale. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + # Note: this implementation is adapted from + # https://github.com/google-research/big_vision/blob/f071ce68852d56099437004fd70057597a95f6ef/big_vision/utils.py#L930 + + if timescale is None: + timescale = num_warmup_steps or 10_000 + + lr_lambda = partial( + _get_inverse_sqrt_schedule_lr_lambda, num_warmup_steps=num_warmup_steps, timescale=timescale + ) + return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) + + +TYPE_TO_SCHEDULER_FUNCTION = { + SchedulerType.COSINE: get_cosine_schedule_with_warmup, + SchedulerType.INVERSE_SQRT: get_inverse_sqrt_schedule, +} + + +def get_scheduler( + name: Union[str, SchedulerType], + optimizer: Optimizer, + num_warmup_steps: Optional[int] = None, + num_training_steps: Optional[int] = None, + last_epoch: int = -1, +) -> LambdaLR: + """ + Unified API to get any scheduler from its name. + + Args: + name (`str` or `SchedulerType`): + The name of the scheduler to use. + optimizer (`torch.optim.Optimizer`): + The optimizer that will be used during training. + num_warmup_steps (`int`, *optional*): + The number of warmup steps to do. This is not required by all schedulers (hence the argument being + optional), the function will raise an error if it's unset and the scheduler type requires it. + num_training_steps (`int``, *optional*): + The number of training steps to do. This is not required by all schedulers (hence the argument being + optional), the function will raise an error if it's unset and the scheduler type requires it. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + """ + name = SchedulerType(name) + if name not in TYPE_TO_SCHEDULER_FUNCTION: + raise ValueError( + f"Unsupported scheduler {name}, expected one of {list(TYPE_TO_SCHEDULER_FUNCTION.keys())}" + ) + + schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] + + # All other schedulers require `num_warmup_steps` + if num_warmup_steps is None: + raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") + + # All other schedulers require `num_training_steps` + if num_training_steps is None: + raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") + + if name == SchedulerType.INVERSE_SQRT: + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps) + + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + last_epoch=last_epoch, + ) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 2b28943d..1e5a42e8 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -69,7 +69,7 @@ def make_optimizer_and_scheduler(cfg, policy): cfg.training.adam_eps, cfg.training.adam_weight_decay, ) - from diffusers.optimization import get_scheduler + from transformers.optimization import get_scheduler lr_scheduler = get_scheduler( cfg.training.lr_scheduler, diff --git a/tests/test_lr_schedulers.py b/tests/test_lr_schedulers.py new file mode 100644 index 00000000..f17cd7ce --- /dev/null +++ b/tests/test_lr_schedulers.py @@ -0,0 +1,72 @@ +import math + +import pytest +import torch + +from lerobot.common.policies.lr_schedulers import get_scheduler + + +def test_get_lr_scheduler(): + optimizer = torch.optim.AdamW(torch.nn.Linear(10, 10).parameters(), lr=1e-4) + + lr_scheduler = get_scheduler("cosine", optimizer, num_warmup_steps=500, num_training_steps=2000) + assert lr_scheduler is not None + assert lr_scheduler.__class__.__name__ == "LambdaLR" + + lr_scheduler = get_scheduler("inverse_sqrt", optimizer, num_warmup_steps=500, num_training_steps=2000) + assert lr_scheduler is not None + assert lr_scheduler.__class__.__name__ == "LambdaLR" + + with pytest.raises(ValueError): + get_scheduler("invalid", 100, 1000) + + +def test_cosine_lr_scheduler(): + intervals = 250 + num_warmup_steps = 500 + num_training_steps = 2000 + recorded_lrs_at_intervals = [2.0e-7, 5.0e-5, 1.0e-4, 9.3e-5, 7.5e-5, 5.0e-5, 2.5e-5, 6.6e-6] + optimizer = torch.optim.AdamW( + torch.nn.Linear(10, 10).parameters(), lr=1e-4, betas=(0.95, 0.999), eps=1e-8, weight_decay=1e-6 + ) + + lr_scheduler = get_scheduler( + "cosine", optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps + ) + assert lr_scheduler.get_last_lr()[0] == 0.0 + + for i in range(num_training_steps): + optimizer.step() + lr_scheduler.step() + if i % intervals == 0: + recorded = recorded_lrs_at_intervals.pop(0) + assert math.isclose( + lr_scheduler.get_last_lr()[0], recorded + ), f"LR value mismatch at step {i}: {lr_scheduler.get_last_lr()[0]} vs. {recorded}" + + assert lr_scheduler.get_last_lr()[0] == recorded_lrs_at_intervals.pop(0) + + +def test_inverse_sqrt_lr_scheduler(): + intervals = 250 + num_warmup_steps = 500 + num_training_steps = 2000 + recorded_lrs_at_intervals = [2.0e-7, 5.0e-5, 1.0e-4, 8.2e-5, 7.1e-5, 6.3e-5, 5.8e-5, 5.3e-5] + optimizer = torch.optim.AdamW( + torch.nn.Linear(10, 10).parameters(), lr=1e-4, betas=(0.95, 0.999), eps=1e-8, weight_decay=1e-6 + ) + + lr_scheduler = get_scheduler( + "inverse_sqrt", optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps + ) + + for i in range(num_training_steps): + optimizer.step() + lr_scheduler.step() + if i % intervals == 0: + recorded = recorded_lrs_at_intervals.pop(0) + assert math.isclose( + lr_scheduler.get_last_lr()[0], recorded + ), f"LR value mismatch at step {i}: {lr_scheduler.get_last_lr()[0]} vs. {recorded}" + + assert lr_scheduler.get_last_lr()[0] == recorded_lrs_at_intervals.pop(0)