Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/MaxText/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class DecoderBlockType(enum.Enum):
GEMMA = "gemma"
GEMMA2 = "gemma2"
GEMMA3 = "gemma3"
QWEN2 = "qwen2"
QWEN3 = "qwen3"
QWEN3_MOE = "qwen3_moe"
QWEN3_NEXT = "qwen3_next"
Expand Down
38 changes: 38 additions & 0 deletions src/MaxText/configs/models/qwen2.5-14b.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright 2023–2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# model config for qwen2.5-14b

base_emb_dim: 5120
base_num_query_heads: 40
base_num_kv_heads: 8
base_mlp_dim: 13824
base_num_decoder_layers: 48
head_dim: 128
mlp_activations: ["silu", "linear"] # "hidden_act": "silu" implies SwiGLU
vocab_size: 152064

decoder_block: "qwen2"

normalization_layer_epsilon: 1.0e-6
rope_max_timescale: 1000000

use_qk_norm: False
attention_bias: True

logits_via_embedding: False
normalize_embedding_logits: False

tokenizer_type: "huggingface"

33 changes: 33 additions & 0 deletions src/MaxText/configs/models/qwen2.5-7b.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright 2023–2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# model config for qwen2.5-7b

base_emb_dim: 3584
base_num_query_heads: 28
base_num_kv_heads: 4
base_mlp_dim: 18944
base_num_decoder_layers: 28
head_dim: 128
mlp_activations: ["silu", "linear"]
vocab_size: 152064
decoder_block: "qwen2"
normalization_layer_epsilon: 1e-06
rope_max_timescale: 1000000.0
use_qk_norm: False
attention_bias: True
logits_via_embedding: False
normalize_embedding_logits: False
tokenizer_type: "huggingface"

2 changes: 2 additions & 0 deletions src/MaxText/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,8 @@ class ProfilerType(str, Enum):
"gemma3-4b",
"gemma3-12b",
"gemma3-27b",
"qwen2.5-7b",
"qwen2.5-14b",
"qwen3-0.6b",
"qwen3-4b",
"qwen3-4b-thinking-2507",
Expand Down
3 changes: 3 additions & 0 deletions src/MaxText/integration/tunix/weight_mapping/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from MaxText.integration.tunix.weight_mapping.deepseek3 import DEEPSEEK_VLLM_MAPPING
from MaxText.integration.tunix.weight_mapping.gpt_oss import GPT_OSS_VLLM_MAPPING
from MaxText.integration.tunix.weight_mapping.llama3 import LLAMA3_VLLM_MAPPING
from MaxText.integration.tunix.weight_mapping.qwen2 import QWEN2_VLLM_MAPPING
from MaxText.integration.tunix.weight_mapping.qwen3 import QWEN3_VLLM_MAPPING


Expand All @@ -30,6 +31,8 @@ class StandaloneVllmWeightMapping:
def __getattr__(self, name):
if name.startswith("llama3.1"):
return LLAMA3_VLLM_MAPPING
elif name.startswith("qwen2"):
return QWEN2_VLLM_MAPPING
elif name.startswith("qwen3"):
return QWEN3_VLLM_MAPPING
elif name.startswith("deepseek3"):
Expand Down
136 changes: 136 additions & 0 deletions src/MaxText/integration/tunix/weight_mapping/qwen2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Copyright 2023–2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Defines the weight mapping from MaxText's Qwen2 model to a vLLM-compatible format.

This module provides the `QWEN2_VLLM_MAPPING` dataclass, which contains all the
necessary configurations to convert MaxText's Qwen2 model weights into a
format that can be loaded by HuggingFace's vLLM. This includes:
- A direct mapping of parameter names.
- Sharding specifications for distributed environments.
"""

from dataclasses import dataclass


@dataclass
class QWEN2_VLLM_MAPPING:
"""Mapping MaxText Qwen2 weights to vLLM's Qwen2 weights."""

@staticmethod
def to_hf_hook_fns():
"""Returns a dictionary of hook functions to be applied to MaxText weights.

Returns:
An empty dictionary, as no hook functions are needed for this mapping.
"""

return {}

@staticmethod
def to_hf_transpose_keys():
"""Returns a list of keys for weights that need to be transposed.

Returns:
An empty dictionary, as no keys require transposition for this mapping.
"""
return {}

@staticmethod
def lora_to_hf_mappings():
"""Provides the mapping for LoRA (Low-Rank Adaptation) weights.

Returns:
None, as LoRA mappings are not defined for this model.
"""
return None

