From 5f00b25e3e1bf4d1bc4847c66fc87f7b6e63dbe5 Mon Sep 17 00:00:00 2001 From: Jimmy Tsai Date: Thu, 5 Feb 2026 09:24:31 +0000 Subject: [PATCH] Fix learning rate schedule --- src/maxtext/utils/maxtext_utils.py | 2 +- tests/unit/maxtext_utils_test.py | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/src/maxtext/utils/maxtext_utils.py b/src/maxtext/utils/maxtext_utils.py index bcd7d5ddd..67a0d721c 100644 --- a/src/maxtext/utils/maxtext_utils.py +++ b/src/maxtext/utils/maxtext_utils.py @@ -1376,7 +1376,7 @@ def schedule(step): boundaries = [] if warmup_steps > 0: - warmup_schedule = optax.linear_schedule(init_value=0.0, end_value=lr, transition_steps=warmup_steps - 1) + warmup_schedule = optax.linear_schedule(init_value=0.0, end_value=lr, transition_steps=warmup_steps) pieces.append(warmup_schedule) boundaries.append(warmup_steps) diff --git a/tests/unit/maxtext_utils_test.py b/tests/unit/maxtext_utils_test.py index a23e8bad1..a53ed094b 100644 --- a/tests/unit/maxtext_utils_test.py +++ b/tests/unit/maxtext_utils_test.py @@ -748,6 +748,11 @@ def test_cosine_schedule(self): # Warmup phase: 0 -> peak self.assertAlmostEqual(float(schedule_fn(0)), 0.0, places=6) self.assertAlmostEqual(float(schedule_fn(warmup_steps)), learning_rate, places=6) + # Ensure delta is constant + expected_slope = learning_rate / warmup_steps + for i in range(1, warmup_steps + 1): + current_lr = float(schedule_fn(i)) + self.assertAlmostEqual(current_lr - float(schedule_fn(i - 1)), expected_slope, places=6) # Cosine decay phase lr_end = schedule_fn(learning_rate_schedule_steps - 1) @@ -791,6 +796,11 @@ def test_wsd_schedule(self): # Warmup phase: 0 -> peak self.assertAlmostEqual(float(schedule_fn(0)), 0.0, places=6) self.assertAlmostEqual(float(schedule_fn(warmup_steps)), learning_rate, places=6) + # Ensure delta is constant + expected_slope = learning_rate / warmup_steps + for i in range(1, warmup_steps + 1): + current_lr = float(schedule_fn(i)) + self.assertAlmostEqual(current_lr - float(schedule_fn(i - 1)), expected_slope, places=6) # Stable phase: constant at peak self.assertAlmostEqual(float(schedule_fn(warmup_steps + 10)), learning_rate, places=6)