From 444f17986c087786e227ab26054caf76198d31a6 Mon Sep 17 00:00:00 2001 From: Jimmy Tsai Date: Fri, 26 Dec 2025 08:48:57 +0000 Subject: [PATCH] Add qwen2 implementation --- src/MaxText/common_types.py | 1 + src/MaxText/configs/models/qwen2.5-14b.yml | 38 +++++ src/MaxText/configs/models/qwen2.5-7b.yml | 33 +++++ src/MaxText/configs/types.py | 2 + .../tunix/weight_mapping/__init__.py | 3 + .../integration/tunix/weight_mapping/qwen2.py | 136 ++++++++++++++++++ src/MaxText/layers/attentions.py | 3 +- src/MaxText/layers/decoders.py | 3 + src/MaxText/layers/qwen3.py | 1 + src/MaxText/pyconfig_deprecated.py | 2 + .../ckpt_conversion/utils/hf_model_configs.py | 37 +++++ .../utils/ckpt_conversion/utils/hf_shape.py | 12 ++ .../ckpt_conversion/utils/param_mapping.py | 42 ++++++ .../utils/ckpt_conversion/utils/utils.py | 2 + 14 files changed, 314 insertions(+), 1 deletion(-) create mode 100644 src/MaxText/configs/models/qwen2.5-14b.yml create mode 100644 src/MaxText/configs/models/qwen2.5-7b.yml create mode 100644 src/MaxText/integration/tunix/weight_mapping/qwen2.py diff --git a/src/MaxText/common_types.py b/src/MaxText/common_types.py index f36b991cef..80f1d907da 100644 --- a/src/MaxText/common_types.py +++ b/src/MaxText/common_types.py @@ -92,6 +92,7 @@ class DecoderBlockType(enum.Enum): GEMMA = "gemma" GEMMA2 = "gemma2" GEMMA3 = "gemma3" + QWEN2 = "qwen2" QWEN3 = "qwen3" QWEN3_MOE = "qwen3_moe" QWEN3_NEXT = "qwen3_next" diff --git a/src/MaxText/configs/models/qwen2.5-14b.yml b/src/MaxText/configs/models/qwen2.5-14b.yml new file mode 100644 index 0000000000..92392d1ad7 --- /dev/null +++ b/src/MaxText/configs/models/qwen2.5-14b.yml @@ -0,0 +1,38 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# model config for qwen2.5-14b + +base_emb_dim: 5120 +base_num_query_heads: 40 +base_num_kv_heads: 8 +base_mlp_dim: 13824 +base_num_decoder_layers: 48 +head_dim: 128 +mlp_activations: ["silu", "linear"] # "hidden_act": "silu" implies SwiGLU +vocab_size: 152064 + +decoder_block: "qwen2" + +normalization_layer_epsilon: 1.0e-6 +rope_max_timescale: 1000000 + +use_qk_norm: False +attention_bias: True + +logits_via_embedding: False +normalize_embedding_logits: False + +tokenizer_type: "huggingface" + diff --git a/src/MaxText/configs/models/qwen2.5-7b.yml b/src/MaxText/configs/models/qwen2.5-7b.yml new file mode 100644 index 0000000000..0876baf721 --- /dev/null +++ b/src/MaxText/configs/models/qwen2.5-7b.yml @@ -0,0 +1,33 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# model config for qwen2.5-7b + +base_emb_dim: 3584 +base_num_query_heads: 28 +base_num_kv_heads: 4 +base_mlp_dim: 18944 +base_num_decoder_layers: 28 +head_dim: 128 +mlp_activations: ["silu", "linear"] +vocab_size: 152064 +decoder_block: "qwen2" +normalization_layer_epsilon: 1e-06 +rope_max_timescale: 1000000.0 +use_qk_norm: False +attention_bias: True +logits_via_embedding: False +normalize_embedding_logits: False +tokenizer_type: "huggingface" + diff --git a/src/MaxText/configs/types.py b/src/MaxText/configs/types.py index f55b101581..b4dd5a7b64 100644 --- a/src/MaxText/configs/types.py +++ b/src/MaxText/configs/types.py @@ -220,6 +220,8 @@ class ProfilerType(str, Enum): "gemma3-4b", "gemma3-12b", "gemma3-27b", + "qwen2.5-7b", + "qwen2.5-14b", "qwen3-0.6b", "qwen3-4b", "qwen3-4b-thinking-2507", diff --git a/src/MaxText/integration/tunix/weight_mapping/__init__.py b/src/MaxText/integration/tunix/weight_mapping/__init__.py index 7f7a0dc534..2c218acc56 100644 --- a/src/MaxText/integration/tunix/weight_mapping/__init__.py +++ b/src/MaxText/integration/tunix/weight_mapping/__init__.py @@ -21,6 +21,7 @@ from MaxText.integration.tunix.weight_mapping.deepseek3 import DEEPSEEK_VLLM_MAPPING from MaxText.integration.tunix.weight_mapping.gpt_oss import GPT_OSS_VLLM_MAPPING from MaxText.integration.tunix.weight_mapping.llama3 import LLAMA3_VLLM_MAPPING +from MaxText.integration.tunix.weight_mapping.qwen2 import QWEN2_VLLM_MAPPING from MaxText.integration.tunix.weight_mapping.qwen3 import QWEN3_VLLM_MAPPING @@ -30,6 +31,8 @@ class StandaloneVllmWeightMapping: def __getattr__(self, name): if name.startswith("llama3.1"): return LLAMA3_VLLM_MAPPING + elif name.startswith("qwen2"): + return QWEN2_VLLM_MAPPING elif name.startswith("qwen3"): return QWEN3_VLLM_MAPPING elif name.startswith("deepseek3"): diff --git a/src/MaxText/integration/tunix/weight_mapping/qwen2.py b/src/MaxText/integration/tunix/weight_mapping/qwen2.py new file mode 100644 index 0000000000..583b55b6f5 --- /dev/null +++ b/src/MaxText/integration/tunix/weight_mapping/qwen2.py @@ -0,0 +1,136 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Defines the weight mapping from MaxText's Qwen2 model to a vLLM-compatible format. + +This module provides the `QWEN2_VLLM_MAPPING` dataclass, which contains all the +necessary configurations to convert MaxText's Qwen2 model weights into a +format that can be loaded by HuggingFace's vLLM. This includes: +- A direct mapping of parameter names. +- Sharding specifications for distributed environments. +""" + +from dataclasses import dataclass + + +@dataclass +class QWEN2_VLLM_MAPPING: + """Mapping MaxText Qwen2 weights to vLLM's Qwen2 weights.""" + + @staticmethod + def to_hf_hook_fns(): + """Returns a dictionary of hook functions to be applied to MaxText weights. + + Returns: + An empty dictionary, as no hook functions are needed for this mapping. + """ + + return {} + + @staticmethod + def to_hf_transpose_keys(): + """Returns a list of keys for weights that need to be transposed. + + Returns: + An empty dictionary, as no keys require transposition for this mapping. + """ + return {} + + @staticmethod + def lora_to_hf_mappings(): + """Provides the mapping for LoRA (Low-Rank Adaptation) weights. + + Returns: + None, as LoRA mappings are not defined for this model. + """ + return None + + @staticmethod + def to_hf_mapping(): + """Mapping from MaxText model to HuggingFace vLLM model. + + Currently, the param mapping conforms to the Tunix API, which combines the + param name & sharding in one dictionary. + This is subject to change in the future where we can decouple the two. + """ + return { + # Token embeddings - shard vocab dimension + "base.token_embedder.embedding": ( + "model.embed.embedding", + ("model", None), + ), + # Final layer norm - no sharding needed + "base.decoder.decoder_norm.scale": ( + "model.norm.scale", + (None,), + ), + # LM head (logits projection) - shard vocab dimension + "base.decoder.logits_dense.kernel": ( + "model.lm_head", + (None, "model"), + ), + # Layer-specific mappings (scanned -> unscanned) + # MLP components - shard hidden dimensions + "base.decoder.layers.mlp.wi_0.kernel": ( + "model.layers.*.mlp.gate_proj.kernel", + (None, "layer", "model"), + ), + "base.decoder.layers.mlp.wi_1.kernel": ( + "model.layers.*.mlp.up_proj.kernel", + (None, "layer", "model"), + ), + "base.decoder.layers.mlp.wo.kernel": ( + "model.layers.*.mlp.down_proj.kernel", + ("model", "layer", None), + ), + # Layer norms - no sharding needed + "base.decoder.layers.pre_self_attention_layer_norm.scale": ( + "model.layers.*.input_layernorm.scale", + (None, "layer"), + ), + "base.decoder.layers.post_self_attention_layer_norm.scale": ( + "model.layers.*.post_attention_layernorm.scale", + (None, "layer"), + ), + # Attention components - shard head dimensions + "base.decoder.layers.self_attention.query.kernel": ( + "model.layers.*.self_attn.q_proj.kernel", + (None, "layer", "model", None), + ), + "base.decoder.layers.self_attention.key.kernel": ( + "model.layers.*.self_attn.k_proj.kernel", + (None, "layer", "model", None), + ), + "base.decoder.layers.self_attention.value.kernel": ( + "model.layers.*.self_attn.v_proj.kernel", + (None, "layer", "model", None), + ), + "base.decoder.layers.self_attention.out.kernel": ( + "model.layers.*.self_attn.o_proj.kernel", + ("model", "layer", None, None), + ), + # Attention biases + "base.decoder.layers.self_attention.query.bias": ( + "model.layers.*.self_attn.q_proj.bias", + (None, "layer", "model", None), + ), + "base.decoder.layers.self_attention.key.bias": ( + "model.layers.*.self_attn.k_proj.bias", + (None, "layer", "model", None), + ), + "base.decoder.layers.self_attention.value.bias": ( + "model.layers.*.self_attn.v_proj.bias", + (None, "layer", "model", None), + ), + } \ No newline at end of file diff --git a/src/MaxText/layers/attentions.py b/src/MaxText/layers/attentions.py index 8f7c63fa41..d2c2034174 100644 --- a/src/MaxText/layers/attentions.py +++ b/src/MaxText/layers/attentions.py @@ -425,6 +425,7 @@ def __init__( self.mrope_section = mrope_section self.rngs = rngs + self.is_qwen2 = self.config.decoder_block == DecoderBlockType.QWEN2 self.is_qwen3_next = self.config.decoder_block == DecoderBlockType.QWEN3_NEXT # Module attribute names must match names previously passed to Linen for checkpointing @@ -698,7 +699,7 @@ def init_out_w(self, output_dim: int) -> nnx.Module: quant=self.quant, shard_mode=self.config.shard_mode, matmul_precision=self.config.matmul_precision, - use_bias=self.use_bias_in_projections, + use_bias=False if self.is_qwen2 else self.use_bias_in_projections, rngs=self.rngs, ) diff --git a/src/MaxText/layers/decoders.py b/src/MaxText/layers/decoders.py index 86d3090c47..faf7568354 100644 --- a/src/MaxText/layers/decoders.py +++ b/src/MaxText/layers/decoders.py @@ -420,6 +420,8 @@ def get_decoder_layers(self): return [gpt3.Gpt3DecoderLayerToLinen] case DecoderBlockType.GPT_OSS: return [gpt_oss.GptOssScannableBlockToLinen] if self.config.scan_layers else [gpt_oss.GptOssDecoderLayerToLinen] + case DecoderBlockType.QWEN2: + return [qwen3.Qwen3DecoderLayerToLinen] case DecoderBlockType.QWEN3: return [qwen3.Qwen3DecoderLayerToLinen] case DecoderBlockType.QWEN3_MOE: @@ -478,6 +480,7 @@ def get_norm_layer(self, num_features: int): DecoderBlockType.GEMMA, DecoderBlockType.GEMMA2, DecoderBlockType.GEMMA3, + DecoderBlockType.QWEN2, DecoderBlockType.QWEN3, DecoderBlockType.QWEN3_MOE, DecoderBlockType.GPT_OSS, diff --git a/src/MaxText/layers/qwen3.py b/src/MaxText/layers/qwen3.py index 1ebdb2ce42..c35c303a3e 100644 --- a/src/MaxText/layers/qwen3.py +++ b/src/MaxText/layers/qwen3.py @@ -966,6 +966,7 @@ def __init__( use_ragged_attention=config.use_ragged_attention, ragged_block_size=config.ragged_block_size, use_qk_norm=config.use_qk_norm, + use_bias_in_projections=config.attention_bias, query_pre_attn_scalar=query_pre_attn_scalar, model_mode=model_mode, use_mrope=config.use_mrope, diff --git a/src/MaxText/pyconfig_deprecated.py b/src/MaxText/pyconfig_deprecated.py index 582d7a122f..34b53d1579 100644 --- a/src/MaxText/pyconfig_deprecated.py +++ b/src/MaxText/pyconfig_deprecated.py @@ -460,6 +460,8 @@ def validate_model_name(s: str) -> bool: "gemma3-4b", "gemma3-12b", "gemma3-27b", + "qwen2.5-7b", + "qwen2.5-14b", "qwen3-0.6b", "qwen3-4b", "qwen3-4b-thinking-2507", diff --git a/src/MaxText/utils/ckpt_conversion/utils/hf_model_configs.py b/src/MaxText/utils/ckpt_conversion/utils/hf_model_configs.py index d91b7987ca..ddec2e336e 100644 --- a/src/MaxText/utils/ckpt_conversion/utils/hf_model_configs.py +++ b/src/MaxText/utils/ckpt_conversion/utils/hf_model_configs.py @@ -210,6 +210,41 @@ query_pre_attn_scalar=144, ) +qwen25_7b_config = transformers.Qwen2Config( + vocab_size=152064, + hidden_size=3584, + intermediate_size=18944, + num_hidden_layers=28, + num_attention_heads=28, + num_key_value_heads=4, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-06, + use_cache=True, + rope_theta=1000000.0, + tie_word_embeddings=False, + torch_dtype="bfloat16", + attention_bias=True, +) + +qwen25_14b_config = transformers.Qwen2Config( + vocab_size=152064, + hidden_size=5120, + intermediate_size=13824, + num_hidden_layers=48, + num_attention_heads=40, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=32768, + rms_norm_eps=1e-06, + rope_theta=1000000.0, + tie_word_embeddings=False, + torch_dtype="bfloat16", + attention_bias=True, +) + + qwen3_0_6b_config = transformers.Qwen3Config( vocab_size=151936, hidden_size=1024, @@ -772,6 +807,8 @@ "gemma3-4b": gemma3_4b_config, "gemma3-12b": gemma3_12b_config, "gemma3-27b": gemma3_27b_config, + "qwen2.5-7b": qwen25_7b_config, + "qwen2.5-14b": qwen25_14b_config, "qwen3-0.6b": qwen3_0_6b_config, "qwen3-4b": qwen3_4b_config, "qwen3-4b-thinking-2507": qwen3_4b_config, diff --git a/src/MaxText/utils/ckpt_conversion/utils/hf_shape.py b/src/MaxText/utils/ckpt_conversion/utils/hf_shape.py index 081017dd96..7867985bb1 100644 --- a/src/MaxText/utils/ckpt_conversion/utils/hf_shape.py +++ b/src/MaxText/utils/ckpt_conversion/utils/hf_shape.py @@ -484,6 +484,15 @@ def QWEN3_HF_WEIGHTS_TO_SHAPE(config): f"{layer_prefix}.self_attn.k_norm.weight": [head_dim], } + if attention_bias: + layer_mapping.update( + { + f"{layer_prefix}.self_attn.q_proj.bias": [num_attention_heads * head_dim], + f"{layer_prefix}.self_attn.k_proj.bias": [num_key_value_heads * head_dim], + f"{layer_prefix}.self_attn.v_proj.bias": [num_key_value_heads * head_dim], + } + ) + if num_experts > 1: # MoE MLP layers moe_ffn_intermediate_size = config.get("moe_intermediate_size") @@ -660,6 +669,8 @@ def MIXTRAL_HF_WEIGHTS_TO_SHAPE(config): "gemma3-4b": GEMMA3_HF_WEIGHTS_TO_SHAPE, "gemma3-12b": GEMMA3_HF_WEIGHTS_TO_SHAPE, "gemma3-27b": GEMMA3_HF_WEIGHTS_TO_SHAPE, + "qwen2.5-7b": QWEN3_HF_WEIGHTS_TO_SHAPE, + "qwen2.5-14b": QWEN3_HF_WEIGHTS_TO_SHAPE, "qwen3-0.6b": QWEN3_HF_WEIGHTS_TO_SHAPE, "qwen3-4b": QWEN3_HF_WEIGHTS_TO_SHAPE, "qwen3-4b-thinking-2507": QWEN3_HF_WEIGHTS_TO_SHAPE, @@ -678,3 +689,4 @@ def MIXTRAL_HF_WEIGHTS_TO_SHAPE(config): "mixtral-8x7b": MIXTRAL_HF_WEIGHTS_TO_SHAPE, "mixtral-8x22b": MIXTRAL_HF_WEIGHTS_TO_SHAPE, } + diff --git a/src/MaxText/utils/ckpt_conversion/utils/param_mapping.py b/src/MaxText/utils/ckpt_conversion/utils/param_mapping.py index a38e77c8f4..8cf81b0049 100644 --- a/src/MaxText/utils/ckpt_conversion/utils/param_mapping.py +++ b/src/MaxText/utils/ckpt_conversion/utils/param_mapping.py @@ -617,6 +617,15 @@ def QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False) "params-decoder-layers-self_attention-value-kernel": [ f"model.layers.{i}.self_attn.v_proj.weight" for i in range(n_layers) ], + "params-decoder-layers-self_attention-query-bias": [ + f"model.layers.{i}.self_attn.q_proj.bias" for i in range(n_layers) + ], + "params-decoder-layers-self_attention-key-bias": [ + f"model.layers.{i}.self_attn.k_proj.bias" for i in range(n_layers) + ], + "params-decoder-layers-self_attention-value-bias": [ + f"model.layers.{i}.self_attn.v_proj.bias" for i in range(n_layers) + ], "params-decoder-layers-self_attention-out-kernel": [ f"model.layers.{i}.self_attn.o_proj.weight" for i in range(n_layers) ], @@ -674,6 +683,11 @@ def QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False) f"params-decoder-layers_{i}-self_attention-key-kernel": f"model.layers.{i}.self_attn.k_proj.weight", f"params-decoder-layers_{i}-self_attention-value-kernel": f"model.layers.{i}.self_attn.v_proj.weight", f"params-decoder-layers_{i}-self_attention-out-kernel": f"model.layers.{i}.self_attn.o_proj.weight", + + f"params-decoder-layers_{i}-self_attention-query-bias": f"model.layers.{i}.self_attn.q_proj.bias", + f"params-decoder-layers_{i}-self_attention-key-bias": f"model.layers.{i}.self_attn.k_proj.bias", + f"params-decoder-layers_{i}-self_attention-value-bias": f"model.layers.{i}.self_attn.v_proj.bias", + f"params-decoder-layers_{i}-self_attention-query_norm-scale": f"model.layers.{i}.self_attn.q_norm.weight", f"params-decoder-layers_{i}-self_attention-key_norm-scale": f"model.layers.{i}.self_attn.k_norm.weight", f"params-decoder-layers_{i}-post_self_attention_layer_norm-scale": f"model.layers.{i}.post_attention_layernorm.weight", @@ -751,6 +765,15 @@ def reshape_kernel(input_tensor, target_shape): return input_tensor.reshape(flipped_target_shape).T else: return input_tensor.T.reshape(target_shape) + + def reshape_bias(input_tensor, target_shape=None): + """Reshapes biases between MaxText 2D (heads, dim) and HF 1D (hidden).""" + if saving_to_hf: + # MaxText [heads, head_dim] -> HF [hidden_dim] (flatten) + return input_tensor.reshape(target_shape) + else: + # HF [hidden_dim] -> MaxText [heads, head_dim] + return input_tensor.reshape(target_shape) mapping = { "params-token_embedder-embedding": pad_embedding_layer, @@ -766,6 +789,11 @@ def reshape_kernel(input_tensor, target_shape): "mlp-wi_1-kernel", "mlp-wo-kernel", ] + bias_hooks = [ + "self_attention-query-bias", + "self_attention-key-bias", + "self_attention-value-bias", + ] moe_kernel_hooks = [ "moe_block-gate-kernel", "moe_block-wi_0-kernel", @@ -779,6 +807,8 @@ def reshape_kernel(input_tensor, target_shape): if scan_layers: for key in kernel_hooks: mapping[f"params-decoder-layers-{key}"] = reshape_kernel + for key in bias_hooks: + mapping[f"params-decoder-layers-{key}"] = reshape_bias if num_experts > 1: for key in moe_kernel_hooks: mapping[f"params-decoder-layers-{key}"] = reshape_kernel @@ -786,6 +816,8 @@ def reshape_kernel(input_tensor, target_shape): for i in range(n_layers): for key in kernel_hooks: mapping[f"params-decoder-layers_{i}-{key}"] = reshape_kernel + for key in bias_hooks: + mapping[f"params-decoder-layers_{i}-{key}"] = reshape_bias if num_experts > 1: for key in moe_kernel_hooks: mapping[f"params-decoder-layers_{i}-{key}"] = reshape_kernel @@ -1577,6 +1609,11 @@ def scale_query_layer(input_tensor, target_shape): "gemma3-4b": GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING, "gemma3-12b": GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING, "gemma3-27b": GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen2.5-0.5b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen2.5-1.5b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen2.5-3b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen2.5-7b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen2.5-14b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, "qwen3-0.6b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, "qwen3-4b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, "qwen3-4b-thinking-2507": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, @@ -1605,6 +1642,11 @@ def scale_query_layer(input_tensor, target_shape): "gemma3-4b": GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN, "gemma3-12b": GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN, "gemma3-27b": GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen2.5-0.5b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen2.5-1.5b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen2.5-3b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen2.5-7b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen2.5-14b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, "qwen3-0.6b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, "qwen3-4b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, "qwen3-4b-thinking-2507": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, diff --git a/src/MaxText/utils/ckpt_conversion/utils/utils.py b/src/MaxText/utils/ckpt_conversion/utils/utils.py index 42eb439539..a4509ba799 100644 --- a/src/MaxText/utils/ckpt_conversion/utils/utils.py +++ b/src/MaxText/utils/ckpt_conversion/utils/utils.py @@ -64,6 +64,8 @@ "gemma3-4b": "google/gemma-3-4b-it", # hf multi-modal should also support the pure-text "gemma3-12b": "google/gemma-3-12b-it", "gemma3-27b": "google/gemma-3-27b-it", + "qwen2.5-7b": "Qwen/Qwen2.5-7B-Instruct", + "qwen2.5-14b": "Qwen/Qwen2.5-14B-Instruct", "qwen3-0.6b": "Qwen/Qwen3-0.6B", "qwen3-4b": "Qwen/Qwen3-4B", "qwen3-4b-thinking-2507": "Qwen/Qwen3-4B-Thinking-2507",