@staticmethod
def to_hf_mapping():
"""Mapping from MaxText model to HuggingFace vLLM model.

Currently, the param mapping conforms to the Tunix API, which combines the
param name & sharding in one dictionary.
This is subject to change in the future where we can decouple the two.
"""
return {
# Token embeddings - shard vocab dimension
"base.token_embedder.embedding": (
"model.embed.embedding",
("model", None),
),
# Final layer norm - no sharding needed
"base.decoder.decoder_norm.scale": (
"model.norm.scale",
(None,),
),
# LM head (logits projection) - shard vocab dimension
"base.decoder.logits_dense.kernel": (
"model.lm_head",
(None, "model"),
),
# Layer-specific mappings (scanned -> unscanned)
# MLP components - shard hidden dimensions
"base.decoder.layers.mlp.wi_0.kernel": (
"model.layers.*.mlp.gate_proj.kernel",
(None, "layer", "model"),
),
"base.decoder.layers.mlp.wi_1.kernel": (
"model.layers.*.mlp.up_proj.kernel",
(None, "layer", "model"),
),
"base.decoder.layers.mlp.wo.kernel": (
"model.layers.*.mlp.down_proj.kernel",
("model", "layer", None),
),
# Layer norms - no sharding needed
"base.decoder.layers.pre_self_attention_layer_norm.scale": (
"model.layers.*.input_layernorm.scale",
(None, "layer"),
),
"base.decoder.layers.post_self_attention_layer_norm.scale": (
"model.layers.*.post_attention_layernorm.scale",
(None, "layer"),
),
# Attention components - shard head dimensions
"base.decoder.layers.self_attention.query.kernel": (
"model.layers.*.self_attn.q_proj.kernel",
(None, "layer", "model", None),
),
"base.decoder.layers.self_attention.key.kernel": (
"model.layers.*.self_attn.k_proj.kernel",
(None, "layer", "model", None),
),
"base.decoder.layers.self_attention.value.kernel": (
"model.layers.*.self_attn.v_proj.kernel",
(None, "layer", "model", None),
),
"base.decoder.layers.self_attention.out.kernel": (
"model.layers.*.self_attn.o_proj.kernel",
("model", "layer", None, None),
),
# Attention biases
"base.decoder.layers.self_attention.query.bias": (
"model.layers.*.self_attn.q_proj.bias",
(None, "layer", "model", None),
),
"base.decoder.layers.self_attention.key.bias": (
"model.layers.*.self_attn.k_proj.bias",
(None, "layer", "model", None),
),
"base.decoder.layers.self_attention.value.bias": (
"model.layers.*.self_attn.v_proj.bias",
(None, "layer", "model", None),
),
}
3 changes: 2 additions & 1 deletion src/MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,7 @@ def __init__(
self.mrope_section = mrope_section
self.rngs = rngs

self.is_qwen2 = self.config.decoder_block == DecoderBlockType.QWEN2
self.is_qwen3_next = self.config.decoder_block == DecoderBlockType.QWEN3_NEXT

# Module attribute names must match names previously passed to Linen for checkpointing
Expand Down Expand Up @@ -698,7 +699,7 @@ def init_out_w(self, output_dim: int) -> nnx.Module:
quant=self.quant,
shard_mode=self.config.shard_mode,
matmul_precision=self.config.matmul_precision,
use_bias=self.use_bias_in_projections,
use_bias=False if self.is_qwen2 else self.use_bias_in_projections,
rngs=self.rngs,
)

Expand Down
3 changes: 3 additions & 0 deletions src/MaxText/layers/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,8 @@ def get_decoder_layers(self):
return [gpt3.Gpt3DecoderLayerToLinen]
case DecoderBlockType.GPT_OSS:
return [gpt_oss.GptOssScannableBlockToLinen] if self.config.scan_layers else [gpt_oss.GptOssDecoderLayerToLinen]
case DecoderBlockType.QWEN2:
return [qwen3.Qwen3DecoderLayerToLinen]
case DecoderBlockType.QWEN3:
return [qwen3.Qwen3DecoderLayerToLinen]
case DecoderBlockType.QWEN3_MOE:
Expand Down Expand Up @@ -478,6 +480,7 @@ def get_norm_layer(self, num_features: int):
DecoderBlockType.GEMMA,
DecoderBlockType.GEMMA2,
DecoderBlockType.GEMMA3,
DecoderBlockType.QWEN2,
DecoderBlockType.QWEN3,
DecoderBlockType.QWEN3_MOE,
DecoderBlockType.GPT_OSS,
Expand Down
1 change: 1 addition & 0 deletions src/MaxText/layers/qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -966,6 +966,7 @@ def __init__(
use_ragged_attention=config.use_ragged_attention,
ragged_block_size=config.ragged_block_size,
use_qk_norm=config.use_qk_norm,
use_bias_in_projections=config.attention_bias,
query_pre_attn_scalar=query_pre_attn_scalar,
model_mode=model_mode,
use_mrope=config.use_mrope,
Expand Down
2 changes: 2 additions & 0 deletions src/MaxText/pyconfig_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,8 @@ def validate_model_name(s: str) -> bool:
"gemma3-4b",
"gemma3-12b",
"gemma3-27b",
"qwen2.5-7b",
"qwen2.5-14b",
"qwen3-0.6b",
"qwen3-4b",
"qwen3-4b-thinking-2507",
Expand Down
37 changes: 37 additions & 0 deletions src/MaxText/utils/ckpt_conversion/utils/hf_model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,41 @@
query_pre_attn_scalar=144,
)

