lerobot/tests/test_lr_schedulers.py

95 lines
3.1 KiB
Python

import math
import pytest
import torch
from lerobot.common.policies.lr_schedulers import get_scheduler
def test_get_lr_scheduler():
optimizer = torch.optim.AdamW(torch.nn.Linear(10, 10).parameters(), lr=1e-4)
lr_scheduler = get_scheduler("cosine", optimizer, num_warmup_steps=500, num_training_steps=2000)
assert lr_scheduler is not None
assert lr_scheduler.__class__.__name__ == "LambdaLR"
lr_scheduler = get_scheduler("inverse_sqrt", optimizer, num_warmup_steps=500, num_training_steps=2000)
assert lr_scheduler is not None
assert lr_scheduler.__class__.__name__ == "LambdaLR"
with pytest.raises(ValueError):
get_scheduler("invalid", 100, 1000)
def test_cosine_lr_scheduler():
intervals = 250
num_warmup_steps = 500
num_training_steps = 2000
recorded_lrs_at_intervals = [
2.0e-7,
5.0200000e-5,
9.9999890e-5,
9.3248815e-5,
7.4909255e-5,
4.9895280e-5,
2.4909365e-5,
6.6464649e-6,
]
optimizer = torch.optim.AdamW(
torch.nn.Linear(10, 10).parameters(), lr=1e-4, betas=(0.95, 0.999), eps=1e-8, weight_decay=1e-6
)
lr_scheduler = get_scheduler(
"cosine", optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps
)
assert lr_scheduler.get_last_lr()[0] == 0.0
for i in range(num_training_steps - intervals):
lr_scheduler.step()
if i % intervals == 0:
recorded = recorded_lrs_at_intervals.pop(0)
assert math.isclose(
lr_scheduler.get_last_lr()[0], recorded, abs_tol=1e-7
), f"LR value mismatch at step {i}: {lr_scheduler.get_last_lr()[0]} vs. {recorded}"
lr_scheduler.step()
assert math.isclose(
lr_scheduler.get_last_lr()[0], recorded_lrs_at_intervals[-1], abs_tol=1e-7
), f"LR value mismatch at step {num_training_steps}: {lr_scheduler.get_last_lr()[0]} vs. {recorded_lrs_at_intervals[-1]}"
def test_inverse_sqrt_lr_scheduler():
intervals = 250
num_warmup_steps = 500
num_training_steps = 2000
recorded_lrs_at_intervals = [
2.0e-7,
5.02e-5,
9.9900150e-05,
8.1595279e-05,
7.0675349e-05,
6.3220270e-05,
5.7715792e-5,
5.3436983e-5,
]
optimizer = torch.optim.AdamW(
torch.nn.Linear(10, 10).parameters(), lr=1e-4, betas=(0.95, 0.999), eps=1e-8, weight_decay=1e-6
)
lr_scheduler = get_scheduler(
"inverse_sqrt", optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps
)
for i in range(num_training_steps - intervals):
lr_scheduler.step()
if i % intervals == 0:
recorded = recorded_lrs_at_intervals.pop(0)
assert math.isclose(
lr_scheduler.get_last_lr()[0], recorded, abs_tol=1e-7
), f"LR value mismatch at step {i}: {lr_scheduler.get_last_lr()[0]} vs. {recorded}"
lr_scheduler.step()
assert math.isclose(
lr_scheduler.get_last_lr()[0], recorded_lrs_at_intervals[-1], abs_tol=1e-7
), f"LR value mismatch at step {num_training_steps}: {lr_scheduler.get_last_lr()[0]} vs. {recorded_lrs_at_intervals[-1]}"