test condition
This commit is contained in:
parent
90e6df3ecb
commit
f2e50a351d
|
@ -38,7 +38,7 @@ def test_cosine_lr_scheduler():
|
||||||
for i in range(num_training_steps):
|
for i in range(num_training_steps):
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
lr_scheduler.step()
|
lr_scheduler.step()
|
||||||
if i % intervals == 0:
|
if i == 0 or (i + 1) % intervals == 0:
|
||||||
recorded = recorded_lrs_at_intervals.pop(0)
|
recorded = recorded_lrs_at_intervals.pop(0)
|
||||||
assert math.isclose(
|
assert math.isclose(
|
||||||
lr_scheduler.get_last_lr()[0], recorded
|
lr_scheduler.get_last_lr()[0], recorded
|
||||||
|
@ -61,9 +61,8 @@ def test_inverse_sqrt_lr_scheduler():
|
||||||
)
|
)
|
||||||
|
|
||||||
for i in range(num_training_steps):
|
for i in range(num_training_steps):
|
||||||
optimizer.step()
|
|
||||||
lr_scheduler.step()
|
lr_scheduler.step()
|
||||||
if i % intervals == 0:
|
if i == 0 or (i + 1) % intervals == 0:
|
||||||
recorded = recorded_lrs_at_intervals.pop(0)
|
recorded = recorded_lrs_at_intervals.pop(0)
|
||||||
assert math.isclose(
|
assert math.isclose(
|
||||||
lr_scheduler.get_last_lr()[0], recorded
|
lr_scheduler.get_last_lr()[0], recorded
|
||||||
|
|
Loading…
Reference in New Issue