diff --git a/docs/guides/checkpointing_solutions/convert_checkpoint.md b/docs/guides/checkpointing_solutions/convert_checkpoint.md index 97a15616bf..f6ff28b44e 100644 --- a/docs/guides/checkpointing_solutions/convert_checkpoint.md +++ b/docs/guides/checkpointing_solutions/convert_checkpoint.md @@ -6,16 +6,16 @@ This guide provides instructions for using the [scripts](https://github.com/AI-H The following models are supported: -| Model Family | Sizes | HF $\\to$ Orbax (scan) | HF $\\to$ Orbax (unscan) | Orbax (scan) $\\to$ HF | Orbax (unscan) $\\to$ HF | -| :---------------------- | :--------------------- | :--------------------: | :----------------------: | :--------------------: | :----------------------: | -| **Gemma2** | 2B, 9B, 27B | √ | √ | √ | √ | -| **Gemma3** (Multimodal) | 4B, 12B, 27B | - | √ | - | √ | -| **Llama3.1** | 8B, 70B, 450B | √ | √ | √ | √ | -| **Qwen3** | 0.6B, 4B, 8B, 14B, 32B | √ | √ | √ | √ | -| **Qwen3 MoE** | 30B, 235B, 480B | √ | √ | √ | √ | -| **Mixtral** | 8x7B, 8x22B | √ | √ | √ | √ | -| **GPT-OSS** | 20B, 120B | √ | √ | √ | √ | -| **DeepSeek3** | 671B | - | - | √ | - | +| Model Family | Sizes | HF $\\to$ Orbax (scan) | HF $\\to$ Orbax (unscan) | Orbax (scan) $\\to$ HF | Orbax (unscan) $\\to$ HF | +| :---------------------- | :--------------------------- | :--------------------: | :----------------------: | :--------------------: | :----------------------: | +| **Gemma2** | 2B, 9B, 27B | √ | √ | √ | √ | +| **Gemma3** (Multimodal) | 4B, 12B, 27B | - | √ | - | √ | +| **Llama3.1** | 8B, 70B, 450B | √ | √ | √ | √ | +| **Qwen3** | 0.6B, 1.7B, 4B, 8B, 14B, 32B | √ | √ | √ | √ | +| **Qwen3 MoE** | 30B, 235B, 480B | √ | √ | √ | √ | +| **Mixtral** | 8x7B, 8x22B | √ | √ | √ | √ | +| **GPT-OSS** | 20B, 120B | √ | √ | √ | √ | +| **DeepSeek3** | 671B | - | - | √ | - | ## Prerequisites diff --git a/src/MaxText/pyconfig_deprecated.py b/src/MaxText/pyconfig_deprecated.py index 582d7a122f..7d8bbebcef 100644 --- a/src/MaxText/pyconfig_deprecated.py +++ b/src/MaxText/pyconfig_deprecated.py @@ -461,6 +461,7 @@ def validate_model_name(s: str) -> bool: "gemma3-12b", "gemma3-27b", "qwen3-0.6b", + "qwen3-1.7b", "qwen3-4b", "qwen3-4b-thinking-2507", "qwen3-8b", diff --git a/src/maxtext/checkpoint_conversion/utils/hf_model_configs.py b/src/maxtext/checkpoint_conversion/utils/hf_model_configs.py index d91b7987ca..1eab21606d 100644 --- a/src/maxtext/checkpoint_conversion/utils/hf_model_configs.py +++ b/src/maxtext/checkpoint_conversion/utils/hf_model_configs.py @@ -226,6 +226,22 @@ torch_dtype="bfloat16", ) +qwen3_1_7b_config = transformers.Qwen3Config( + vocab_size=151936, + hidden_size=2048, + intermediate_size=6144, + num_hidden_layers=28, + num_attention_heads=16, + num_key_value_heads=8, + head_dim=128, + hidden_act="silu", + max_position_embeddings=40960, + rms_norm_eps=1.0e-6, + rope_theta=1000000.0, + tie_word_embeddings=True, + torch_dtype="bfloat16", +) + qwen3_4b_config = transformers.Qwen3Config( vocab_size=151936, hidden_size=2560, @@ -773,6 +789,7 @@ "gemma3-12b": gemma3_12b_config, "gemma3-27b": gemma3_27b_config, "qwen3-0.6b": qwen3_0_6b_config, + "qwen3-1.7b": qwen3_1_7b_config, "qwen3-4b": qwen3_4b_config, "qwen3-4b-thinking-2507": qwen3_4b_config, "qwen3-8b": qwen3_8b_config, diff --git a/src/maxtext/checkpoint_conversion/utils/hf_shape.py b/src/maxtext/checkpoint_conversion/utils/hf_shape.py index 081017dd96..3a5cc89676 100644 --- a/src/maxtext/checkpoint_conversion/utils/hf_shape.py +++ b/src/maxtext/checkpoint_conversion/utils/hf_shape.py @@ -661,6 +661,7 @@ def MIXTRAL_HF_WEIGHTS_TO_SHAPE(config): "gemma3-12b": GEMMA3_HF_WEIGHTS_TO_SHAPE, "gemma3-27b": GEMMA3_HF_WEIGHTS_TO_SHAPE, "qwen3-0.6b": QWEN3_HF_WEIGHTS_TO_SHAPE, + "qwen3-1.7b": QWEN3_HF_WEIGHTS_TO_SHAPE, "qwen3-4b": QWEN3_HF_WEIGHTS_TO_SHAPE, "qwen3-4b-thinking-2507": QWEN3_HF_WEIGHTS_TO_SHAPE, "qwen3-8b": QWEN3_HF_WEIGHTS_TO_SHAPE, diff --git a/src/maxtext/checkpoint_conversion/utils/param_mapping.py b/src/maxtext/checkpoint_conversion/utils/param_mapping.py index d4f7317969..b76bfb4403 100644 --- a/src/maxtext/checkpoint_conversion/utils/param_mapping.py +++ b/src/maxtext/checkpoint_conversion/utils/param_mapping.py @@ -2083,6 +2083,7 @@ def pad_hf_embedding_layer(input_tensor, target_shape): "gemma3-12b": GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING, "gemma3-27b": GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING, "qwen3-0.6b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen3-1.7b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, "qwen3-4b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, "qwen3-4b-thinking-2507": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, "qwen3-8b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, @@ -2114,6 +2115,7 @@ def pad_hf_embedding_layer(input_tensor, target_shape): "gemma3-12b": GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN, "gemma3-27b": GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN, "qwen3-0.6b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen3-1.7b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, "qwen3-4b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, "qwen3-4b-thinking-2507": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, "qwen3-8b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, diff --git a/src/maxtext/checkpoint_conversion/utils/utils.py b/src/maxtext/checkpoint_conversion/utils/utils.py index 5a0ecfe940..374335a9e9 100644 --- a/src/maxtext/checkpoint_conversion/utils/utils.py +++ b/src/maxtext/checkpoint_conversion/utils/utils.py @@ -65,6 +65,7 @@ "gemma3-12b": "google/gemma-3-12b-it", "gemma3-27b": "google/gemma-3-27b-it", "qwen3-0.6b": "Qwen/Qwen3-0.6B", + "qwen3-1.7b": "Qwen/Qwen3-1.7B", "qwen3-4b": "Qwen/Qwen3-4B", "qwen3-4b-thinking-2507": "Qwen/Qwen3-4B-Thinking-2507", "qwen3-8b": "Qwen/Qwen3-8B", diff --git a/src/maxtext/configs/models/qwen3-1.7b.yml b/src/maxtext/configs/models/qwen3-1.7b.yml new file mode 100644 index 0000000000..79b7abe8e7 --- /dev/null +++ b/src/maxtext/configs/models/qwen3-1.7b.yml @@ -0,0 +1,37 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# model config for qwen3-1.7b + +base_emb_dim: 2048 +base_num_query_heads: 16 +base_num_kv_heads: 8 +base_mlp_dim: 6144 +base_num_decoder_layers: 28 +head_dim: 128 +mlp_activations: ["silu", "linear"] # "hidden_act": "silu" implies SwiGLU +vocab_size: 151936 + +decoder_block: "qwen3" + +normalization_layer_epsilon: 1.0e-6 +rope_max_timescale: 1000000 + +use_qk_norm: True + +logits_via_embedding: True # from "tie_word_embeddings": true +normalize_embedding_logits: False +enable_dropout: False # deterministic for testing + +tokenizer_type: "huggingface" diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index cde89b0c7e..572b4ed932 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -229,6 +229,7 @@ class ProfilerType(str, Enum): "gemma3-12b", "gemma3-27b", "qwen3-0.6b", + "qwen3-1.7b", "qwen3-4b", "qwen3-4b-thinking-2507", "qwen3-8b", diff --git a/tests/unit/configs_test.py b/tests/unit/configs_test.py index 09ceb03f44..fb410be224 100644 --- a/tests/unit/configs_test.py +++ b/tests/unit/configs_test.py @@ -230,6 +230,7 @@ def test_mistral_configs(config_file): QWEN_CONFIGS = [ os.path.join(CONFIGS_DIR, "models", "qwen3-0.6b.yml"), + os.path.join(CONFIGS_DIR, "models", "qwen3-1.7b.yml"), os.path.join(CONFIGS_DIR, "models", "qwen3-4b.yml"), os.path.join(CONFIGS_DIR, "models", "qwen3-4b-thinking-2507.yml"), os.path.join(CONFIGS_DIR, "models", "qwen3-8b.yml"), diff --git a/tests/unit/muon_test.py b/tests/unit/muon_test.py index 9fd847d04e..ab1f2344fd 100644 --- a/tests/unit/muon_test.py +++ b/tests/unit/muon_test.py @@ -185,7 +185,7 @@ # qwen3, specific: logits_via_embedding=True -# applicable: qwen3-0.6b, qwen3-4b, but not: qwen3-8b, qwen3-14b (logits_via_embedding=False) +# applicable: qwen3-0.6b, qwen3-1.7b, qwen3-4b, but not: qwen3-8b, qwen3-14b (logits_via_embedding=False) QWEN3_DIMENSION_NUMBER = { "params": { "decoder": { diff --git a/tests/utils/sharding_dump.py b/tests/utils/sharding_dump.py index ec7b98b752..cf051a6fd8 100644 --- a/tests/utils/sharding_dump.py +++ b/tests/utils/sharding_dump.py @@ -66,6 +66,7 @@ # "gemma3-12b", # "gemma3-27b", "qwen3-0.6b", + # "qwen3-1.7b", # "qwen3-4b", # "qwen3-4b-thinking-2507", # "qwen3-8b",