diff --git a/tests/test_lr_schedulers.py b/tests/test_lr_schedulers.py index f17cd7ce..0bfa503a 100644 --- a/tests/test_lr_schedulers.py +++ b/tests/test_lr_schedulers.py @@ -38,7 +38,7 @@ def test_cosine_lr_scheduler(): for i in range(num_training_steps): optimizer.step() lr_scheduler.step() - if i % intervals == 0: + if i == 0 or (i + 1) % intervals == 0: recorded = recorded_lrs_at_intervals.pop(0) assert math.isclose( lr_scheduler.get_last_lr()[0], recorded @@ -61,9 +61,8 @@ def test_inverse_sqrt_lr_scheduler(): ) for i in range(num_training_steps): - optimizer.step() lr_scheduler.step() - if i % intervals == 0: + if i == 0 or (i + 1) % intervals == 0: recorded = recorded_lrs_at_intervals.pop(0) assert math.isclose( lr_scheduler.get_last_lr()[0], recorded