diff --git a/lightllm/common/basemodel/layer_weights/base_layer_weight.py b/lightllm/common/basemodel/layer_weights/base_layer_weight.py
index 1875e2c3b..b1d992a7c 100644
--- a/lightllm/common/basemodel/layer_weights/base_layer_weight.py
+++ b/lightllm/common/basemodel/layer_weights/base_layer_weight.py
@@ -33,7 +33,11 @@ def verify_load(self):
for attr_name in dir(self):
attr = getattr(self, attr_name)
if isinstance(attr, BaseWeight):
- assert attr.verify_load(), f"Loading {attr_name} of layers {self.layer_num_} fails."
+ if hasattr(self, "layer_num_"):
+ layer_num = self.layer_num_
+ else:
+ layer_num = None
+ assert attr.verify_load(), f"Loading {attr_name} of layers {layer_num} fails."
def _cuda(self, cpu_tensor):
return cpu_tensor.contiguous().to(self.data_type_).cuda(get_current_device_id())
diff --git a/lightllm/models/__init__.py b/lightllm/models/__init__.py
index 095f73679..32ccbe833 100644
--- a/lightllm/models/__init__.py
+++ b/lightllm/models/__init__.py
@@ -37,4 +37,5 @@
Tarsier2LlamaTpPartModel,
)
from lightllm.models.gpt_oss.model import GptOssTpPartModel
+from lightllm.models.qwen3_omni_moe_thinker.model import Qwen3OmniMOETpPartModel
from .registry import get_model, get_model_class
diff --git a/lightllm/models/qwen2_vl/triton_kernel/get_mrope_position_ids.py b/lightllm/models/qwen2_vl/triton_kernel/get_mrope_position_ids.py
index 756198e89..eace676c0 100644
--- a/lightllm/models/qwen2_vl/triton_kernel/get_mrope_position_ids.py
+++ b/lightllm/models/qwen2_vl/triton_kernel/get_mrope_position_ids.py
@@ -28,13 +28,13 @@ def _get_mrope_position_triton(
local_image_start_idx = tl.load(b_image_start_idx + image_start_num + i)
image_start_idx = start_loc + local_image_start_idx - cache_len
image_len = tl.load(b_image_len + image_start_num + i)
- image_h = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 1)
+ # image_h = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 1)
image_w = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 2)
for j in range(0, image_len, BLOCK_SIZE):
off = j + tl.arange(0, BLOCK_SIZE)
# 目前没考虑视频,所以t 恒为 0
t_pos = local_image_start_idx + off * 0
- h_pos = local_image_start_idx + off // image_h
+ h_pos = local_image_start_idx + off // image_w
w_pos = local_image_start_idx + off % image_w
tl.store(
position_ids + off + image_start_idx,
diff --git a/lightllm/models/qwen3_omni_moe_thinker/__init__.py b/lightllm/models/qwen3_omni_moe_thinker/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/lightllm/models/qwen3_omni_moe_thinker/audio_process.py b/lightllm/models/qwen3_omni_moe_thinker/audio_process.py
new file mode 100644
index 000000000..833cc8f4b
--- /dev/null
+++ b/lightllm/models/qwen3_omni_moe_thinker/audio_process.py
@@ -0,0 +1,185 @@
+import torch
+import numpy as np
+from typing import TYPE_CHECKING, Any, Optional, Union, Tuple
+from transformers.audio_utils import mel_filter_bank, spectrogram, window_function
+from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
+from transformers.feature_extraction_utils import BatchFeature
+from transformers.utils import TensorType
+
+
+class WhisperFeatureExtractor(SequenceFeatureExtractor):
+
+ model_input_names = ["input_features"]
+
+ def __init__(
+ self,
+ feature_size=80,
+ sampling_rate=16000,
+ hop_length=160,
+ chunk_length=30,
+ n_fft=400,
+ padding_value=0.0,
+ dither=0.0,
+ return_attention_mask=False, # pad inputs to max length with silence token (zero) and no attention mask
+ **kwargs,
+ ):
+ super().__init__(
+ feature_size=feature_size,
+ sampling_rate=sampling_rate,
+ padding_value=padding_value,
+ return_attention_mask=return_attention_mask,
+ **kwargs,
+ )
+ self.n_fft = n_fft
+ self.hop_length = hop_length
+ self.chunk_length = chunk_length
+ self.n_samples = chunk_length * sampling_rate
+ self.nb_max_frames = self.n_samples // hop_length
+ self.sampling_rate = sampling_rate
+ self.dither = dither
+ self.mel_filters = mel_filter_bank(
+ num_frequency_bins=1 + n_fft // 2,
+ num_mel_filters=feature_size,
+ min_frequency=0.0,
+ max_frequency=8000.0,
+ sampling_rate=sampling_rate,
+ norm="slaney",
+ mel_scale="slaney",
+ )
+
+ def _torch_extract_fbank_features(self, waveform: np.ndarray, device: str = "cpu") -> np.ndarray:
+ waveform = torch.from_numpy(waveform).to(device, torch.float32)
+ window = torch.hann_window(self.n_fft, device=device)
+
+ if self.dither != 0.0:
+ waveform += self.dither * torch.randn(waveform.shape, dtype=waveform.dtype, device=waveform.device)
+
+ stft = torch.stft(waveform, self.n_fft, self.hop_length, window=window, return_complex=True)
+ magnitudes = stft[..., :-1].abs() ** 2
+
+ mel_filters = torch.from_numpy(self.mel_filters).to(device, torch.float32)
+ mel_spec = mel_filters.T @ magnitudes
+
+ log_spec = torch.clamp(mel_spec, min=1e-10).log10()
+ if waveform.dim() == 2:
+ max_val = log_spec.max(dim=2, keepdim=True)[0].max(dim=1, keepdim=True)[0]
+ log_spec = torch.maximum(log_spec, max_val - 8.0)
+ else:
+ log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
+ log_spec = (log_spec + 4.0) / 4.0
+ if device != "cpu":
+ log_spec = log_spec.detach().cpu()
+ return log_spec.numpy()
+
+ # Copied from transformers.models.wav2vec2.feature_extraction_wav2vec2.
+ # Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm
+ def zero_mean_unit_var_norm(
+ self, input_values: list[np.ndarray], attention_mask: list[np.ndarray], padding_value: float = 0.0
+ ) -> list[np.ndarray]:
+ if attention_mask is not None:
+ attention_mask = np.array(attention_mask, np.int32)
+ normed_input_values = []
+
+ for vector, length in zip(input_values, attention_mask.sum(-1)):
+ normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7)
+ if length < normed_slice.shape[0]:
+ normed_slice[length:] = padding_value
+
+ normed_input_values.append(normed_slice)
+ else:
+ normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values]
+
+ return normed_input_values
+
+ def _preprocess(
+ self,
+ raw_speech: Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]],
+ truncation: bool = True,
+ pad_to_multiple_of: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_attention_mask: Optional[bool] = None,
+ padding: Optional[str] = "longest", # max_length代表padding到max_length
+ max_length: Optional[int] = None,
+ sampling_rate: Optional[int] = 16000,
+ do_normalize: Optional[bool] = None,
+ device: Optional[str] = "cpu",
+ return_token_timestamps: Optional[bool] = None,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+
+ is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
+ if is_batched_numpy and len(raw_speech.shape) > 2:
+ raise ValueError(f"Only mono-channel audio is supported for input to {self}")
+ is_batched = is_batched_numpy or (
+ isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list)))
+ )
+
+ if is_batched:
+ raw_speech = [np.asarray([speech], dtype=np.float32).T for speech in raw_speech]
+ elif not is_batched and not isinstance(raw_speech, np.ndarray):
+ raw_speech = np.asarray(raw_speech, dtype=np.float32)
+ elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64):
+ raw_speech = raw_speech.astype(np.float32)
+
+ # always return batch
+ if not is_batched:
+ raw_speech = [np.asarray([raw_speech]).T]
+
+ batched_speech = BatchFeature({"input_features": raw_speech})
+
+ # convert into correct format for padding
+
+ padded_inputs = self.pad(
+ batched_speech,
+ padding=padding,
+ max_length=max_length if max_length else self.n_samples,
+ truncation=truncation,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_attention_mask=return_attention_mask or do_normalize,
+ )
+
+ # zero-mean and unit-variance normalization
+ if do_normalize:
+ padded_inputs["input_features"] = self.zero_mean_unit_var_norm(
+ padded_inputs["input_features"],
+ attention_mask=padded_inputs["attention_mask"],
+ padding_value=self.padding_value,
+ )
+ padded_inputs["input_features"] = np.stack(padded_inputs["input_features"], axis=0)
+
+ # make sure list is in array format
+ input_features = padded_inputs.get("input_features").transpose(2, 0, 1)
+
+ input_features = self._torch_extract_fbank_features(input_features[0], device)
+
+ if isinstance(input_features[0], list):
+ padded_inputs["input_features"] = [np.asarray(feature, dtype=np.float32) for feature in input_features]
+
+ else:
+ padded_inputs["input_features"] = input_features
+
+ if return_attention_mask:
+ # rescale from sample (48000) to feature (3000)
+ rescaled_attention_mask = padded_inputs["attention_mask"][:, :: self.hop_length]
+
+ # The STFT computation produces L//hop_length + 1 frames,
+ # but we skip the last frame (see `_torch_extract_fbank_features`).
+ # This means we need to trim the rescaled attention mask to match
+ # the actual number of frames (L//hop_length) when the input length
+ # is not perfectly divisible by the hop length.
+ if padded_inputs["attention_mask"].shape[1] % self.hop_length != 0:
+ rescaled_attention_mask = rescaled_attention_mask[:, :-1]
+ padded_inputs["attention_mask"] = rescaled_attention_mask
+
+ if return_token_timestamps is not None:
+ padded_inputs["num_frames"] = [len(raw_speech_i) // self.hop_length for raw_speech_i in raw_speech]
+
+ if return_tensors is not None:
+ padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
+ input_features = torch.from_numpy(np.asarray(padded_inputs["input_features"], dtype=np.float32)).to(
+ device="cuda", dtype=torch.bfloat16
+ )
+ attention_mask = torch.from_numpy(np.asarray(padded_inputs["attention_mask"], dtype=np.float32)).to(
+ device="cuda", dtype=torch.int32
+ )
+ return input_features, attention_mask
diff --git a/lightllm/models/qwen3_omni_moe_thinker/infer_struct.py b/lightllm/models/qwen3_omni_moe_thinker/infer_struct.py
new file mode 100644
index 000000000..1c09ebf44
--- /dev/null
+++ b/lightllm/models/qwen3_omni_moe_thinker/infer_struct.py
@@ -0,0 +1,6 @@
+from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo
+
+
+class Qwen3OmniMOEInferStateInfo(Qwen3VLInferStateInfo):
+ def __init__(self):
+ super().__init__()
diff --git a/lightllm/models/qwen3_omni_moe_thinker/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_omni_moe_thinker/layer_infer/transformer_layer_infer.py
new file mode 100644
index 000000000..1a05a752f
--- /dev/null
+++ b/lightllm/models/qwen3_omni_moe_thinker/layer_infer/transformer_layer_infer.py
@@ -0,0 +1,15 @@
+import torch
+from lightllm.models.qwen3_vl_moe.layer_infer.transformer_layer_infer import Qwen3VLMOETransformerLayerInfer
+from lightllm.utils.log_utils import init_logger
+
+logger = init_logger(__name__)
+
+
+class Qwen3OmniMOETransformerLayerInfer(Qwen3VLMOETransformerLayerInfer):
+ def __init__(self, layer_num, network_config):
+ super().__init__(layer_num, network_config)
+ self.head_dim_ = network_config["head_dim"]
+ self.mrope_section = torch.tensor(
+ network_config["rope_scaling"]["mrope_section"], dtype=torch.int32, device="cuda"
+ )
+ return
diff --git a/lightllm/models/qwen3_omni_moe_thinker/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen3_omni_moe_thinker/layer_weights/pre_and_post_layer_weight.py
new file mode 100644
index 000000000..5ac8060c4
--- /dev/null
+++ b/lightllm/models/qwen3_omni_moe_thinker/layer_weights/pre_and_post_layer_weight.py
@@ -0,0 +1,30 @@
+from lightllm.models.qwen2.layer_weights.pre_and_post_layer_weight import Qwen2PreAndPostLayerWeight
+from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LMHeadWeight, RMSNormWeight
+
+
+class Qwen3OmniMOEThinkerPreAndPostLayerWeight(Qwen2PreAndPostLayerWeight):
+ def __init__(self, data_type, network_config):
+ super().__init__(data_type, network_config)
+
+ hidden_size = network_config["hidden_size"]
+ vocab_size = network_config["vocab_size"]
+ self.wte_weight_ = EmbeddingWeight(
+ dim=hidden_size,
+ vocab_size=vocab_size,
+ weight_name="thinker.model.embed_tokens.weight",
+ data_type=self.data_type_,
+ )
+ tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False)
+ self.lm_head_weight_ = LMHeadWeight(
+ dim=hidden_size,
+ vocab_size=vocab_size,
+ weight_name="thinker.lm_head.weight",
+ data_type=self.data_type_,
+ embedding_weight=self.wte_weight_ if tie_word_embeddings else None,
+ )
+ self.final_norm_weight_ = RMSNormWeight(
+ dim=hidden_size,
+ weight_name="thinker.model.norm.weight",
+ data_type=self.data_type_,
+ )
+ return
diff --git a/lightllm/models/qwen3_omni_moe_thinker/layer_weights/transformers_layer_weight.py b/lightllm/models/qwen3_omni_moe_thinker/layer_weights/transformers_layer_weight.py
new file mode 100644
index 000000000..775ba5ffe
--- /dev/null
+++ b/lightllm/models/qwen3_omni_moe_thinker/layer_weights/transformers_layer_weight.py
@@ -0,0 +1,53 @@
+import os
+from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight
+from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, FusedMoeWeight
+
+
+class Qwen3OmniMOEThinkerTransformerLayerWeight(Qwen3MOETransformerLayerWeight):
+ def __init__(self, layer_num, data_type, network_config, quant_cfg=None):
+ super().__init__(layer_num, data_type, network_config, quant_cfg)
+ return
+
+ def _init_weight_names(self):
+ self._q_weight_name = f"thinker.model.layers.{self.layer_num_}.self_attn.q_proj.weight"
+ self._q_norm_name = f"thinker.model.layers.{self.layer_num_}.self_attn.q_norm.weight"
+ self._q_bias_name = None
+ self._k_weight_name = f"thinker.model.layers.{self.layer_num_}.self_attn.k_proj.weight"
+ self._k_norm_name = f"thinker.model.layers.{self.layer_num_}.self_attn.k_norm.weight"
+ self._k_bias_name = None
+ self._v_weight_name = f"thinker.model.layers.{self.layer_num_}.self_attn.v_proj.weight"
+ self._v_bias_name = None
+ self._kv_weight_name = f"thinker.model.layers.{self.layer_num_}.self_attn.kv_proj.weight"
+ self._kv_bias_name = None
+ self._o_weight_name = f"thinker.model.layers.{self.layer_num_}.self_attn.o_proj.weight"
+ self._o_bias_name = None
+ self._att_norm_weight_name = f"thinker.model.layers.{self.layer_num_}.input_layernorm.weight"
+ self._att_norm_bias_name = None
+ self._ffn_norm_weight_name = f"thinker.model.layers.{self.layer_num_}.post_attention_layernorm.weight"
+ self._ffn_norm_bias_name = None
+
+ def _init_moe(self):
+ moe_intermediate_size = self.network_config_["moe_intermediate_size"]
+ self.moe_gate = ROWMMWeight(
+ in_dim=self.network_config_["hidden_size"],
+ out_dims=[self.n_routed_experts],
+ weight_names=f"thinker.model.layers.{self.layer_num_}.mlp.gate.weight",
+ data_type=self.data_type_,
+ quant_method=None,
+ tp_rank=0,
+ tp_world_size=1,
+ )
+ self.experts = FusedMoeWeight(
+ gate_proj_name="gate_proj",
+ down_proj_name="down_proj",
+ up_proj_name="up_proj",
+ e_score_correction_bias_name="",
+ weight_prefix=f"thinker.model.layers.{self.layer_num_}.mlp.experts",
+ n_routed_experts=self.n_routed_experts,
+ hidden_size=self.network_config_["hidden_size"],
+ moe_intermediate_size=moe_intermediate_size,
+ data_type=self.data_type_,
+ quant_method=self.quant_cfg.get_quant_method(self.layer_num_, "fused_moe"),
+ layer_num=self.layer_num_,
+ network_config=self.network_config_,
+ )
diff --git a/lightllm/models/qwen3_omni_moe_thinker/model.py b/lightllm/models/qwen3_omni_moe_thinker/model.py
new file mode 100644
index 000000000..2e863da00
--- /dev/null
+++ b/lightllm/models/qwen3_omni_moe_thinker/model.py
@@ -0,0 +1,158 @@
+import os
+import json
+import librosa
+from io import BytesIO
+from lightllm.common.build_utils import repair_config
+from lightllm.models.registry import ModelRegistry
+from lightllm.models.qwen3_moe.model import Qwen3MOEModel
+from lightllm.models.qwen3_vl.layer_infer.pre_layer_infer import Qwen3VLMultimodalPreLayerInfer
+
+from lightllm.models.qwen3_omni_moe_thinker.layer_infer.transformer_layer_infer import Qwen3OmniMOETransformerLayerInfer
+from lightllm.models.qwen3_omni_moe_thinker.layer_weights.pre_and_post_layer_weight import (
+ Qwen3OmniMOEThinkerPreAndPostLayerWeight,
+)
+from lightllm.models.qwen3_omni_moe_thinker.layer_weights.transformers_layer_weight import (
+ Qwen3OmniMOEThinkerTransformerLayerWeight,
+)
+
+from lightllm.models.qwen3_vl_moe.model import Qwen3VLMOETpPartModel
+from lightllm.models.qwen3_omni_moe_thinker.infer_struct import Qwen3OmniMOEInferStateInfo
+from lightllm.models.qwen3_vl.model import QWen3VLTokenizer
+from lightllm.server.core.objs import SamplingParams
+from lightllm.server.multimodal_params import AudioItem, MultimodalParams, ImageItem
+
+
+MIN_AUDIO_LEN = 480
+
+
+class QWen3OmniTokenizer(QWen3VLTokenizer):
+ def __init__(self, tokenizer=None, processor=None, **kwargs):
+ self.tokenizer = tokenizer
+ # image
+ self.image_processor = processor.image_processor
+ self.min_pixel = self.image_processor.min_pixels
+ self.max_pixel = self.image_processor.max_pixels
+ self.patch_size = self.image_processor.patch_size
+ self.merge_size = self.image_processor.merge_size
+
+ # audio
+ self.audio_processor = processor.feature_extractor
+ self.sampling_rate = self.audio_processor.sampling_rate
+ self.n_samples = self.audio_processor.n_samples
+ self.hop_length = self.audio_processor.hop_length
+
+ self.image_start_id = kwargs["model_cfg"]["vision_start_token_id"]
+ self.image_end_id = kwargs["model_cfg"]["vision_end_token_id"]
+ self.image_token_id = kwargs["model_cfg"]["image_token_id"]
+
+ self.audio_start_id = kwargs["model_cfg"]["audio_start_token_id"]
+ self.audio_end_id = kwargs["model_cfg"]["audio_end_token_id"]
+ self.audio_token_id = kwargs["model_cfg"]["audio_token_id"]
+
+ def init_audioitem_extral_params(
+ self, audio: AudioItem, multi_params: MultimodalParams, sampling_params: SamplingParams
+ ):
+ return
+
+ def get_audio_token_length(self, audio: AudioItem):
+ # 这里得处理对应奖语音长度按照 30 进行限制,后续处理中,超过30的会被截断。
+ length = min(audio.audio_length, int(self.n_samples))
+ token_num = self._caclu_audio_token_num(length)
+ return token_num
+
+ def _caclu_audio_token_num(self, input_audio_len: int):
+ _mel_len = input_audio_len // int(self.hop_length)
+ input_lengths_leave = _mel_len % 100
+ feat_lengths = (input_lengths_leave - 1) // 2 + 1
+ output_lengths = ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (_mel_len // 100) * 13
+ return output_lengths
+
+ def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs):
+ origin_ids = self.tokenizer.encode(prompt)
+
+ #
->
+ origin_ids = [token for token in origin_ids if token not in (self.image_token_id, self.audio_token_id)]
+ #
-->
id,id+1...id+num
+ input_ids = []
+ image_id = 0
+ while True:
+ try:
+ start_idx = origin_ids.index(self.image_start_id)
+ if start_idx + 1 >= len(origin_ids):
+ break
+ if origin_ids[start_idx + 1] == self.image_end_id:
+ input_ids.extend(origin_ids[: start_idx + 1])
+ token_id = multimodal_params.images[image_id].token_id
+ token_num = multimodal_params.images[image_id].token_num
+ multimodal_params.images[image_id].start_idx = len(input_ids)
+ input_ids.extend(range(token_id, token_id + token_num))
+ input_ids.append(self.image_end_id)
+ origin_ids = origin_ids[start_idx + 2 :]
+ image_id += 1
+ else:
+ raise ValueError("image token error")
+ except ValueError:
+ break
+ if multimodal_params:
+ image_cnt = len(multimodal_params.images)
+ if image_cnt != image_id:
+ raise ValueError(image_cnt == image_id, f"invalid image tag num: {image_cnt} vs {image_id}!")
+ input_ids.extend(origin_ids)
+
+ # audio
+ origin_ids = input_ids
+ input_ids = []
+ audio_id = 0
+ start_idx = 0
+ while True:
+ try:
+ start_idx = origin_ids.index(self.audio_start_id)
+ if start_idx + 1 >= len(origin_ids):
+ break
+ if origin_ids[start_idx + 1] == self.audio_end_id:
+ input_ids.extend(origin_ids[: start_idx + 1])
+ token_id = multimodal_params.audios[audio_id].token_id
+ token_num = multimodal_params.audios[audio_id].token_num
+ input_ids.extend(range(token_id, token_id + token_num))
+ input_ids.append(self.audio_end_id)
+ origin_ids = origin_ids[start_idx + 2 :]
+ audio_id += 1
+ else:
+ raise ValueError("audio token error")
+ except ValueError:
+ break
+ if multimodal_params:
+ audio_cnt = len(multimodal_params.audios)
+ if audio_cnt != audio_id:
+ raise ValueError(audio_cnt == audio_id, f"invalid audio tag num: {audio_cnt} vs {audio_id}!")
+ input_ids.extend(origin_ids)
+
+ return input_ids
+
+
+@ModelRegistry(["qwen3_omni_moe"], is_multimodal=True)
+class Qwen3OmniMOETpPartModel(Qwen3VLMOETpPartModel):
+
+ pre_layer_infer_class = Qwen3VLMultimodalPreLayerInfer
+ transformer_layer_infer_class = Qwen3OmniMOETransformerLayerInfer
+
+ pre_and_post_weight_class = Qwen3OmniMOEThinkerPreAndPostLayerWeight
+ transformer_weight_class = Qwen3OmniMOEThinkerTransformerLayerWeight
+
+ infer_state_class = Qwen3OmniMOEInferStateInfo
+
+ def __init__(self, kvargs):
+ super().__init__(kvargs)
+ return
+
+ def _init_config(self):
+ with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file:
+ all_config = json.load(json_file)
+ self.config = all_config["thinker_config"]["text_config"]
+ # rename keys
+ repair_config(self.config, same_names=["num_attention_heads", "n_head"])
+ repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"])
+ repair_config(self.config, same_names=["num_hidden_layers", "n_layer"])
+ if self.finetune_config:
+ self.config["vocab_size"] = self.finetune_config.vocab_size
+ return
diff --git a/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_audio.py b/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_audio.py
new file mode 100644
index 000000000..c66033f53
--- /dev/null
+++ b/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_audio.py
@@ -0,0 +1,388 @@
+import os
+import json
+import math
+import torch
+import rpyc
+import librosa
+import numpy as np
+from io import BytesIO
+from torch import Tensor, nn
+from safetensors import safe_open
+from torch.nn import functional as F
+from typing import Callable, Optional, Union, List
+from rpyc.utils.classic import obtain
+
+from transformers.activations import ACT2FN
+
+from lightllm.server.multimodal_params import AudioItem
+from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data
+from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient
+from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
+from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd
+from lightllm.models.qwen3_omni_moe_thinker.audio_process import WhisperFeatureExtractor
+
+
+def _get_feat_extract_output_lengths(input_lengths):
+ """
+ Computes the output length of the convolutional layers and the output length of the audio encoder
+ """
+
+ input_lengths_leave = input_lengths % 100
+ feat_lengths = (input_lengths_leave - 1) // 2 + 1
+ output_lengths = ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
+ return output_lengths
+
+
+class Qwen3OmniMoeAudioEncoderLayer(nn.Module):
+ def __init__(
+ self,
+ d_model,
+ encoder_attention_heads,
+ attention_dropout,
+ dropout,
+ activation_function,
+ activation_dropout,
+ encoder_ffn_dim,
+ ):
+ super().__init__()
+ self.embed_dim = d_model
+ self.self_attn = Qwen3OmniMoeAudioAttention(d_model, encoder_attention_heads, attention_dropout)
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+ self.dropout = dropout
+ self.activation_fn = ACT2FN[activation_function]
+ self.activation_dropout = activation_dropout
+ self.fc1 = nn.Linear(self.embed_dim, encoder_ffn_dim)
+ self.fc2 = nn.Linear(encoder_ffn_dim, self.embed_dim)
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ max_seqlen: int,
+ **kwargs,
+ ) -> torch.Tensor:
+ residual = hidden_states
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+ hidden_states = self.self_attn(
+ hidden_states=hidden_states,
+ cu_seqlens=cu_seqlens,
+ max_seqlen=max_seqlen,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+ residual = hidden_states
+ hidden_states = self.final_layer_norm(hidden_states)
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = residual + hidden_states
+
+ if hidden_states.dtype == torch.float16:
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
+
+ outputs = (hidden_states,)
+
+ return outputs
+
+
+class Qwen3OmniMoeAudioAttention(nn.Module):
+ def __init__(self, d_model, encoder_attention_heads, attention_dropout):
+ super().__init__()
+ self.embed_dim = d_model
+ self.num_heads = encoder_attention_heads
+ self.dropout = attention_dropout
+ self.head_dim = self.embed_dim // self.num_heads
+ self.num_key_value_groups = 1 # needed for eager attention
+
+ if (self.head_dim * self.num_heads) != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
+ f" and `num_heads`: {self.num_heads})."
+ )
+ self.scaling = self.head_dim ** -0.5
+ self.attention_dropout = 0.0
+ self.is_decoder = False
+ self.is_causal = False
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cu_seqlens: Optional[torch.Tensor] = None,
+ max_seqlen: int = 0,
+ **kwargs,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ seq_length, _ = hidden_states.size()
+
+ q = self.q_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
+ k = self.k_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
+ v = self.v_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
+
+ attn_output = g_cache_manager.alloc_tensor(q.shape, q.dtype, device=q.device)
+
+ flash_attention_fwd(q, k, v, attn_output, cu_seqlens, max_seqlen)
+
+ attn_output = attn_output.reshape(seq_length, -1).contiguous()
+ attn_output = self.out_proj(attn_output)
+ return attn_output
+
+
+class SinusoidsPositionEmbedding(nn.Module):
+ def __init__(self, length, channels, max_timescale=10000):
+ super().__init__()
+ if channels % 2 != 0:
+ raise ValueError("SinusoidsPositionEmbedding needs even channels input")
+ log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
+ inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2).float())
+ scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
+ self.positional_embedding = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
+
+ def forward(self, seqlen: int):
+ return self.positional_embedding[:seqlen, :]
+
+
+class Qwen3OmniMoeAudioEncoder(nn.Module):
+ def __init__(
+ self,
+ kvargs,
+ dropout=0,
+ d_model=1280,
+ num_mel_bins=128,
+ max_source_positions=1500,
+ scale_embedding=False,
+ n_window=50,
+ encoder_layers=32,
+ downsample_hidden_size=480,
+ activation_function="gelu",
+ output_dim=2048,
+ n_window_infer=800,
+ conv_chunksize=500,
+ encoder_attention_heads=20,
+ attention_dropout=0,
+ activation_dropout=0,
+ encoder_ffn_dim=5120,
+ ):
+ super().__init__()
+ self.data_type = kvargs.get("data_type", "bfloat16")
+ self.dropout = dropout
+ self.embed_dim = d_model
+ self.num_mel_bins = num_mel_bins
+ self.max_source_positions = max_source_positions
+ self.embed_scale = math.sqrt(self.embed_dim) if scale_embedding else 1.0
+ self.n_window = n_window
+ self.positional_embedding = SinusoidsPositionEmbedding(self.max_source_positions, self.embed_dim)
+ self.layers = nn.ModuleList(
+ [
+ Qwen3OmniMoeAudioEncoderLayer(
+ d_model,
+ encoder_attention_heads,
+ attention_dropout,
+ dropout,
+ activation_function,
+ activation_dropout,
+ encoder_ffn_dim,
+ )
+ for _ in range(encoder_layers)
+ ]
+ )
+ self.ln_post = nn.LayerNorm(d_model)
+ self.gradient_checkpointing = False
+ self.conv2d1 = nn.Conv2d(1, downsample_hidden_size, 3, 2, padding=1)
+ self.conv2d2 = nn.Conv2d(downsample_hidden_size, downsample_hidden_size, 3, 2, padding=1)
+ self.conv2d3 = nn.Conv2d(downsample_hidden_size, downsample_hidden_size, 3, 2, padding=1)
+ self.conv_out = nn.Linear(
+ downsample_hidden_size * ((((num_mel_bins + 1) // 2 + 1) // 2 + 1) // 2),
+ d_model,
+ bias=False,
+ )
+ self.proj1 = nn.Linear(d_model, d_model)
+ self.act = ACT2FN[activation_function]
+ self.proj2 = nn.Linear(d_model, output_dim)
+ self.n_window_infer = n_window_infer
+ self.conv_chunksize = conv_chunksize
+
+ self.cache_port = kvargs["cache_port"]
+ self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True})
+ self._init_datatype()
+
+ def _init_datatype(self):
+ if isinstance(self.data_type, torch.dtype):
+ return
+ if self.data_type in ["fp16", "float16"]:
+ self.data_type = torch.float16
+ elif self.data_type in ["bf16", "bfloat16"]:
+ self.data_type = torch.bfloat16
+ elif self.data_type in ["fp32", "float32"]:
+ self.data_type = torch.float32
+ else:
+ raise ValueError(f"Unsupport datatype {self.data_type}!")
+ return
+
+ def _freeze_parameters(self):
+ for param in self.parameters():
+ param.requires_grad = False
+ self._requires_grad = False
+
+ def get_input_embeddings(self) -> nn.Module:
+ return self.conv1
+
+ def set_input_embeddings(self, value: nn.Module):
+ self.conv1 = value
+
+ def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
+ """
+ Computes the output length of the convolutional layers and the output length of the audio encoder
+ """
+ input_lengths = (input_lengths - 1) // 2 + 1
+ output_lengths = (input_lengths - 2) // 2 + 1
+ return input_lengths, output_lengths
+
+ def load_model(self, weight_dir, config):
+ processor_config_path = os.path.join(weight_dir, "preprocessor_config.json")
+ with open(processor_config_path, "r") as f:
+ processor_config_dict = json.load(f)
+ self.processor = WhisperFeatureExtractor(**processor_config_dict)
+
+ bin_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")]
+ if bin_weight_files:
+ weight_dict = {}
+ for file_ in bin_weight_files:
+ f = torch.load(os.path.join(weight_dir, file_), "cpu")
+ for k, v in f.items():
+ if "thinker.audio_tower" in k:
+ weight_dict[k[len("thinker.audio_tower.") :]] = v
+ else:
+ hf_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".safetensors")]
+ weight_dict = {}
+ for file_ in hf_weight_files:
+ f = safe_open(os.path.join(weight_dir, file_), "pt", "cpu")
+ for k in f.keys():
+ if "thinker.audio_tower" in k:
+ weight_dict[k[len("thinker.audio_tower.") :]] = f.get_tensor(k)
+
+ self.load_state_dict(weight_dict)
+
+ def forward(
+ self,
+ input_features,
+ feature_lens=None,
+ aftercnn_lens=None,
+ ):
+ aftercnn_lens = _get_feat_extract_output_lengths(feature_lens)
+ chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long()
+
+ chunk_lengths = torch.tensor(
+ [self.n_window * 2] * chunk_num.sum(),
+ dtype=torch.long,
+ device=feature_lens.device,
+ )
+ tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:]
+ chunk_lengths[tail_chunk_index] = feature_lens % (self.n_window * 2)
+ chunk_lengths[chunk_lengths == 0] = self.n_window * 2
+
+ chunk_list = input_features.T.split(chunk_lengths.tolist(), dim=0)
+ padded_feature = nn.utils.rnn.pad_sequence(chunk_list, batch_first=True).transpose(1, 2)
+ feature_lens_after_cnn = _get_feat_extract_output_lengths(chunk_lengths)
+ padded_mask_after_cnn = nn.utils.rnn.pad_sequence(
+ [torch.ones(length, dtype=torch.bool, device=padded_feature.device) for length in feature_lens_after_cnn],
+ batch_first=True,
+ )
+ padded_feature = padded_feature.unsqueeze(1)
+ # Split to chunk to avoid OOM during convolution
+ padded_embeds = []
+ for chunk in padded_feature.split(self.conv_chunksize, dim=0):
+ padded_embed = F.gelu(self.conv2d1(chunk))
+ padded_embed = F.gelu(self.conv2d2(padded_embed))
+ padded_embed = F.gelu(self.conv2d3(padded_embed))
+ padded_embeds.append(padded_embed)
+ padded_embed = torch.cat(padded_embeds, dim=0)
+ b, c, f, t = padded_embed.size()
+ padded_embed = self.conv_out(padded_embed.permute(0, 3, 1, 2).contiguous().view(b, t, c * f))
+
+ positional_embedding = (
+ self.positional_embedding.positional_embedding[: padded_embed.shape[1], :]
+ .unsqueeze(0)
+ .to(padded_embed.dtype)
+ .to(padded_embed.device)
+ )
+ padded_embed = padded_embed + positional_embedding
+ hidden_states = padded_embed[padded_mask_after_cnn]
+ cu_chunk_lens = [0]
+ window_aftercnn = padded_mask_after_cnn.shape[-1] * (self.n_window_infer // (self.n_window * 2))
+ for cnn_len in aftercnn_lens:
+ cu_chunk_lens += [window_aftercnn] * (cnn_len // window_aftercnn)
+ remainder = cnn_len % window_aftercnn
+ if remainder != 0:
+ cu_chunk_lens += [remainder]
+ cu_seqlens = torch.tensor(cu_chunk_lens, device=aftercnn_lens.device).cumsum(-1, dtype=torch.int32)
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
+ for encoder_layer in self.layers:
+ layer_outputs = encoder_layer(
+ hidden_states,
+ cu_seqlens,
+ max_seqlen,
+ )
+ hidden_states = layer_outputs[0]
+
+ hidden_states = self.ln_post(hidden_states)
+ hidden_states = self.proj1(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.proj2(hidden_states)
+ return hidden_states
+
+ def encode(self, audio_items: List[AudioItem], cpu_embed_cache_client: CpuEmbedCacheClient):
+ uuids = []
+ items: List[AudioItem] = []
+ per_audio_features: List[torch.Tensor] = []
+ for i, item in enumerate(audio_items):
+ if isinstance(item, AudioItem):
+ uuids.append(item.uuid)
+ items.append(item)
+ audio_data = read_shm(get_shm_name_data(item.uuid))
+ audio = BytesIO(audio_data)
+ audio, _ = librosa.load(audio, sr=self.processor.sampling_rate)
+ else:
+ raise ValueError(f"cannot read audio which type is {type(item)}!")
+
+ input_features, feature_attention_mask = self.processor._preprocess(audio, return_attention_mask=True)
+ if feature_attention_mask is not None:
+ audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
+ input_features = input_features.permute(0, 2, 1)[feature_attention_mask.bool()].permute(1, 0)
+ else:
+ audio_feature_lengths = None
+
+ feature_lens = (
+ audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1)
+ )
+
+ audio_features = self.forward(
+ input_features,
+ feature_lens=feature_lens,
+ )
+ per_audio_features.append(audio_features)
+
+ ready_audio = obtain(self.cache_client.root.get_items_embed(uuids))
+ ids_to_set = []
+ for i, ready in enumerate(ready_audio):
+ if ready:
+ continue
+
+ uid = uuids[i]
+ item = items[i]
+
+ cur_embed = per_audio_features[i]
+ cpu_embed_cache_client.copy_to_cache(
+ embed_tensor=cur_embed, start_index_in_cache=item.start_index_in_embed_cache
+ )
+ ids_to_set.append(uid)
+
+ if ids_to_set:
+ self.cache_client.root.set_items_embed(ids=ids_to_set)
+ torch.cuda.current_stream().synchronize()
diff --git a/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py b/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py
new file mode 100644
index 000000000..dd9b54ee8
--- /dev/null
+++ b/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py
@@ -0,0 +1,408 @@
+# coding=utf-8
+# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import json
+import time
+from PIL import Image
+from io import BytesIO
+from typing import List
+from safetensors import safe_open
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from transformers.activations import ACT2FN
+
+from lightllm.server.multimodal_params import ImageItem
+from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data
+from lightllm.models.qwen2_vl.vision_process import resize_image, Qwen2VLImageProcessor
+from lightllm.models.qwen2_vl.qwen2_visual import VisionRotaryEmbedding, VisionFlashAttention
+
+
+class Qwen3OmniMoeVisionMLP(nn.Module):
+ def __init__(self, hidden_size: int, intermediate_size: int, hidden_act: str):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.linear_fc1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=True)
+ self.linear_fc2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=True)
+ self.act_fn = ACT2FN[hidden_act]
+
+ def forward(self, hidden_state):
+ return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state)))
+
+
+class Qwen3OmniMoeVisionPatchEmbed(nn.Module):
+ def __init__(
+ self,
+ patch_size: int = 14,
+ temporal_patch_size: int = 2,
+ in_channels: int = 3,
+ embed_dim: int = 1152,
+ ) -> None:
+ super().__init__()
+ self.patch_size = patch_size
+ self.temporal_patch_size = temporal_patch_size
+ self.in_channels = in_channels
+ self.embed_dim = embed_dim
+
+ kernel_size = [temporal_patch_size, patch_size, patch_size]
+ self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ target_dtype = self.proj.weight.dtype
+ if hidden_states.dtype != target_dtype:
+ hidden_states = hidden_states.to(target_dtype)
+ hidden_states = hidden_states.view(
+ -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
+ )
+ hidden_states = self.proj(hidden_states).view(-1, self.embed_dim)
+ return hidden_states
+
+
+class Qwen3OmniMoeVisionPatchMerger(nn.Module):
+ def __init__(self, hidden_size, out_hidden_size, spatial_merge_size, use_postshuffle_norm=False) -> None:
+ super().__init__()
+ self.hidden_size = hidden_size * (spatial_merge_size ** 2)
+ self.use_postshuffle_norm = use_postshuffle_norm
+ self.ln_q = nn.LayerNorm(self.hidden_size if use_postshuffle_norm else hidden_size, eps=1e-6)
+ self.mlp = nn.ModuleList(
+ [
+ nn.Linear(self.hidden_size, self.hidden_size),
+ nn.GELU(),
+ nn.Linear(self.hidden_size, out_hidden_size),
+ ]
+ )
+
+ def forward(self, hidden: torch.Tensor) -> torch.Tensor:
+ hidden = self.ln_q(hidden.view(-1, self.hidden_size) if self.use_postshuffle_norm else hidden).view(
+ -1, self.hidden_size
+ )
+ for layer in self.mlp:
+ hidden = layer(hidden)
+ return hidden
+
+
+class Qwen3OmniMoeVisionBlock(nn.Module):
+ def __init__(self, hidden_size, intermediate_size, num_heads, hidden_act) -> None:
+ super().__init__()
+ self.norm1 = nn.LayerNorm(hidden_size, eps=1e-6)
+ self.norm2 = nn.LayerNorm(hidden_size, eps=1e-6)
+
+ self.attn = VisionFlashAttention(hidden_size, num_heads=num_heads)
+ self.mlp = Qwen3OmniMoeVisionMLP(hidden_size, intermediate_size, hidden_act)
+
+ def forward(self, hidden_states, cu_seqlens, max_seqlen, rotary_cos, rotary_sin) -> torch.Tensor:
+ hidden_states = hidden_states + self.attn(
+ self.norm1(hidden_states),
+ cu_seqlens=cu_seqlens,
+ max_seqlen=max_seqlen,
+ rotary_cos=rotary_cos,
+ rotary_sin=rotary_sin,
+ )
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
+ return hidden_states
+
+
+class Qwen3OmniMoeVisionTransformerPretrainedModel(nn.Module):
+ def __init__(
+ self,
+ kvargs,
+ depth=27,
+ out_hidden_size=4096,
+ hidden_size=1152,
+ hidden_act="gelu_pytorch_tanh",
+ intermediate_size=4304,
+ deepstack_visual_indexes=[8, 16, 24],
+ num_heads=16,
+ in_channels=3,
+ patch_size=16,
+ spatial_merge_size=2,
+ temporal_patch_size=2,
+ num_position_embeddings=2304,
+ **kwargs,
+ ):
+ super().__init__()
+ self.data_type = kvargs.get("data_type", "bfloat16")
+
+ self.depth = depth
+ self.out_hidden_size = out_hidden_size
+ self.hidden_size = hidden_size
+ self.hidden_act = hidden_act
+ self.num_heads = num_heads
+ self.in_channels = in_channels
+ self.patch_size = patch_size
+ self.intermediate_size = intermediate_size
+ self.spatial_merge_size = spatial_merge_size
+ self.temporal_patch_size = temporal_patch_size
+ self.num_position_embeddings = num_position_embeddings
+ self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
+
+ self.deepstack_visual_indexes = deepstack_visual_indexes
+ self.merger_list = nn.ModuleList(
+ [
+ Qwen3OmniMoeVisionPatchMerger(
+ hidden_size=self.hidden_size,
+ out_hidden_size=self.out_hidden_size,
+ spatial_merge_size=self.spatial_merge_size,
+ use_postshuffle_norm=True,
+ )
+ for _ in range(len(self.deepstack_visual_indexes))
+ ]
+ )
+
+ self.patch_embed = Qwen3OmniMoeVisionPatchEmbed(
+ patch_size=self.patch_size,
+ temporal_patch_size=self.temporal_patch_size,
+ in_channels=self.in_channels,
+ embed_dim=self.hidden_size,
+ )
+
+ self.pos_embed = nn.Embedding(self.num_position_embeddings, self.hidden_size)
+ self.num_grid_per_side = int(self.num_position_embeddings ** 0.5)
+
+ head_dim = self.hidden_size // self.num_heads
+ self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2).cuda()
+
+ self.blocks = nn.ModuleList(
+ [
+ Qwen3OmniMoeVisionBlock(self.hidden_size, self.intermediate_size, self.num_heads, self.hidden_act)
+ for _ in range(self.depth)
+ ]
+ )
+ self.merger = Qwen3OmniMoeVisionPatchMerger(
+ hidden_size=self.hidden_size,
+ out_hidden_size=self.out_hidden_size,
+ spatial_merge_size=self.spatial_merge_size,
+ use_postshuffle_norm=False,
+ )
+ self._init_datatype()
+
+ def _init_datatype(self):
+ if isinstance(self.data_type, torch.dtype):
+ return
+ if self.data_type in ["fp16", "float16"]:
+ self.data_type = torch.float16
+ elif self.data_type in ["bf16", "bfloat16"]:
+ self.data_type = torch.bfloat16
+ elif self.data_type in ["fp32", "float32"]:
+ self.data_type = torch.float32
+ else:
+ raise ValueError(f"Unsupport datatype {self.data_type}!")
+ return
+
+ def concat_img_embed_and_deepstack_features(self, image_embed, deepstack_feature_lists, valid_ids):
+ all_chunks = []
+
+ for start, end in valid_ids:
+ hs_i = image_embed[start:end]
+ ds_i_list = [feat[start:end] for feat in deepstack_feature_lists]
+ combined_i = torch.cat([hs_i, *ds_i_list], dim=1).view((end - start), len(ds_i_list) + 1, hs_i.shape[-1])
+ all_chunks.append(combined_i)
+
+ all_img_embeds_ds = torch.cat(all_chunks, dim=0)
+ return all_img_embeds_ds, valid_ids
+
+ def load_model(self, weight_dir):
+
+ processor_config_path = os.path.join(weight_dir, "preprocessor_config.json")
+ with open(processor_config_path, "r") as f:
+ processor_config_dict = json.load(f)
+ self.processor = Qwen2VLImageProcessor(**processor_config_dict)
+
+ bin_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")]
+ if bin_weight_files:
+ weight_dict = {}
+ for file_ in bin_weight_files:
+ f = torch.load(os.path.join(weight_dir, file_), "cpu")
+ for k, v in f.items():
+ if "thinker.visual" in k:
+ weight_dict[k[len("thinker.visual.") :]] = v
+ else:
+ hf_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".safetensors")]
+ weight_dict = {}
+ for file_ in hf_weight_files:
+ f = safe_open(os.path.join(weight_dir, file_), "pt", "cpu")
+ for k in f.keys():
+ if "thinker.visual" in k:
+ weight_dict[k[len("thinker.visual.") :]] = f.get_tensor(k)
+
+ self.load_state_dict(weight_dict)
+
+ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
+ merge_size = self.spatial_merge_size
+
+ total_tokens = int(torch.prod(grid_thw, dim=1).sum().item())
+ pos_ids = torch.empty((total_tokens, 2), dtype=torch.long)
+
+ offset = 0
+ for num_frames, height, width in grid_thw:
+ merged_h, merged_w = height // merge_size, width // merge_size
+
+ block_rows = torch.arange(merged_h) # block row indices
+ block_cols = torch.arange(merged_w) # block col indices
+ intra_row = torch.arange(merge_size) # intra-block row offsets
+ intra_col = torch.arange(merge_size) # intra-block col offsets
+
+ # Compute full-resolution positions
+ row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None]
+ col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :]
+
+ row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)
+ col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)
+
+ coords = torch.stack((row_idx, col_idx), dim=-1)
+
+ if num_frames > 1:
+ coords = coords.repeat(num_frames, 1)
+
+ num_tokens = coords.shape[0]
+ pos_ids[offset : offset + num_tokens] = coords
+ offset += num_tokens
+
+ max_hw = int(grid_thw[:, 1:].max().item())
+ cos_full, sin_full = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2)
+ cos = cos_full[pos_ids].flatten(1)
+ sin = sin_full[pos_ids].flatten(1)
+ return cos, sin
+
+ def fast_pos_embed_interpolate(self, grid_thw):
+ grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2]
+
+ idx_list = [[] for _ in range(4)]
+ weight_list = [[] for _ in range(4)]
+
+ for t, h, w in zip(grid_ts, grid_hs, grid_ws):
+ h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h)
+ w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w)
+
+ h_idxs_floor = h_idxs.int()
+ w_idxs_floor = w_idxs.int()
+ h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
+ w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
+
+ dh = h_idxs - h_idxs_floor
+ dw = w_idxs - w_idxs_floor
+
+ base_h = h_idxs_floor * self.num_grid_per_side
+ base_h_ceil = h_idxs_ceil * self.num_grid_per_side
+
+ indices = [
+ (base_h[None].T + w_idxs_floor[None]).flatten(),
+ (base_h[None].T + w_idxs_ceil[None]).flatten(),
+ (base_h_ceil[None].T + w_idxs_floor[None]).flatten(),
+ (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(),
+ ]
+
+ weights = [
+ ((1 - dh)[None].T * (1 - dw)[None]).flatten(),
+ ((1 - dh)[None].T * dw[None]).flatten(),
+ (dh[None].T * (1 - dw)[None]).flatten(),
+ (dh[None].T * dw[None]).flatten(),
+ ]
+
+ for i in range(4):
+ idx_list[i].extend(indices[i].tolist())
+ weight_list[i].extend(weights[i].tolist())
+
+ idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=self.pos_embed.weight.device)
+ weight_tensor = torch.tensor(
+ weight_list, dtype=self.pos_embed.weight.dtype, device=self.pos_embed.weight.device
+ )
+ pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None]
+ patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3]
+
+ patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)])
+
+ patch_pos_embeds_permute = []
+ merge_size = self.spatial_merge_size
+ for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws):
+ pos_embed = pos_embed.repeat(t, 1)
+ pos_embed = (
+ pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1)
+ .permute(0, 1, 3, 2, 4, 5)
+ .flatten(0, 4)
+ )
+ patch_pos_embeds_permute.append(pos_embed)
+ patch_pos_embeds = torch.cat(patch_pos_embeds_permute)
+ return patch_pos_embeds
+
+ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
+ hidden_states = self.patch_embed(hidden_states)
+ pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
+ hidden_states = hidden_states + pos_embeds
+ rotary_cos, rotary_sin = self.rot_pos_emb(grid_thw)
+ rotary_cos = rotary_cos.to("cuda", non_blocking=True)
+ rotary_sin = rotary_sin.to("cuda", non_blocking=True)
+ cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
+ dim=0,
+ dtype=torch.int32,
+ )
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0).to("cuda", non_blocking=True)
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
+ deepstack_feature_lists = []
+ for layer_num, blk in enumerate(self.blocks):
+ hidden_states = blk(
+ hidden_states,
+ cu_seqlens=cu_seqlens,
+ max_seqlen=max_seqlen,
+ rotary_cos=rotary_cos,
+ rotary_sin=rotary_sin,
+ )
+ if layer_num in self.deepstack_visual_indexes:
+ deepstack_feature = self.merger_list[self.deepstack_visual_indexes.index(layer_num)](hidden_states)
+ deepstack_feature_lists.append(deepstack_feature)
+
+ hidden_states = self.merger(hidden_states)
+ print(f"hidden_states is {hidden_states}, deepstack is {deepstack_feature_lists}")
+ return hidden_states, deepstack_feature_lists
+
+ def encode(self, images: List[ImageItem]):
+ img_tensors = []
+ valid_ids = []
+ valid_id = 0
+ img_grids = []
+ uuids = []
+ for i, img in enumerate(images):
+ if isinstance(img, ImageItem):
+ uuids.append(img.uuid)
+ image_data = read_shm(get_shm_name_data(img.uuid))
+ image_data = Image.open(BytesIO(image_data))
+ pixel_values, image_grid_thw = self.processor.preprocess(image_data)
+ img_tensors.append(pixel_values)
+ img_grids.append(image_grid_thw)
+ else:
+ raise Exception("Unsupport input types: {} for {}".format(type(img), img))
+
+ # must devide merge_length
+ cur_num = img_tensors[-1].shape[0] // (self.spatial_merge_size ** 2)
+
+ valid_ids.append([valid_id, valid_id + cur_num])
+ valid_id += cur_num
+
+ if len(img_tensors) <= 0:
+ return None
+
+ imgs = torch.cat(img_tensors, dim=0)
+ grid_thw = torch.cat(img_grids, dim=0)
+
+ pixel_values = imgs.to("cuda", dtype=self.data_type, non_blocking=True)
+ image_grid_thw = grid_thw.to("cuda", non_blocking=True)
+ img_embeds, deepstack_feature_lists = self.forward(pixel_values, grid_thw=image_grid_thw)
+ all_img_embeds_df, valid_ids = self.concat_img_embed_and_deepstack_features(
+ img_embeds, deepstack_feature_lists, valid_ids
+ )
+
+ return all_img_embeds_df, uuids, valid_ids
diff --git a/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py
index 4ccc6da37..0bf716fc5 100644
--- a/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py
+++ b/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py
@@ -23,8 +23,10 @@ def _get_qkv(
layer_weight: Qwen3MOETransformerLayerWeight,
) -> Tuple[torch.Tensor, torch.Tensor]:
input = input.view(-1, self.embed_dim_)
- q = layer_weight.q_proj.mm(input)
- cache_kv = layer_weight.kv_proj.mm(input)
+ qkv = layer_weight.qkv_proj.mm(input)
+ q, cache_kv = qkv.split(
+ [self.tp_q_head_num_ * self.head_dim_, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_], dim=-1
+ )
layer_weight.q_norm_weight_(
q,
eps=self.eps_,
diff --git a/lightllm/server/audioserver/model_infer/model_rpc.py b/lightllm/server/audioserver/model_infer/model_rpc.py
index f3c813800..ad18ffd7f 100644
--- a/lightllm/server/audioserver/model_infer/model_rpc.py
+++ b/lightllm/server/audioserver/model_infer/model_rpc.py
@@ -4,6 +4,7 @@
from typing import Dict, List, Tuple
from transformers.configuration_utils import PretrainedConfig
from lightllm.models.whisper.whisper_audio import WhisperAudioModel
+from lightllm.models.qwen3_omni_moe_thinker.qwen3_omni_audio import Qwen3OmniMoeAudioEncoder
from lightllm.server.multimodal_params import AudioItem
from lightllm.utils.infer_utils import set_random_seed
from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient
@@ -19,6 +20,9 @@ def exposed_init_model(self, kvargs):
weight_dir = kvargs["weight_dir"]
model_cfg, _ = PretrainedConfig.get_config_dict(weight_dir)
+ if model_cfg.get("thinker_config") is not None:
+ model_cfg = model_cfg["thinker_config"]
+
audio_config = model_cfg["audio_config"]
model_kvargs = {"cache_port": kvargs["cache_port"], "data_type": kvargs["data_type"]}
@@ -26,6 +30,8 @@ def exposed_init_model(self, kvargs):
self.model_type = audio_config["model_type"]
if self.model_type == "clap_audio_model" or self.model_type == "whisper":
self.model = WhisperAudioModel(model_kvargs)
+ elif self.model_type == "qwen3_omni_moe_audio_encoder":
+ self.model = Qwen3OmniMoeAudioEncoder(model_kvargs).eval().bfloat16()
else:
raise Exception(f"can not support {self.model_type} now")
diff --git a/lightllm/server/embed_cache/embed_cache_client.py b/lightllm/server/embed_cache/embed_cache_client.py
index ad01dd808..8c5b7f71e 100644
--- a/lightllm/server/embed_cache/embed_cache_client.py
+++ b/lightllm/server/embed_cache/embed_cache_client.py
@@ -47,6 +47,19 @@ def copy_to_cache(self, embed_tensor: torch.Tensor, start_index_in_cache: int):
start_index_in_cache=start_index_in_cache,
)
+ def copy_vision_to_cache(self, embed_tensor: torch.Tensor, start_index_in_cache: int):
+ from .copy_to_cache import offload_embed_tensor_to_cache
+
+ if embed_tensor.ndim == 3:
+ # check for qwen3 vision embed tensor shape, use apply deepstack
+ assert embed_tensor.shape[1] == self.cpu_embed_cache_tensor.shape[1]
+
+ offload_embed_tensor_to_cache(
+ embed_tensor=embed_tensor,
+ cache_tensor=self.cpu_embed_cache_tensor,
+ start_index_in_cache=start_index_in_cache,
+ )
+
def _create_shm_embed_kv_cache(self):
shm_ptr = create_shm_kv_cache_ptr(
key=self.args.multi_modal_cache_shm_id, size=self.embed_cache_tensor_meta.calcu_size()
diff --git a/lightllm/server/tokenizer.py b/lightllm/server/tokenizer.py
index e0b2bd425..09bc938f2 100644
--- a/lightllm/server/tokenizer.py
+++ b/lightllm/server/tokenizer.py
@@ -31,6 +31,7 @@
from ..models.qwen3_vl.model import QWen3VLTokenizer
from ..models.internvl.model import InternvlTokenizer
from ..models.gemma3.model import Gemma3Tokenizer
+from ..models.qwen3_omni_moe_thinker.model import QWen3OmniTokenizer
# A fast LLaMA tokenizer with the pre-processed `tokenizer.json` file.
_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"
@@ -100,6 +101,12 @@ def get_tokenizer(
tokenizer = QWen3VLTokenizer(
tokenizer=tokenizer, image_processor=processor.image_processor, model_cfg=model_cfg
)
+ elif model_cfg.get("thinker_config") is not None:
+ from transformers import AutoProcessor
+
+ model_cfg = model_cfg["thinker_config"]
+ processor = AutoProcessor.from_pretrained(tokenizer_name)
+ tokenizer = QWen3OmniTokenizer(tokenizer, processor=processor, model_cfg=model_cfg)
elif model_type == "internvl_chat":
tokenizer = InternvlTokenizer(tokenizer, model_cfg, weight_dir=tokenizer_name)
elif model_type == "gemma3":
diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py
index 8f4a1ee45..3e97f4de3 100644
--- a/lightllm/server/visualserver/model_infer/model_rpc.py
+++ b/lightllm/server/visualserver/model_infer/model_rpc.py
@@ -19,6 +19,7 @@
from lightllm.models.qwen2_5_vl.qwen2_5_visual import Qwen2_5_VisionTransformerPretrainedModel
from lightllm.models.qwen3_vl.qwen3_visual import Qwen3VisionTransformerPretrainedModel
from lightllm.models.tarsier2.tarsier2_visual import TarsierVisionTransformerPretrainedModel
+from lightllm.models.qwen3_omni_moe_thinker.qwen3_omni_visual import Qwen3OmniMoeVisionTransformerPretrainedModel
from lightllm.utils.infer_utils import set_random_seed
from lightllm.utils.dist_utils import init_vision_distributed_env
from lightllm.utils.graceful_utils import graceful_registry
@@ -80,6 +81,15 @@ def exposed_init_model(self, kvargs):
# self.model = InternVLVisionModel()
elif self.model_type == "gemma3":
self.model = Gemma3VisionModel()
+ elif (
+ model_cfg.get("thinker_config", {}).get("vision_config", {}).get("model_type")
+ == "qwen3_omni_moe_vision_encoder"
+ ):
+ self.model = (
+ Qwen3OmniMoeVisionTransformerPretrainedModel(kvargs, **model_cfg["thinker_config"]["vision_config"])
+ .eval()
+ .bfloat16()
+ )
else:
raise Exception(f"can not support {self.model_type} now")
@@ -117,7 +127,7 @@ def exposed_encode(self, images: List[ImageItem]):
uid = uuids[i]
start, end = valid_ids[i]
image = images[i]
- self.cpu_embed_cache_client.copy_to_cache(
+ self.cpu_embed_cache_client.copy_vision_to_cache(
embed_tensor=all_img_embeds[start:end], start_index_in_cache=image.start_index_in_embed_cache
)
ids_to_set.append(uid)
diff --git a/lightllm/utils/config_utils.py b/lightllm/utils/config_utils.py
index b06309f96..790f185f2 100644
--- a/lightllm/utils/config_utils.py
+++ b/lightllm/utils/config_utils.py
@@ -25,6 +25,8 @@ def _get_config_llm_keyvalue(model_path: str, key_name: list[str]):
value = config_json["llm_config"][key]
except:
value = config_json.get("text_config", {}).get(key)
+ if config_json.get("thinker_config") is not None:
+ value = config_json.get("thinker_config", {}).get("text_config").get(key)
if value is not None:
return value
@@ -77,6 +79,14 @@ def get_layer_num(model_path: str) -> int:
def get_eos_token_ids(model_path: str) -> Optional[List[int]]:
+ try:
+ # qwen3-omini special eos_token_id
+ config_json = get_config_json(model_path)
+ assert config_json["architectures"][0] == "Qwen3OmniMoeForConditionalGeneration"
+ return [151645]
+ except:
+ pass
+
eos_token_id = _get_config_llm_keyvalue(model_path=model_path, key_name=["eos_token_id"])
if isinstance(eos_token_id, int):
return [eos_token_id]
@@ -100,6 +110,9 @@ def get_model_architectures(model_path: str):
def get_vocab_size(model_path: str):
try:
config_json = get_config_json(model_path)
+ # qwen3-omini special
+ if "thinker_config" in config_json:
+ config_json = config_json["thinker_config"]
if "llm_config" in config_json:
vocab_size = int(config_json["llm_config"]["vocab_size"])
return vocab_size
diff --git a/lightllm/utils/embed_utils.py b/lightllm/utils/embed_utils.py
index 81b05dc29..dec1537d6 100644
--- a/lightllm/utils/embed_utils.py
+++ b/lightllm/utils/embed_utils.py
@@ -27,12 +27,12 @@ def calcu_embed_cache_meta() -> "EmbedCacheMeta":
args = get_env_start_args()
assert args.enable_multimodal
from lightllm.utils.llm_utils import get_llm_model_class
- from lightllm.models import Qwen3VLTpPartModel, Qwen3VLMOETpPartModel
+ from lightllm.models import Qwen3VLTpPartModel, Qwen3VLMOETpPartModel, Qwen3OmniMOETpPartModel
model_class = get_llm_model_class()
model_dir = args.model_dir
- if model_class in [Qwen3VLTpPartModel, Qwen3VLMOETpPartModel]:
+ if model_class in [Qwen3VLTpPartModel, Qwen3VLMOETpPartModel, Qwen3OmniMOETpPartModel]:
embed_cache_meta_data = EmbedCacheMeta(
token_num=None,
layer_num=4,
diff --git a/requirements.txt b/requirements.txt
index a3b9473f8..89c52ebf8 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -61,8 +61,8 @@ sortedcontainers==2.4.0
toolz==0.12.0
torch==2.8.0
tqdm==4.65.0
-transformers==4.53.3
-tokenizers==0.21.1
+transformers==4.57.1
+tokenizers==0.22.1
urllib3==1.26.16
uvicorn==0.19.0
uvloop==0.17.0