From fa6ff3c013e3e1e16180ad3473d12439e9c774e3 Mon Sep 17 00:00:00 2001 From: Richard <99752583+rxu183@users.noreply.github.com> Date: Thu, 5 Feb 2026 11:55:21 -0600 Subject: [PATCH 1/2] Add qwen3 1.7B --- .../convert_checkpoint.md | 2 +- src/MaxText/configs/models/qwen3-1.7b.yml | 37 +++++++++++++++++++ src/MaxText/configs/types.py | 1 + src/MaxText/pyconfig_deprecated.py | 1 + .../ckpt_conversion/utils/hf_model_configs.py | 17 +++++++++ .../utils/ckpt_conversion/utils/hf_shape.py | 1 + .../ckpt_conversion/utils/param_mapping.py | 2 + .../utils/ckpt_conversion/utils/utils.py | 1 + tests/unit/configs_test.py | 1 + tests/unit/muon_test.py | 2 +- tests/utils/sharding_dump.py | 1 + 11 files changed, 64 insertions(+), 2 deletions(-) create mode 100644 src/MaxText/configs/models/qwen3-1.7b.yml diff --git a/docs/guides/checkpointing_solutions/convert_checkpoint.md b/docs/guides/checkpointing_solutions/convert_checkpoint.md index b37d2923c8..a22d63e31f 100644 --- a/docs/guides/checkpointing_solutions/convert_checkpoint.md +++ b/docs/guides/checkpointing_solutions/convert_checkpoint.md @@ -11,7 +11,7 @@ The following models are supported: | **Gemma2** | 2B, 9B, 27B | √ | √ | √ | √ | | **Gemma3** (Multimodal) | 4B, 12B, 27B | - | √ | - | √ | | **Llama3.1** | 8B, 70B, 450B | √ | √ | √ | √ | -| **Qwen3** | 0.6B, 4B, 8B, 14B, 32B | √ | √ | √ | √ | +| **Qwen3** | 0.6B, 1.7B, 4B, 8B, 14B, 32B | √ | √ | √ | √ | | **Qwen3 MoE** | 30B, 235B, 480B | √ | √ | √ | √ | | **Mixtral** | 8x7B, 8x22B | √ | √ | √ | √ | | **GPT-OSS** | 20B, 120B | √ | √ | √ | √ | diff --git a/src/MaxText/configs/models/qwen3-1.7b.yml b/src/MaxText/configs/models/qwen3-1.7b.yml new file mode 100644 index 0000000000..79b7abe8e7 --- /dev/null +++ b/src/MaxText/configs/models/qwen3-1.7b.yml @@ -0,0 +1,37 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# model config for qwen3-1.7b + +base_emb_dim: 2048 +base_num_query_heads: 16 +base_num_kv_heads: 8 +base_mlp_dim: 6144 +base_num_decoder_layers: 28 +head_dim: 128 +mlp_activations: ["silu", "linear"] # "hidden_act": "silu" implies SwiGLU +vocab_size: 151936 + +decoder_block: "qwen3" + +normalization_layer_epsilon: 1.0e-6 +rope_max_timescale: 1000000 + +use_qk_norm: True + +logits_via_embedding: True # from "tie_word_embeddings": true +normalize_embedding_logits: False +enable_dropout: False # deterministic for testing + +tokenizer_type: "huggingface" diff --git a/src/MaxText/configs/types.py b/src/MaxText/configs/types.py index 40c2404b76..2be7c523bc 100644 --- a/src/MaxText/configs/types.py +++ b/src/MaxText/configs/types.py @@ -221,6 +221,7 @@ class ProfilerType(str, Enum): "gemma3-12b", "gemma3-27b", "qwen3-0.6b", + "qwen3-1.7b", "qwen3-4b", "qwen3-4b-thinking-2507", "qwen3-8b", diff --git a/src/MaxText/pyconfig_deprecated.py b/src/MaxText/pyconfig_deprecated.py index 582d7a122f..7d8bbebcef 100644 --- a/src/MaxText/pyconfig_deprecated.py +++ b/src/MaxText/pyconfig_deprecated.py @@ -461,6 +461,7 @@ def validate_model_name(s: str) -> bool: "gemma3-12b", "gemma3-27b", "qwen3-0.6b", + "qwen3-1.7b", "qwen3-4b", "qwen3-4b-thinking-2507", "qwen3-8b", diff --git a/src/MaxText/utils/ckpt_conversion/utils/hf_model_configs.py b/src/MaxText/utils/ckpt_conversion/utils/hf_model_configs.py index d91b7987ca..1eab21606d 100644 --- a/src/MaxText/utils/ckpt_conversion/utils/hf_model_configs.py +++ b/src/MaxText/utils/ckpt_conversion/utils/hf_model_configs.py @@ -226,6 +226,22 @@ torch_dtype="bfloat16", ) +qwen3_1_7b_config = transformers.Qwen3Config( + vocab_size=151936, + hidden_size=2048, + intermediate_size=6144, + num_hidden_layers=28, + num_attention_heads=16, + num_key_value_heads=8, + head_dim=128, + hidden_act="silu", + max_position_embeddings=40960, + rms_norm_eps=1.0e-6, + rope_theta=1000000.0, + tie_word_embeddings=True, + torch_dtype="bfloat16", +) + qwen3_4b_config = transformers.Qwen3Config( vocab_size=151936, hidden_size=2560, @@ -773,6 +789,7 @@ "gemma3-12b": gemma3_12b_config, "gemma3-27b": gemma3_27b_config, "qwen3-0.6b": qwen3_0_6b_config, + "qwen3-1.7b": qwen3_1_7b_config, "qwen3-4b": qwen3_4b_config, "qwen3-4b-thinking-2507": qwen3_4b_config, "qwen3-8b": qwen3_8b_config, diff --git a/src/MaxText/utils/ckpt_conversion/utils/hf_shape.py b/src/MaxText/utils/ckpt_conversion/utils/hf_shape.py index 081017dd96..3a5cc89676 100644 --- a/src/MaxText/utils/ckpt_conversion/utils/hf_shape.py +++ b/src/MaxText/utils/ckpt_conversion/utils/hf_shape.py @@ -661,6 +661,7 @@ def MIXTRAL_HF_WEIGHTS_TO_SHAPE(config): "gemma3-12b": GEMMA3_HF_WEIGHTS_TO_SHAPE, "gemma3-27b": GEMMA3_HF_WEIGHTS_TO_SHAPE, "qwen3-0.6b": QWEN3_HF_WEIGHTS_TO_SHAPE, + "qwen3-1.7b": QWEN3_HF_WEIGHTS_TO_SHAPE, "qwen3-4b": QWEN3_HF_WEIGHTS_TO_SHAPE, "qwen3-4b-thinking-2507": QWEN3_HF_WEIGHTS_TO_SHAPE, "qwen3-8b": QWEN3_HF_WEIGHTS_TO_SHAPE, diff --git a/src/MaxText/utils/ckpt_conversion/utils/param_mapping.py b/src/MaxText/utils/ckpt_conversion/utils/param_mapping.py index a38e77c8f4..e4055d28d1 100644 --- a/src/MaxText/utils/ckpt_conversion/utils/param_mapping.py +++ b/src/MaxText/utils/ckpt_conversion/utils/param_mapping.py @@ -1578,6 +1578,7 @@ def scale_query_layer(input_tensor, target_shape): "gemma3-12b": GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING, "gemma3-27b": GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING, "qwen3-0.6b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen3-1.7b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, "qwen3-4b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, "qwen3-4b-thinking-2507": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, "qwen3-8b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, @@ -1606,6 +1607,7 @@ def scale_query_layer(input_tensor, target_shape): "gemma3-12b": GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN, "gemma3-27b": GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN, "qwen3-0.6b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen3-1.7b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, "qwen3-4b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, "qwen3-4b-thinking-2507": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, "qwen3-8b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, diff --git a/src/MaxText/utils/ckpt_conversion/utils/utils.py b/src/MaxText/utils/ckpt_conversion/utils/utils.py index 42eb439539..a96ad02330 100644 --- a/src/MaxText/utils/ckpt_conversion/utils/utils.py +++ b/src/MaxText/utils/ckpt_conversion/utils/utils.py @@ -65,6 +65,7 @@ "gemma3-12b": "google/gemma-3-12b-it", "gemma3-27b": "google/gemma-3-27b-it", "qwen3-0.6b": "Qwen/Qwen3-0.6B", + "qwen3-1.7b": "Qwen/Qwen3-1.7B", "qwen3-4b": "Qwen/Qwen3-4B", "qwen3-4b-thinking-2507": "Qwen/Qwen3-4B-Thinking-2507", "qwen3-8b": "Qwen/Qwen3-8B", diff --git a/tests/unit/configs_test.py b/tests/unit/configs_test.py index 44dda1df3a..14f97e894e 100644 --- a/tests/unit/configs_test.py +++ b/tests/unit/configs_test.py @@ -230,6 +230,7 @@ def test_mistral_configs(config_file): QWEN_CONFIGS = [ os.path.join(CONFIGS_DIR, "models", "qwen3-0.6b.yml"), + os.path.join(CONFIGS_DIR, "models", "qwen3-1.7b.yml"), os.path.join(CONFIGS_DIR, "models", "qwen3-4b.yml"), os.path.join(CONFIGS_DIR, "models", "qwen3-4b-thinking-2507.yml"), os.path.join(CONFIGS_DIR, "models", "qwen3-8b.yml"), diff --git a/tests/unit/muon_test.py b/tests/unit/muon_test.py index 9fd847d04e..ab1f2344fd 100644 --- a/tests/unit/muon_test.py +++ b/tests/unit/muon_test.py @@ -185,7 +185,7 @@ # qwen3, specific: logits_via_embedding=True -# applicable: qwen3-0.6b, qwen3-4b, but not: qwen3-8b, qwen3-14b (logits_via_embedding=False) +# applicable: qwen3-0.6b, qwen3-1.7b, qwen3-4b, but not: qwen3-8b, qwen3-14b (logits_via_embedding=False) QWEN3_DIMENSION_NUMBER = { "params": { "decoder": { diff --git a/tests/utils/sharding_dump.py b/tests/utils/sharding_dump.py index c096c98136..a3440facfe 100644 --- a/tests/utils/sharding_dump.py +++ b/tests/utils/sharding_dump.py @@ -61,6 +61,7 @@ # "gemma3-12b", # "gemma3-27b", # "qwen3-0.6b", + # "qwen3-1.7b", # "qwen3-4b", # "qwen3-8b", # "gpt3-175b", From 138dc1cf2575c4d4c0191383a061fb02c5a584f7 Mon Sep 17 00:00:00 2001 From: Richard <99752583+rxu183@users.noreply.github.com> Date: Wed, 18 Feb 2026 23:05:48 -0600 Subject: [PATCH 2/2] Fix mdformatting --- .../convert_checkpoint.md | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/docs/guides/checkpointing_solutions/convert_checkpoint.md b/docs/guides/checkpointing_solutions/convert_checkpoint.md index 1d677eda38..f6ff28b44e 100644 --- a/docs/guides/checkpointing_solutions/convert_checkpoint.md +++ b/docs/guides/checkpointing_solutions/convert_checkpoint.md @@ -6,16 +6,16 @@ This guide provides instructions for using the [scripts](https://github.com/AI-H The following models are supported: -| Model Family | Sizes | HF $\\to$ Orbax (scan) | HF $\\to$ Orbax (unscan) | Orbax (scan) $\\to$ HF | Orbax (unscan) $\\to$ HF | -| :---------------------- | :--------------------- | :--------------------: | :----------------------: | :--------------------: | :----------------------: | -| **Gemma2** | 2B, 9B, 27B | √ | √ | √ | √ | -| **Gemma3** (Multimodal) | 4B, 12B, 27B | - | √ | - | √ | -| **Llama3.1** | 8B, 70B, 450B | √ | √ | √ | √ | +| Model Family | Sizes | HF $\\to$ Orbax (scan) | HF $\\to$ Orbax (unscan) | Orbax (scan) $\\to$ HF | Orbax (unscan) $\\to$ HF | +| :---------------------- | :--------------------------- | :--------------------: | :----------------------: | :--------------------: | :----------------------: | +| **Gemma2** | 2B, 9B, 27B | √ | √ | √ | √ | +| **Gemma3** (Multimodal) | 4B, 12B, 27B | - | √ | - | √ | +| **Llama3.1** | 8B, 70B, 450B | √ | √ | √ | √ | | **Qwen3** | 0.6B, 1.7B, 4B, 8B, 14B, 32B | √ | √ | √ | √ | -| **Qwen3 MoE** | 30B, 235B, 480B | √ | √ | √ | √ | -| **Mixtral** | 8x7B, 8x22B | √ | √ | √ | √ | -| **GPT-OSS** | 20B, 120B | √ | √ | √ | √ | -| **DeepSeek3** | 671B | - | - | √ | - | +| **Qwen3 MoE** | 30B, 235B, 480B | √ | √ | √ | √ | +| **Mixtral** | 8x7B, 8x22B | √ | √ | √ | √ | +| **GPT-OSS** | 20B, 120B | √ | √ | √ | √ | +| **DeepSeek3** | 671B | - | - | √ | - | ## Prerequisites