qwen25_7b_config = transformers.Qwen2Config(
vocab_size=152064,
hidden_size=3584,
intermediate_size=18944,
num_hidden_layers=28,
num_attention_heads=28,
num_key_value_heads=4,
hidden_act="silu",
max_position_embeddings=32768,
initializer_range=0.02,
rms_norm_eps=1e-06,
use_cache=True,
rope_theta=1000000.0,
tie_word_embeddings=False,
torch_dtype="bfloat16",
attention_bias=True,
)

qwen25_14b_config = transformers.Qwen2Config(
vocab_size=152064,
hidden_size=5120,
intermediate_size=13824,
num_hidden_layers=48,
num_attention_heads=40,
num_key_value_heads=8,
hidden_act="silu",
max_position_embeddings=32768,
rms_norm_eps=1e-06,
rope_theta=1000000.0,
tie_word_embeddings=False,
torch_dtype="bfloat16",
attention_bias=True,
)


qwen3_0_6b_config = transformers.Qwen3Config(
vocab_size=151936,
hidden_size=1024,
Expand Down Expand Up @@ -772,6 +807,8 @@
"gemma3-4b": gemma3_4b_config,
"gemma3-12b": gemma3_12b_config,
"gemma3-27b": gemma3_27b_config,
"qwen2.5-7b": qwen25_7b_config,
"qwen2.5-14b": qwen25_14b_config,
"qwen3-0.6b": qwen3_0_6b_config,
"qwen3-4b": qwen3_4b_config,
"qwen3-4b-thinking-2507": qwen3_4b_config,
Expand Down
12 changes: 12 additions & 0 deletions src/MaxText/utils/ckpt_conversion/utils/hf_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,15 @@ def QWEN3_HF_WEIGHTS_TO_SHAPE(config):
f"{layer_prefix}.self_attn.k_norm.weight": [head_dim],
}

if attention_bias:
layer_mapping.update(
{
f"{layer_prefix}.self_attn.q_proj.bias": [num_attention_heads * head_dim],
f"{layer_prefix}.self_attn.k_proj.bias": [num_key_value_heads * head_dim],
f"{layer_prefix}.self_attn.v_proj.bias": [num_key_value_heads * head_dim],
}
)

if num_experts > 1:
# MoE MLP layers
moe_ffn_intermediate_size = config.get("moe_intermediate_size")
Expand Down Expand Up @@ -660,6 +669,8 @@ def MIXTRAL_HF_WEIGHTS_TO_SHAPE(config):
"gemma3-4b": GEMMA3_HF_WEIGHTS_TO_SHAPE,
"gemma3-12b": GEMMA3_HF_WEIGHTS_TO_SHAPE,
"gemma3-27b": GEMMA3_HF_WEIGHTS_TO_SHAPE,
"qwen2.5-7b": QWEN3_HF_WEIGHTS_TO_SHAPE,
"qwen2.5-14b": QWEN3_HF_WEIGHTS_TO_SHAPE,
"qwen3-0.6b": QWEN3_HF_WEIGHTS_TO_SHAPE,
"qwen3-4b": QWEN3_HF_WEIGHTS_TO_SHAPE,
"qwen3-4b-thinking-2507": QWEN3_HF_WEIGHTS_TO_SHAPE,
Expand All @@ -678,3 +689,4 @@ def MIXTRAL_HF_WEIGHTS_TO_SHAPE(config):
"mixtral-8x7b": MIXTRAL_HF_WEIGHTS_TO_SHAPE,
"mixtral-8x22b": MIXTRAL_HF_WEIGHTS_TO_SHAPE,
}

Loading
Loading