diff --git a/src/maxtext/utils/maxtext_utils.py b/src/maxtext/utils/maxtext_utils.py index bcd7d5ddd..67a0d721c 100644 --- a/src/maxtext/utils/maxtext_utils.py +++ b/src/maxtext/utils/maxtext_utils.py @@ -1376,7 +1376,7 @@ def schedule(step): boundaries = [] if warmup_steps > 0: - warmup_schedule = optax.linear_schedule(init_value=0.0, end_value=lr, transition_steps=warmup_steps - 1) + warmup_schedule = optax.linear_schedule(init_value=0.0, end_value=lr, transition_steps=warmup_steps) pieces.append(warmup_schedule) boundaries.append(warmup_steps) diff --git a/tests/unit/maxtext_utils_test.py b/tests/unit/maxtext_utils_test.py index a23e8bad1..a53ed094b 100644 --- a/tests/unit/maxtext_utils_test.py +++ b/tests/unit/maxtext_utils_test.py @@ -748,6 +748,11 @@ def test_cosine_schedule(self): # Warmup phase: 0 -> peak self.assertAlmostEqual(float(schedule_fn(0)), 0.0, places=6) self.assertAlmostEqual(float(schedule_fn(warmup_steps)), learning_rate, places=6) + # Ensure delta is constant + expected_slope = learning_rate / warmup_steps + for i in range(1, warmup_steps + 1): + current_lr = float(schedule_fn(i)) + self.assertAlmostEqual(current_lr - float(schedule_fn(i - 1)), expected_slope, places=6) # Cosine decay phase lr_end = schedule_fn(learning_rate_schedule_steps - 1) @@ -791,6 +796,11 @@ def test_wsd_schedule(self): # Warmup phase: 0 -> peak self.assertAlmostEqual(float(schedule_fn(0)), 0.0, places=6) self.assertAlmostEqual(float(schedule_fn(warmup_steps)), learning_rate, places=6) + # Ensure delta is constant + expected_slope = learning_rate / warmup_steps + for i in range(1, warmup_steps + 1): + current_lr = float(schedule_fn(i)) + self.assertAlmostEqual(current_lr - float(schedule_fn(i - 1)), expected_slope, places=6) # Stable phase: constant at peak self.assertAlmostEqual(float(schedule_fn(warmup_steps + 10)), learning_rate, places=6)