From f2b9e1c88f1b0f4f404a0fbb9454497dd1d8a661 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Wed, 4 Feb 2026 13:20:53 +0000 Subject: [PATCH 01/20] add-qwen3-omni-thinker --- lightllm/models/__init__.py | 1 + .../pre_and_post_layer_weight.py | 11 + lightllm/models/qwen2_vl/infer_struct.py | 1 + .../triton_kernel/get_mrope_position_ids.py | 3 + .../models/qwen3_omni_moe_thinker/__init__.py | 0 .../qwen3_omni_moe_thinker/audio_process.py | 180 ++++++++ .../qwen3_omni_moe_thinker/infer_struct.py | 7 + .../layer_infer/transformer_layer_infer.py | 30 ++ .../meta_weights/code2wav_causal_conv_net.py | 116 +++++ .../code2wav_causal_trans_conv_net.py | 101 +++++ .../meta_weights/code2wav_conv_ne_xt.py | 165 +++++++ .../meta_weights/talker_resize_mlp_weight.py | 109 +++++ .../pre_and_post_layer_weight.py | 23 + .../transformers_layer_weight.py | 10 + .../models/qwen3_omni_moe_thinker/model.py | 124 ++++++ .../qwen3_omni_audio.py | 413 ++++++++++++++++++ .../qwen3_omni_visual.py | 408 +++++++++++++++++ .../audioserver/model_infer/model_rpc.py | 6 + lightllm/server/tokenizer.py | 7 + .../visualserver/model_infer/model_rpc.py | 10 + lightllm/utils/config_utils.py | 5 + 21 files changed, 1730 insertions(+) create mode 100644 lightllm/models/qwen3_omni_moe_thinker/__init__.py create mode 100644 lightllm/models/qwen3_omni_moe_thinker/audio_process.py create mode 100644 lightllm/models/qwen3_omni_moe_thinker/infer_struct.py create mode 100644 lightllm/models/qwen3_omni_moe_thinker/layer_infer/transformer_layer_infer.py create mode 100644 lightllm/models/qwen3_omni_moe_thinker/layer_weights/meta_weights/code2wav_causal_conv_net.py create mode 100644 lightllm/models/qwen3_omni_moe_thinker/layer_weights/meta_weights/code2wav_causal_trans_conv_net.py create mode 100644 lightllm/models/qwen3_omni_moe_thinker/layer_weights/meta_weights/code2wav_conv_ne_xt.py create mode 100644 lightllm/models/qwen3_omni_moe_thinker/layer_weights/meta_weights/talker_resize_mlp_weight.py create mode 100644 lightllm/models/qwen3_omni_moe_thinker/layer_weights/pre_and_post_layer_weight.py create mode 100644 lightllm/models/qwen3_omni_moe_thinker/layer_weights/transformers_layer_weight.py create mode 100644 lightllm/models/qwen3_omni_moe_thinker/model.py create mode 100644 lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_audio.py create mode 100644 lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py diff --git a/lightllm/models/__init__.py b/lightllm/models/__init__.py index 095f736791..32ccbe8337 100644 --- a/lightllm/models/__init__.py +++ b/lightllm/models/__init__.py @@ -37,4 +37,5 @@ Tarsier2LlamaTpPartModel, ) from lightllm.models.gpt_oss.model import GptOssTpPartModel +from lightllm.models.qwen3_omni_moe_thinker.model import Qwen3OmniMOETpPartModel from .registry import get_model, get_model_class diff --git a/lightllm/models/qwen2/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen2/layer_weights/pre_and_post_layer_weight.py index 6449430d9e..59c965b692 100644 --- a/lightllm/models/qwen2/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen2/layer_weights/pre_and_post_layer_weight.py @@ -1,7 +1,18 @@ from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import LMHeadWeight class Qwen2PreAndPostLayerWeight(LlamaPreAndPostLayerWeight): def __init__(self, data_type, network_config): super().__init__(data_type, network_config) + hidden_size = network_config["hidden_size"] + vocab_size = network_config["vocab_size"] + tie_word_embeddings = network_config.get("tie_word_embeddings", False) + self.lm_head_weight_ = LMHeadWeight( + dim=hidden_size, + vocab_size=vocab_size, + weight_name="thinker.lm_head.weight", + data_type=self.data_type_, + embedding_weight=self.wte_weight_ if tie_word_embeddings else None, + ) return diff --git a/lightllm/models/qwen2_vl/infer_struct.py b/lightllm/models/qwen2_vl/infer_struct.py index 747be932d9..2508fd554a 100644 --- a/lightllm/models/qwen2_vl/infer_struct.py +++ b/lightllm/models/qwen2_vl/infer_struct.py @@ -79,5 +79,6 @@ def get_mrope_position(self, multimodal_params: List[dict]) -> torch.Tensor: b_ready_cache_len=self.b_ready_cache_len, b_q_seq_len=self.b_q_seq_len, b_start_loc=self.b_q_start_loc, + use_image_h=getattr(self, "use_image_h", True), ) return position_ids diff --git a/lightllm/models/qwen2_vl/triton_kernel/get_mrope_position_ids.py b/lightllm/models/qwen2_vl/triton_kernel/get_mrope_position_ids.py index 756198e89a..b5455fb4cf 100644 --- a/lightllm/models/qwen2_vl/triton_kernel/get_mrope_position_ids.py +++ b/lightllm/models/qwen2_vl/triton_kernel/get_mrope_position_ids.py @@ -91,11 +91,14 @@ def get_mrope_position_triton( b_ready_cache_len: torch.Tensor, b_q_seq_len: torch.Tensor, b_start_loc: torch.Tensor, + use_image_h: bool = True, ) -> torch.Tensor: batch_size = b_q_seq_len.shape[0] assert batch_size == b_image_nums.shape[0] grid = (batch_size,) + if not use_image_h: # 也可以放在前面生成的地方改, 看哪里合适 + b_image_thwd[:, 1] = b_image_thwd[:, 2] BLOCK_SIZE = 64 _get_mrope_position_triton[grid]( b_image_start_idx=b_image_start_idx, diff --git a/lightllm/models/qwen3_omni_moe_thinker/__init__.py b/lightllm/models/qwen3_omni_moe_thinker/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/qwen3_omni_moe_thinker/audio_process.py b/lightllm/models/qwen3_omni_moe_thinker/audio_process.py new file mode 100644 index 0000000000..3b15bd271e --- /dev/null +++ b/lightllm/models/qwen3_omni_moe_thinker/audio_process.py @@ -0,0 +1,180 @@ +import torch +import numpy as np +from typing import TYPE_CHECKING, Any, Optional, Union, Tuple +from transformers.audio_utils import mel_filter_bank, spectrogram, window_function +from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor +from transformers.feature_extraction_utils import BatchFeature +from transformers.utils import TensorType + + +class WhisperFeatureExtractor(SequenceFeatureExtractor): + + model_input_names = ["input_features"] + + def __init__( + self, + feature_size=80, + sampling_rate=16000, + hop_length=160, + chunk_length=30, + n_fft=400, + padding_value=0.0, + dither=0.0, + return_attention_mask=False, # pad inputs to max length with silence token (zero) and no attention mask + **kwargs, + ): + super().__init__( + feature_size=feature_size, + sampling_rate=sampling_rate, + padding_value=padding_value, + return_attention_mask=return_attention_mask, + **kwargs, + ) + self.n_fft = n_fft + self.hop_length = hop_length + self.chunk_length = chunk_length + self.n_samples = chunk_length * sampling_rate + self.nb_max_frames = self.n_samples // hop_length + self.sampling_rate = sampling_rate + self.dither = dither + self.mel_filters = mel_filter_bank( + num_frequency_bins=1 + n_fft // 2, + num_mel_filters=feature_size, + min_frequency=0.0, + max_frequency=8000.0, + sampling_rate=sampling_rate, + norm="slaney", + mel_scale="slaney", + ) + + def _torch_extract_fbank_features(self, waveform: np.ndarray, device: str = "cpu") -> np.ndarray: + waveform = torch.from_numpy(waveform).to(device, torch.float32) + window = torch.hann_window(self.n_fft, device=device) + + if self.dither != 0.0: + waveform += self.dither * torch.randn(waveform.shape, dtype=waveform.dtype, device=waveform.device) + + stft = torch.stft(waveform, self.n_fft, self.hop_length, window=window, return_complex=True) + magnitudes = stft[..., :-1].abs() ** 2 + + mel_filters = torch.from_numpy(self.mel_filters).to(device, torch.float32) + mel_spec = mel_filters.T @ magnitudes + + log_spec = torch.clamp(mel_spec, min=1e-10).log10() + if waveform.dim() == 2: + max_val = log_spec.max(dim=2, keepdim=True)[0].max(dim=1, keepdim=True)[0] + log_spec = torch.maximum(log_spec, max_val - 8.0) + else: + log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) + log_spec = (log_spec + 4.0) / 4.0 + if device != "cpu": + log_spec = log_spec.detach().cpu() + return log_spec.numpy() + + # Copied from transformers.models.wav2vec2.feature_extraction_wav2vec2. + # Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm + def zero_mean_unit_var_norm( + self, input_values: list[np.ndarray], attention_mask: list[np.ndarray], padding_value: float = 0.0 + ) -> list[np.ndarray]: + if attention_mask is not None: + attention_mask = np.array(attention_mask, np.int32) + normed_input_values = [] + + for vector, length in zip(input_values, attention_mask.sum(-1)): + normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7) + if length < normed_slice.shape[0]: + normed_slice[length:] = padding_value + + normed_input_values.append(normed_slice) + else: + normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values] + + return normed_input_values + + def _preprocess( + self, + raw_speech: Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]], + truncation: bool = True, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_attention_mask: Optional[bool] = None, + padding: Optional[str] = "max_length", + max_length: Optional[int] = None, + sampling_rate: Optional[int] = None, + do_normalize: Optional[bool] = None, + device: Optional[str] = "cpu", + return_token_timestamps: Optional[bool] = None, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + + is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1 + if is_batched_numpy and len(raw_speech.shape) > 2: + raise ValueError(f"Only mono-channel audio is supported for input to {self}") + is_batched = is_batched_numpy or ( + isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list))) + ) + + if is_batched: + raw_speech = [np.asarray([speech], dtype=np.float32).T for speech in raw_speech] + elif not is_batched and not isinstance(raw_speech, np.ndarray): + raw_speech = np.asarray(raw_speech, dtype=np.float32) + elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64): + raw_speech = raw_speech.astype(np.float32) + + # always return batch + if not is_batched: + raw_speech = [np.asarray([raw_speech]).T] + + batched_speech = BatchFeature({"input_features": raw_speech}) + + # convert into correct format for padding + + padded_inputs = self.pad( + batched_speech, + padding=padding, + max_length=max_length if max_length else self.n_samples, + truncation=truncation, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=return_attention_mask or do_normalize, + ) + + # zero-mean and unit-variance normalization + if do_normalize: + padded_inputs["input_features"] = self.zero_mean_unit_var_norm( + padded_inputs["input_features"], + attention_mask=padded_inputs["attention_mask"], + padding_value=self.padding_value, + ) + padded_inputs["input_features"] = np.stack(padded_inputs["input_features"], axis=0) + + # make sure list is in array format + input_features = padded_inputs.get("input_features").transpose(2, 0, 1) + + input_features = self._torch_extract_fbank_features(input_features[0], device) + + if isinstance(input_features[0], list): + padded_inputs["input_features"] = [np.asarray(feature, dtype=np.float32) for feature in input_features] + + else: + padded_inputs["input_features"] = input_features + + if return_attention_mask: + # rescale from sample (48000) to feature (3000) + rescaled_attention_mask = padded_inputs["attention_mask"][:, :: self.hop_length] + + # The STFT computation produces L//hop_length + 1 frames, + # but we skip the last frame (see `_torch_extract_fbank_features`). + # This means we need to trim the rescaled attention mask to match + # the actual number of frames (L//hop_length) when the input length + # is not perfectly divisible by the hop length. + if padded_inputs["attention_mask"].shape[1] % self.hop_length != 0: + rescaled_attention_mask = rescaled_attention_mask[:, :-1] + padded_inputs["attention_mask"] = rescaled_attention_mask + + if return_token_timestamps is not None: + padded_inputs["num_frames"] = [len(raw_speech_i) // self.hop_length for raw_speech_i in raw_speech] + + if return_tensors is not None: + padded_inputs = padded_inputs.convert_to_tensors(return_tensors) + + return padded_inputs diff --git a/lightllm/models/qwen3_omni_moe_thinker/infer_struct.py b/lightllm/models/qwen3_omni_moe_thinker/infer_struct.py new file mode 100644 index 0000000000..a273f2abae --- /dev/null +++ b/lightllm/models/qwen3_omni_moe_thinker/infer_struct.py @@ -0,0 +1,7 @@ +from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo + + +class Qwen3OmniMOEInferStateInfo(Qwen3VLInferStateInfo): + def __init__(self): + self.use_image_h = False + super().__init__() diff --git a/lightllm/models/qwen3_omni_moe_thinker/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_omni_moe_thinker/layer_infer/transformer_layer_infer.py new file mode 100644 index 0000000000..91c65e6750 --- /dev/null +++ b/lightllm/models/qwen3_omni_moe_thinker/layer_infer/transformer_layer_infer.py @@ -0,0 +1,30 @@ +import os +import torch +import torch.functional as F +import torch.distributed as dist +import numpy as np +import triton +from typing import Tuple +from lightllm.models.qwen3_vl_moe.layer_infer.transformer_layer_infer import Qwen3VLMOETransformerLayerInfer +from lightllm.models.qwen3_moe.layer_infer.transformer_layer_infer import Qwen3MOETransformerLayerInfer +from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd +from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd +from functools import partial +from lightllm.utils.log_utils import init_logger +from lightllm.utils.dist_utils import get_global_world_size +from lightllm.distributed.communication_op import all_gather_into_tensor, reduce_scatter_tensor + +logger = init_logger(__name__) + + +class Qwen3OmniMOETransformerLayerInfer(Qwen3VLMOETransformerLayerInfer): + def __init__(self, layer_num, network_config): + self.layer_num_ = network_config["num_hidden_layers"] + super().__init__(layer_num, network_config) + self.head_dim_ = network_config["head_dim"] + self.mrope_section = torch.tensor( + network_config["rope_scaling"]["mrope_section"], dtype=torch.int32, device="cuda" + ) + return diff --git a/lightllm/models/qwen3_omni_moe_thinker/layer_weights/meta_weights/code2wav_causal_conv_net.py b/lightllm/models/qwen3_omni_moe_thinker/layer_weights/meta_weights/code2wav_causal_conv_net.py new file mode 100644 index 0000000000..1caabf31da --- /dev/null +++ b/lightllm/models/qwen3_omni_moe_thinker/layer_weights/meta_weights/code2wav_causal_conv_net.py @@ -0,0 +1,116 @@ +import math +import torch +import numpy as np +from typing import Dict, Optional +from transformers.activations import ACT2FN +from lightllm.common.basemodel.layer_weights.meta_weights.base_weight import BaseWeightTpl +from lightllm.common.basemodel.layer_weights.meta_weights.platform_op import PlatformAwareOp +from lightllm.common.basemodel.triton_kernel.embedding import embedding as embedding_kernel +from lightllm.utils.dist_utils import get_dp_world_size, get_current_rank_in_dp + + +class Qwen3OmniMoeCausalConvNetWeight(BaseWeightTpl, PlatformAwareOp): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + weight_name: str, + bias_name: str, + data_type: torch.dtype, + dilation: int = 1, + stride: int = 1, + groups: int = 1, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.dilation = dilation + self.stride = stride + self.groups = groups + self.weight_name: str = weight_name + self.bias_name: str = bias_name + self.data_type_ = data_type + + self.kernel_size_effective = (kernel_size - 1) * dilation + 1 + self.padding = self.kernel_size_effective - self.stride + + self._create_weight() + + def _create_weight(self): + # Conv1d weight shape: (out_channels, in_channels // groups, kernel_size) + self.weight: torch.Tensor = torch.empty( + self.out_channels, + self.in_channels // self.groups, + self.kernel_size, + dtype=self.data_type_, + device=self.device_id_, + ) + self.bias: torch.Tensor = torch.empty(self.out_channels, dtype=self.data_type_, device=self.device_id_) + self.weight.load_ok = False + self.bias.load_ok = False + + def load_hf_weights(self, weights: Dict[str, torch.Tensor]): + if self.weight_name in weights: + t_weight = weights[self.weight_name] + assert t_weight.shape == ( + self.out_channels, + self.in_channels // self.groups, + self.kernel_size, + ) + self.weight.copy_(t_weight.to(self.data_type_)) + self.weight.load_ok = True + if self.bias_name in weights: + t_bias = weights[self.bias_name] + assert t_bias.shape == (self.out_channels,) + self.bias.copy_(t_bias.to(self.data_type_)) + self.bias.load_ok = True + + def verify_load(self): + return self.weight.load_ok and self.bias.load_ok + + def _get_extra_padding_for_conv1d(self, hidden_state: torch.Tensor) -> int: + length = hidden_state.shape[-1] + n_frames = (length - self.kernel_size_effective + self.padding) / self.stride + 1 + ideal_length = (math.ceil(n_frames) - 1) * self.stride + (self.kernel_size_effective - self.padding) + return int(ideal_length - length) + + def _native_forward( + self, hidden_state: torch.Tensor, out: Optional[torch.Tensor] = None, _alloc_func=torch.empty + ) -> torch.Tensor: + extra_padding = self._get_extra_padding_for_conv1d(hidden_state) + hidden_state = torch.nn.functional.pad(hidden_state, (self.padding, extra_padding), mode="constant", value=0) + x = torch.nn.functional.conv1d( + hidden_state, + self.weight, + self.bias, + stride=self.stride, + padding=0, + dilation=self.dilation, + groups=self.groups, + ).contiguous() + if out is not None: + out.copy_(x) + return out + return x + + def _cuda_forward( + self, hidden_state: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty + ) -> torch.Tensor: + if out is None: + result = self._native_forward(hidden_state=hidden_state, out=None, _alloc_func=alloc_func) + return result + result = self._native_forward(hidden_state=hidden_state, out=None, _alloc_func=alloc_func) + out.copy_(result) + return out + + def _musa_forward( + self, hidden_state: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty + ) -> torch.Tensor: + return self._cuda_forward(hidden_state=hidden_state, out=out, alloc_func=alloc_func) + + def __call__( + self, hidden_state: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty + ) -> torch.Tensor: + return self._forward(hidden_state=hidden_state, out=out, alloc_func=alloc_func) diff --git a/lightllm/models/qwen3_omni_moe_thinker/layer_weights/meta_weights/code2wav_causal_trans_conv_net.py b/lightllm/models/qwen3_omni_moe_thinker/layer_weights/meta_weights/code2wav_causal_trans_conv_net.py new file mode 100644 index 0000000000..058565e848 --- /dev/null +++ b/lightllm/models/qwen3_omni_moe_thinker/layer_weights/meta_weights/code2wav_causal_trans_conv_net.py @@ -0,0 +1,101 @@ +import math +import torch +import numpy as np +from typing import Dict, Optional +from transformers.activations import ACT2FN +from lightllm.common.basemodel.layer_weights.meta_weights.base_weight import BaseWeightTpl +from lightllm.common.basemodel.layer_weights.meta_weights.platform_op import PlatformAwareOp +from lightllm.common.basemodel.triton_kernel.embedding import embedding as embedding_kernel +from lightllm.utils.dist_utils import get_dp_world_size, get_current_rank_in_dp + + +class Qwen3OmniMoeCode2wavCausalTransConvNetWeight(BaseWeightTpl, PlatformAwareOp): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int, + weight_name: str, + bias_name: str, + data_type: torch.dtype, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.weight_name: str = weight_name + self.bias_name: str = bias_name + self.data_type_ = data_type + + pad = kernel_size - stride + self.left_pad = math.ceil(pad) + self.right_pad = pad = self.left_pad + + self._create_weight() + + def _create_weight(self): + # ConvTranspose1d weight shape: (in_channels, out_channels, kernel_size) when groups=1 + self.weight: torch.Tensor = torch.empty( + self.in_channels, self.out_channels, self.kernel_size, dtype=self.data_type_, device=self.device_id_ + ) + self.bias: torch.Tensor = torch.empty(self.out_channels, dtype=self.data_type_, device=self.device_id_) + self.weight.load_ok = False + self.bias.load_ok = False + + def load_hf_weights(self, weights: Dict[str, torch.Tensor]): + if self.weight_name in weights: + t_weight = weights[self.weight_name] + assert t_weight.shape == (self.in_channels, self.out_channels, self.kernel_size) + self.weight.copy_(t_weight.to(self.data_type_)) + self.weight.load_ok = True + if self.bias_name in weights: + t_bias = weights[self.bias_name] + assert t_bias.shape == (self.out_channels,) + self.bias.copy_(t_bias.to(self.data_type_)) + self.bias.load_ok = True + + def verify_load(self): + return self.weight.load_ok and self.bias.load_ok + + def _native_forward( + self, hidden_state: torch.Tensor, out: Optional[torch.Tensor] = None, _alloc_func=torch.empty + ) -> torch.Tensor: + # hidden_state: [B, C_in, L] + x = torch.nn.functional.conv_transpose1d( + hidden_state, + self.weight, + self.bias, + stride=self.stride, + padding=0, + output_padding=0, + groups=1, + dilation=1, + ) + x = x[..., self.left_pad : x.shape[-1] - self.right_pad].contiguous() + if out is not None: + out.copy_(x) + return out + return x + + def _cuda_forward( + self, hidden_state: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty + ) -> torch.Tensor: + if out is None: + # output length depends on input length; allocate after computing + result = self._native_forward(hidden_state=hidden_state, out=None, _alloc_func=alloc_func) + return result + result = self._native_forward(hidden_state=hidden_state, out=None, _alloc_func=alloc_func) + out.copy_(result) + return out + + def _musa_forward( + self, hidden_state: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty + ) -> torch.Tensor: + return self._cuda_forward(hidden_state=hidden_state, out=out, alloc_func=alloc_func) + + def __call__( + self, hidden_state: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty + ) -> torch.Tensor: + return self._forward(hidden_state=hidden_state, out=out, alloc_func=alloc_func) diff --git a/lightllm/models/qwen3_omni_moe_thinker/layer_weights/meta_weights/code2wav_conv_ne_xt.py b/lightllm/models/qwen3_omni_moe_thinker/layer_weights/meta_weights/code2wav_conv_ne_xt.py new file mode 100644 index 0000000000..21b0462dce --- /dev/null +++ b/lightllm/models/qwen3_omni_moe_thinker/layer_weights/meta_weights/code2wav_conv_ne_xt.py @@ -0,0 +1,165 @@ +import math +import torch +import numpy as np +from typing import Dict, Optional +from transformers.activations import ACT2FN +from lightllm.common.basemodel.layer_weights.meta_weights.base_weight import BaseWeightTpl +from lightllm.common.basemodel.layer_weights.meta_weights.platform_op import PlatformAwareOp +from lightllm.common.basemodel.triton_kernel.embedding import embedding as embedding_kernel +from lightllm.utils.dist_utils import get_dp_world_size, get_current_rank_in_dp + + +class Qwen3OmniMoeConvNeXtBlockWeight(BaseWeightTpl, PlatformAwareOp): + def __init__( + self, + dim: int, + dwconv, + norm_weight_name: str, + norm_bias_name: str, + pwconv1_weight_name: str, + pwconv1_bias_name: str, + pwconv2_weight_name: str, + pwconv2_bias_name: str, + gamma_name: str, + data_type: torch.dtype, + eps: float = 1e-6, + ): + super().__init__() + self.dim = dim + self.dwconv = dwconv + self.norm_weight_name = norm_weight_name + self.norm_bias_name = norm_bias_name + self.pwconv1_weight_name = pwconv1_weight_name + self.pwconv1_bias_name = pwconv1_bias_name + self.pwconv2_weight_name = pwconv2_weight_name + self.pwconv2_bias_name = pwconv2_bias_name + self.gamma_name = gamma_name + self.data_type_ = data_type + self.eps = eps + + self._create_weight() + + def _create_weight(self): + self.norm_weight: torch.Tensor = torch.empty(self.dim, dtype=self.data_type_, device=self.device_id_) + self.norm_bias: torch.Tensor = torch.empty(self.dim, dtype=self.data_type_, device=self.device_id_) + + self.pwconv1_weight: torch.Tensor = torch.empty( + 4 * self.dim, self.dim, dtype=self.data_type_, device=self.device_id_ + ) + self.pwconv1_bias: torch.Tensor = torch.empty(4 * self.dim, dtype=self.data_type_, device=self.device_id_) + + self.pwconv2_weight: torch.Tensor = torch.empty( + self.dim, 4 * self.dim, dtype=self.data_type_, device=self.device_id_ + ) + self.pwconv2_bias: torch.Tensor = torch.empty(self.dim, dtype=self.data_type_, device=self.device_id_) + + self.gamma: torch.Tensor = torch.empty(self.dim, dtype=self.data_type_, device=self.device_id_) + + self.norm_weight.load_ok = False + self.norm_bias.load_ok = False + self.pwconv1_weight.load_ok = False + self.pwconv1_bias.load_ok = False + self.pwconv2_weight.load_ok = False + self.pwconv2_bias.load_ok = False + self.gamma.load_ok = False + + def load_hf_weights(self, weights: Dict[str, torch.Tensor]): + if self.norm_weight_name in weights: + t = weights[self.norm_weight_name] + assert t.shape == (self.dim,) + self.norm_weight.copy_(t.to(self.data_type_)) + self.norm_weight.load_ok = True + + if self.norm_bias_name in weights: + t = weights[self.norm_bias_name] + assert t.shape == (self.dim,) + self.norm_bias.copy_(t.to(self.data_type_)) + self.norm_bias.load_ok = True + + if self.pwconv1_weight_name in weights: + t = weights[self.pwconv1_weight_name] + assert t.shape == (4 * self.dim, self.dim) + self.pwconv1_weight.copy_(t.to(self.data_type_)) + self.pwconv1_weight.load_ok = True + + if self.pwconv1_bias_name in weights: + t = weights[self.pwconv1_bias_name] + assert t.shape == (4 * self.dim,) + self.pwconv1_bias.copy_(t.to(self.data_type_)) + self.pwconv1_bias.load_ok = True + + if self.pwconv2_weight_name in weights: + t = weights[self.pwconv2_weight_name] + assert t.shape == (self.dim, 4 * self.dim) + self.pwconv2_weight.copy_(t.to(self.data_type_)) + self.pwconv2_weight.load_ok = True + + if self.pwconv2_bias_name in weights: + t = weights[self.pwconv2_bias_name] + assert t.shape == (self.dim,) + self.pwconv2_bias.copy_(t.to(self.data_type_)) + self.pwconv2_bias.load_ok = True + + if self.gamma_name in weights: + t = weights[self.gamma_name] + assert t.shape == (self.dim,) + self.gamma.copy_(t.to(self.data_type_)) + self.gamma.load_ok = True + + def verify_load(self): + return ( + self.norm_weight.load_ok + and self.norm_bias.load_ok + and self.pwconv1_weight.load_ok + and self.pwconv1_bias.load_ok + and self.pwconv2_weight.load_ok + and self.pwconv2_bias.load_ok + and self.gamma.load_ok + ) + + def _native_forward( + self, hidden_states: torch.Tensor, out: Optional[torch.Tensor] = None, _alloc_func=torch.empty + ) -> torch.Tensor: + input = hidden_states + + hidden_states = self.dwconv(hidden_states) # [B, C, L] + hidden_states = hidden_states.permute(0, 2, 1) # [B, L, C] + + mean = hidden_states.mean(dim=-1, keepdim=True) + var = (hidden_states - mean).pow(2).mean(dim=-1, keepdim=True) + hidden_states = (hidden_states - mean) / torch.sqrt(var + self.eps) + hidden_states = hidden_states * self.norm_weight.view(1, 1, -1) + self.norm_bias.view(1, 1, -1) + + hidden_states = torch.nn.functional.linear(hidden_states, self.pwconv1_weight, self.pwconv1_bias) + hidden_states = torch.nn.functional.gelu(hidden_states) + hidden_states = torch.nn.functional.linear(hidden_states, self.pwconv2_weight, self.pwconv2_bias) + + hidden_states = hidden_states * self.gamma.view(1, 1, -1) + + hidden_states = hidden_states.permute(0, 2, 1) # [B, C, L] + hidden_states = input + hidden_states + + if out is not None: + out.copy_(hidden_states) + return out + return hidden_states + + def _cuda_forward( + self, hidden_states: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty + ) -> torch.Tensor: + if out is None: + result = self._native_forward(hidden_states=hidden_states, out=None, _alloc_func=alloc_func) + return result + result = self._native_forward(hidden_states=hidden_states, out=None, _alloc_func=alloc_func) + out.copy_(result) + return out + + def _musa_forward( + self, hidden_states: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty + ) -> torch.Tensor: + return self._cuda_forward(hidden_states=hidden_states, out=out, alloc_func=alloc_func) + + def __call__( + self, hidden_states: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty + ) -> torch.Tensor: + return self._forward(hidden_states=hidden_states, out=out, alloc_func=alloc_func) diff --git a/lightllm/models/qwen3_omni_moe_thinker/layer_weights/meta_weights/talker_resize_mlp_weight.py b/lightllm/models/qwen3_omni_moe_thinker/layer_weights/meta_weights/talker_resize_mlp_weight.py new file mode 100644 index 0000000000..fac9c5cdc7 --- /dev/null +++ b/lightllm/models/qwen3_omni_moe_thinker/layer_weights/meta_weights/talker_resize_mlp_weight.py @@ -0,0 +1,109 @@ +import torch +import numpy as np +from typing import Dict, Optional +from transformers.activations import ACT2FN +from lightllm.common.basemodel.layer_weights.meta_weights.base_weight import BaseWeightTpl +from lightllm.common.basemodel.layer_weights.meta_weights.platform_op import PlatformAwareOp +from lightllm.common.basemodel.triton_kernel.embedding import embedding as embedding_kernel +from lightllm.utils.dist_utils import get_dp_world_size, get_current_rank_in_dp + + +class Qwen3OmniMoeTalkerResizeMLPWeight(BaseWeightTpl, PlatformAwareOp): + def __init__( + self, + in_dim: int, + intermediate_dim: int, + out_dim: int, + fc1_weight_name: str, + fc1_bias_name: str, + fc2_weight_name: str, + fc2_bias_name: str, + hidden_act: str, + data_type: torch.dtype, + ): + super().__init__() + self.in_dim = in_dim + self.intermediate_dim = intermediate_dim + self.out_dim = out_dim + self.fc1_weight_name: str = fc1_weight_name + self.fc1_bias_name: str = fc1_bias_name + self.fc2_weight_name: str = fc2_weight_name + self.fc2_bias_name: str = fc2_bias_name + self.data_type_ = data_type + self.act_fn = ACT2FN[hidden_act] + self._create_weight() + + def _create_weight(self): + self.fc1_weight: torch.Tensor = torch.empty( + self.intermediate_dim, self.in_dim, dtype=self.data_type_, device=self.device_id_ + ) + self.fc1_bias: torch.Tensor = torch.empty(self.intermediate_dim, dtype=self.data_type_, device=self.device_id_) + self.fc2_weight: torch.Tensor = torch.empty( + self.out_dim, self.intermediate_dim, dtype=self.data_type_, device=self.device_id_ + ) + self.fc2_bias: torch.Tensor = torch.empty(self.out_dim, dtype=self.data_type_, device=self.device_id_) + self.fc1_weight.load_ok = False + self.fc1_bias.load_ok = False + self.fc2_weight.load_ok = False + self.fc2_bias.load_ok = False + + def load_hf_weights(self, weights: Dict[str, torch.Tensor]): + if self.fc1_weight_name in weights: + t = weights[self.fc1_weight_name] + assert t.shape == (self.intermediate_dim, self.in_dim) + self.fc1_weight.copy_(t.to(self.data_type_)) + self.fc1_weight.load_ok = True + if self.fc1_bias_name in weights: + t = weights[self.fc1_bias_name] + assert t.shape == (self.intermediate_dim,) + self.fc1_bias.copy_(t.to(self.data_type_)) + self.fc1_bias.load_ok = True + if self.fc2_weight_name in weights: + t = weights[self.fc2_weight_name] + assert t.shape == (self.out_dim, self.intermediate_dim) + self.fc2_weight.copy_(t.to(self.data_type_)) + self.fc2_weight.load_ok = True + if self.fc2_bias_name in weights: + t = weights[self.fc2_bias_name] + assert t.shape == (self.out_dim,) + self.fc2_bias.copy_(t.to(self.data_type_)) + self.fc2_bias.load_ok = True + + def verify_load(self): + return self.fc1_weight.load_ok and self.fc1_bias.load_ok and self.fc2_weight.load_ok and self.fc2_bias.load_ok + + def _native_forward( + self, hidden_state: torch.Tensor, out: Optional[torch.Tensor] = None, _alloc_func=torch.empty + ) -> torch.Tensor: + in_dim = hidden_state.shape[-1] + assert in_dim == self.in_dim + x = hidden_state.reshape(-1, in_dim) + y = torch.nn.functional.linear(x, self.fc1_weight, self.fc1_bias) + y = self.act_fn(y) + y = torch.nn.functional.linear(y, self.fc2_weight, self.fc2_bias) + y = y.reshape(*hidden_state.shape[:-1], self.out_dim) + if out is not None: + out.copy_(y) + return out + return y + + def _cuda_forward( + self, hidden_state: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty + ) -> torch.Tensor: + if out is None: + out = alloc_func( + (*hidden_state.shape[:-1], self.out_dim), dtype=hidden_state.dtype, device=hidden_state.device + ) + result = self._native_forward(hidden_state=hidden_state, out=None, _alloc_func=alloc_func) + out.copy_(result) + return out + + def _musa_forward( + self, hidden_state: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty + ) -> torch.Tensor: + return self._cuda_forward(hidden_state=hidden_state, out=out, alloc_func=alloc_func) + + def __call__( + self, hidden_state: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty + ) -> torch.Tensor: + return self._forward(hidden_state=hidden_state, out=out, alloc_func=alloc_func) diff --git a/lightllm/models/qwen3_omni_moe_thinker/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen3_omni_moe_thinker/layer_weights/pre_and_post_layer_weight.py new file mode 100644 index 0000000000..614a68bcb2 --- /dev/null +++ b/lightllm/models/qwen3_omni_moe_thinker/layer_weights/pre_and_post_layer_weight.py @@ -0,0 +1,23 @@ +import numpy as np +from lightllm.models.qwen2.layer_weights.pre_and_post_layer_weight import Qwen2PreAndPostLayerWeight + +# add key: language_model.xxx -> xxx +# only change keys at PreAndPostLayerWeight load, TransformLayerWeight is correct now +def rename_weight_keys(weights): + prefix = "thinker.model." + keys = list(weights.keys()) + for k in keys: + if prefix in k: + weights[k.replace(prefix, "model.")] = weights.pop(k) + + +class Qwen3OmniMOEThinkerPreAndPostLayerWeight(Qwen2PreAndPostLayerWeight): + def __init__(self, data_type, network_config): + self.layer_num_ = network_config["num_hidden_layers"] + super().__init__(data_type, network_config) + return + + def load_hf_weights(self, weights): + rename_weight_keys(weights) + super().load_hf_weights(weights) + return diff --git a/lightllm/models/qwen3_omni_moe_thinker/layer_weights/transformers_layer_weight.py b/lightllm/models/qwen3_omni_moe_thinker/layer_weights/transformers_layer_weight.py new file mode 100644 index 0000000000..220d846e0e --- /dev/null +++ b/lightllm/models/qwen3_omni_moe_thinker/layer_weights/transformers_layer_weight.py @@ -0,0 +1,10 @@ +import os +from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, FusedMoeWeight + + +class Qwen3OmniMOEThinkerTransformerLayerWeight(Qwen3MOETransformerLayerWeight): + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + self.layer_num_ = network_config["num_hidden_layers"] + super().__init__(layer_num, data_type, network_config, quant_cfg) + return diff --git a/lightllm/models/qwen3_omni_moe_thinker/model.py b/lightllm/models/qwen3_omni_moe_thinker/model.py new file mode 100644 index 0000000000..4ceab9573c --- /dev/null +++ b/lightllm/models/qwen3_omni_moe_thinker/model.py @@ -0,0 +1,124 @@ +import os +import json +from lightllm.common.build_utils import repair_config +from lightllm.models.registry import ModelRegistry +from lightllm.models.qwen3_moe.model import Qwen3MOEModel +from lightllm.models.qwen3_vl.layer_infer.pre_layer_infer import Qwen3VLMultimodalPreLayerInfer + +from lightllm.models.qwen3_omni_moe_thinker.layer_infer.transformer_layer_infer import Qwen3OmniMOETransformerLayerInfer +from lightllm.models.qwen3_omni_moe_thinker.layer_weights.pre_and_post_layer_weight import ( + Qwen3OmniMOEThinkerPreAndPostLayerWeight, +) +from lightllm.models.qwen3_omni_moe_thinker.layer_weights.transformers_layer_weight import ( + Qwen3OmniMOEThinkerTransformerLayerWeight, +) + +from lightllm.models.qwen3_vl_moe.model import Qwen3VLMOETpPartModel +from lightllm.models.qwen3_omni_moe_thinker.infer_struct import Qwen3OmniMOEInferStateInfo +from lightllm.models.qwen3_vl.model import QWen3VLTokenizer +from lightllm.server.core.objs import SamplingParams +from lightllm.server.multimodal_params import AudioItem, MultimodalParams, ImageItem + + +# <|audio_start|><|audio_pad|><|audio_end|> +AUDIO_START_TOKEN = "<|audio_start|>" +AUDIO_END_TOKEN = "<|audio_end|>" + +MIN_AUDIO_LEN = 480 + + +class QWen3OmniTokenizer(QWen3VLTokenizer): + def __init__(self, tokenizer=None, image_processor=None, **kwargs): + self.tokenizer = tokenizer + self.image_processor = image_processor + self.min_pixel = self.image_processor.min_pixels + self.max_pixel = self.image_processor.max_pixels + self.patch_size = self.image_processor.patch_size + self.merge_size = self.image_processor.merge_size + self.image_start_id = kwargs["model_cfg"]["vision_start_token_id"] + self.image_end_id = kwargs["model_cfg"]["vision_end_token_id"] + self.image_token_id = kwargs["model_cfg"]["image_token_id"] + + self.audio_start_tag = AUDIO_START_TOKEN + self.audio_start_id = tokenizer.convert_tokens_to_ids(self.audio_start_tag) + + self.audio_end_tag = AUDIO_END_TOKEN + self.audio_end_id = tokenizer.convert_tokens_to_ids(self.audio_end_tag) + + self.audio_min_length = MIN_AUDIO_LEN + self.audio_max_length = 16000 * 30 + + def init_audioitem_extral_params( + self, audio: AudioItem, multi_params: MultimodalParams, sampling_params: SamplingParams + ): + return + + def get_image_token_length(self, img: ImageItem): + return ( + self.get_image_patch_func( + img.image_w, img.image_h, max_num=img.extra_params["image_patch_max_num"], use_thumbnail=True + ) + * self.image_length + ) + + def get_audio_token_length(self, audio: AudioItem): + L = audio.audio_length + audio_token_num = 0 + chunk_lens = [] + if L <= self.audio_max_length: + cur_len = L + if cur_len < self.audio_min_length: + cur_len = self.audio_min_length + chunk_lens.append(cur_len) + else: + start = 0 + while start < L: + end = min(start + self.audio_max_length, L) + cur_len = end - start + + if cur_len < self.audio_min_length: + cur_len = self.audio_min_length + + chunk_lens.append(cur_len) + start = end + for chunk_len in chunk_lens: + mel_len = chunk_len // 160 + dilation = 1 + L_in = mel_len + for (padding, kernel_size, stride) in eval("[(1,3,1)] + [(1,3,2)] "): + L_out = L_in + 2 * padding - dilation * (kernel_size - 1) - 1 + L_out = 1 + L_out // stride + L_in = L_out + audio_len_after_cnn = L_out + chunk_token_num = (audio_len_after_cnn - 2) // 2 + 1 + audio_token_num += int(chunk_token_num) + return audio_token_num + + +@ModelRegistry(["qwen3_omni_moe"], is_multimodal=True) +class Qwen3OmniMOETpPartModel(Qwen3VLMOETpPartModel): + + pre_layer_infer_class = Qwen3VLMultimodalPreLayerInfer + transformer_layer_infer_class = Qwen3OmniMOETransformerLayerInfer + + pre_and_post_weight_class = Qwen3OmniMOEThinkerPreAndPostLayerWeight + transformer_weight_class = Qwen3OmniMOEThinkerTransformerLayerWeight + + infer_state_class = Qwen3OmniMOEInferStateInfo + + def __init__(self, kvargs): + super().__init__(kvargs) + return + + def _init_config(self): + with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: + all_config = json.load(json_file) + self.config = all_config["thinker_config"]["text_config"] + # rename keys + print(f"self.config is {self.config}") + repair_config(self.config, same_names=["num_attention_heads", "n_head"]) + repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"]) + repair_config(self.config, same_names=["num_hidden_layers", "n_layer"]) + if self.finetune_config: + self.config["vocab_size"] = self.finetune_config.vocab_size + return diff --git a/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_audio.py b/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_audio.py new file mode 100644 index 0000000000..7e62c14ecd --- /dev/null +++ b/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_audio.py @@ -0,0 +1,413 @@ +import os +import json +import math +import torch +import librosa +import numpy as np +from io import BytesIO +from torch import Tensor, nn +from safetensors import safe_open +from torch.nn import functional as F +from typing import Callable, Optional, Union, List +from rpyc.utils.classic import obtain + +from transformers.activations import ACT2FN + +from lightllm.server.multimodal_params import AudioItem +from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data +from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient +from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager +from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd +from lightllm.models.qwen3_omni_moe_thinker.audio_process import WhisperFeatureExtractor + + +def _get_feat_extract_output_lengths(input_lengths): + """ + Computes the output length of the convolutional layers and the output length of the audio encoder + """ + + input_lengths_leave = input_lengths % 100 + feat_lengths = (input_lengths_leave - 1) // 2 + 1 + output_lengths = ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13 + return output_lengths + + +class Qwen3OmniMoeAudioEncoderLayer(nn.Module): + def __init__( + self, + d_model, + encoder_attention_heads, + attention_dropout, + dropout, + activation_function, + activation_dropout, + encoder_ffn_dim, + ): + super().__init__() + self.embed_dim = d_model + self.self_attn = Qwen3OmniMoeAudioAttention(d_model, encoder_attention_heads, attention_dropout) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = dropout + self.activation_fn = ACT2FN[activation_function] + self.activation_dropout = activation_dropout + self.fc1 = nn.Linear(self.embed_dim, encoder_ffn_dim) + self.fc2 = nn.Linear(encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + **kwargs, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + **kwargs, + ) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + hidden_states = residual + hidden_states + + if hidden_states.dtype == torch.float16: + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + return outputs + + +class Qwen3OmniMoeAudioAttention(nn.Module): + def __init__(self, d_model, encoder_attention_heads, attention_dropout): + super().__init__() + self.embed_dim = d_model + self.num_heads = encoder_attention_heads + self.dropout = attention_dropout + self.head_dim = self.embed_dim // self.num_heads + self.num_key_value_groups = 1 # needed for eager attention + + if (self.head_dim * self.num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {self.num_heads})." + ) + self.scaling = self.head_dim ** -0.5 + self.attention_dropout = 0.0 + self.is_decoder = False + self.is_causal = False + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: int = 0, + **kwargs, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + seq_length, _ = hidden_states.size() + + q = self.q_proj(hidden_states).reshape(seq_length, self.num_heads, -1) + k = self.k_proj(hidden_states).reshape(seq_length, self.num_heads, -1) + v = self.v_proj(hidden_states).reshape(seq_length, self.num_heads, -1) + + attn_output = g_cache_manager.alloc_tensor(q.shape, q.dtype, device=q.device) + + flash_attention_fwd(q, k, v, attn_output, cu_seqlens, max_seqlen) + + attn_output = attn_output.reshape(seq_length, -1).contiguous() + attn_output = self.out_proj(attn_output) + return attn_output + + +class SinusoidsPositionEmbedding(nn.Module): + def __init__(self, length, channels, max_timescale=10000): + super().__init__() + if channels % 2 != 0: + raise ValueError("SinusoidsPositionEmbedding needs even channels input") + log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) + inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2).float()) + scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] + self.positional_embedding = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) + + def forward(self, seqlen: int): + return self.positional_embedding[:seqlen, :] + + +class Qwen3OmniMoeAudioEncoder(nn.Module): + def __init__( + self, + kvargs, + dropout=0, + d_model=1280, + num_mel_bins=128, + max_source_positions=1500, + scale_embedding=False, + n_window=50, + encoder_layers=32, + downsample_hidden_size=480, + activation_function="gelu", + output_dim=2048, + n_window_infer=800, + conv_chunksize=500, + encoder_attention_heads=20, + attention_dropout=0, + activation_dropout=0, + encoder_ffn_dim=5120, + ): + super().__init__() + self.data_type = kvargs.get("data_type", "bfloat16") + self.dropout = dropout + self.embed_dim = d_model + self.num_mel_bins = num_mel_bins + self.max_source_positions = max_source_positions + self.embed_scale = math.sqrt(self.embed_dim) if scale_embedding else 1.0 + self.n_window = n_window + self.positional_embedding = SinusoidsPositionEmbedding(self.max_source_positions, self.embed_dim) + self.layers = nn.ModuleList( + [ + Qwen3OmniMoeAudioEncoderLayer( + d_model, + encoder_attention_heads, + attention_dropout, + dropout, + activation_function, + activation_dropout, + encoder_ffn_dim, + ) + for _ in range(encoder_layers) + ] + ) + self.ln_post = nn.LayerNorm(d_model) + self.gradient_checkpointing = False + self.conv2d1 = nn.Conv2d(1, downsample_hidden_size, 3, 2, padding=1) + self.conv2d2 = nn.Conv2d(downsample_hidden_size, downsample_hidden_size, 3, 2, padding=1) + self.conv2d3 = nn.Conv2d(downsample_hidden_size, downsample_hidden_size, 3, 2, padding=1) + self.conv_out = nn.Linear( + downsample_hidden_size * ((((num_mel_bins + 1) // 2 + 1) // 2 + 1) // 2), + d_model, + bias=False, + ) + self.proj1 = nn.Linear(d_model, d_model) + self.act = ACT2FN[activation_function] + self.proj2 = nn.Linear(d_model, output_dim) + self.n_window_infer = n_window_infer + self.conv_chunksize = conv_chunksize + self._init_datatype() + + def _init_datatype(self): + if isinstance(self.data_type, torch.dtype): + return + if self.data_type in ["fp16", "float16"]: + self.data_type = torch.float16 + elif self.data_type in ["bf16", "bfloat16"]: + self.data_type = torch.bfloat16 + elif self.data_type in ["fp32", "float32"]: + self.data_type = torch.float32 + else: + raise ValueError(f"Unsupport datatype {self.data_type}!") + return + + def _freeze_parameters(self): + for param in self.parameters(): + param.requires_grad = False + self._requires_grad = False + + def get_input_embeddings(self) -> nn.Module: + return self.conv1 + + def set_input_embeddings(self, value: nn.Module): + self.conv1 = value + + def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): + """ + Computes the output length of the convolutional layers and the output length of the audio encoder + """ + input_lengths = (input_lengths - 1) // 2 + 1 + output_lengths = (input_lengths - 2) // 2 + 1 + return input_lengths, output_lengths + + def load_model(self, weight_dir, config): + processor_config_path = os.path.join(weight_dir, "preprocessor_config.json") + with open(processor_config_path, "r") as f: + processor_config_dict = json.load(f) + self.processor = WhisperFeatureExtractor(**processor_config_dict) + + bin_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")] + if bin_weight_files: + weight_dict = {} + for file_ in bin_weight_files: + f = torch.load(os.path.join(weight_dir, file_), "cpu") + for k, v in f.items(): + if "thinker.audio_tower" in k: + weight_dict[k[len("thinker.audio_tower.") :]] = v + else: + hf_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".safetensors")] + weight_dict = {} + for file_ in hf_weight_files: + f = safe_open(os.path.join(weight_dir, file_), "pt", "cpu") + for k in f.keys(): + if "thinker.audio_tower" in k: + weight_dict[k[len("thinker.audio_tower.") :]] = f.get_tensor(k) + + self.load_state_dict(weight_dict) + + def forward( + self, + input_features, + feature_lens=None, + aftercnn_lens=None, + ): + chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long() + + chunk_lengths = torch.tensor( + [self.n_window * 2] * chunk_num.sum(), + dtype=torch.long, + device=feature_lens.device, + ) + tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:] + chunk_lengths[tail_chunk_index] = feature_lens % (self.n_window * 2) + chunk_lengths[chunk_lengths == 0] = self.n_window * 2 + + chunk_list = input_features.T.split(chunk_lengths.tolist(), dim=0) + padded_feature = nn.utils.rnn.pad_sequence(chunk_list, batch_first=True).transpose(1, 2) + feature_lens_after_cnn = _get_feat_extract_output_lengths(chunk_lengths) + padded_mask_after_cnn = nn.utils.rnn.pad_sequence( + [torch.ones(length, dtype=torch.bool, device=padded_feature.device) for length in feature_lens_after_cnn], + batch_first=True, + ) + padded_feature = padded_feature.unsqueeze(1) + # Split to chunk to avoid OOM during convolution + padded_embeds = [] + for chunk in padded_feature.split(self.conv_chunksize, dim=0): + padded_embed = F.gelu(self.conv2d1(chunk)) + padded_embed = F.gelu(self.conv2d2(padded_embed)) + padded_embed = F.gelu(self.conv2d3(padded_embed)) + padded_embeds.append(padded_embed) + padded_embed = torch.cat(padded_embeds, dim=0) + b, c, f, t = padded_embed.size() + padded_embed = self.conv_out(padded_embed.permute(0, 3, 1, 2).contiguous().view(b, t, c * f)) + + positional_embedding = ( + self.positional_embedding.positional_embedding[: padded_embed.shape[1], :] + .unsqueeze(0) + .to(padded_embed.dtype) + .to(padded_embed.device) + ) + padded_embed = padded_embed + positional_embedding + hidden_states = padded_embed[padded_mask_after_cnn] + cu_chunk_lens = [0] + window_aftercnn = padded_mask_after_cnn.shape[-1] * (self.n_window_infer // (self.n_window * 2)) + for cnn_len in aftercnn_lens: + cu_chunk_lens += [window_aftercnn] * (cnn_len // window_aftercnn) + remainder = cnn_len % window_aftercnn + if remainder != 0: + cu_chunk_lens += [remainder] + cu_seqlens = torch.tensor(cu_chunk_lens, device=aftercnn_lens.device).cumsum(-1, dtype=torch.int32) + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + for encoder_layer in self.layers: + layer_outputs = encoder_layer( + hidden_states, + cu_seqlens, + max_seqlen, + ) + hidden_states = layer_outputs[0] + + hidden_states = self.ln_post(hidden_states) + hidden_states = self.proj1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.proj2(hidden_states) + return hidden_states + + def encode(self, audio_items: List[AudioItem], cpu_embed_cache_client: CpuEmbedCacheClient): + # 每个元素是一个chunk + batch_audios = [] + batch_audio_lens = [] + uuids = [] + items: List[AudioItem] = [] + # 记录每个chunk属于哪个audio_items下标 + chunk_owner_index = [] + for i, item in enumerate(audio_items): + if isinstance(item, AudioItem): + uuids.append(item.uuid) + items.append(item) + audio_data = read_shm(get_shm_name_data(item.uuid)) + audio = BytesIO(audio_data) + audio, _ = librosa.load(audio, sr=16000) + else: + raise ValueError(f"cannot read audio which type is {type(item)}!") + + # padding to min audio len + MIN_AUDIO_LEN = 480 + + if audio.shape[0] < MIN_AUDIO_LEN: + audio = np.pad(audio, (0, MIN_AUDIO_LEN - len(audio)), mode="constant", constant_values=0.0) + + if audio.shape[0] > self.max_length: + start = 0 + while start < audio.shape[0]: + end = min(start + self.max_length, audio.shape[0]) + chunk = audio[start:end] + + if chunk.shape[0] < MIN_AUDIO_LEN: + chunk = np.pad(chunk, (0, MIN_AUDIO_LEN - chunk.shape[0]), mode="constant", constant_values=0.0) + batch_audios.append(chunk) + batch_audio_lens.append(min(chunk.shape[0], self.max_length)) + chunk_owner_index.append(i) + + start = end + else: + batch_audio_lens.append(min(audio.shape[0], self.max_length)) + batch_audios.append(audio) + chunk_owner_index.append(i) + + batch_audio_lens = np.array(batch_audio_lens, dtype=np.int32) + + audios, audio_lens_after_cnn = self.processor._preprocess( + batch_audios, sampling_rate=16000, return_tensors="pt" + ) + audios = self.forward(audios, audio_lens_after_cnn) + audio_lens_after_cnn = np.array(audio_lens_after_cnn, dtype=np.int32) + audio_token_num = (audio_lens_after_cnn - 2) // 2 + 1 + + num_audios = len(audio_items) + per_audio_embeds = [[] for _ in range(num_audios)] + + for chunk_idx, owner in enumerate(chunk_owner_index): + token_len = int(audio_token_num[chunk_idx]) + if token_len <= 0: + continue + per_audio_embeds[owner].append(audios[chunk_idx][:token_len]) + + ready_audio = obtain(self.cache_client.root.get_items_embed(uuids)) + ids_to_set = [] + for i, ready in enumerate(ready_audio): + if ready: + continue + + uid = uuids[i] + item = items[i] + + # 拼接该 audio 的所有 chunk embedding + cur_embed = torch.cat(per_audio_embeds[i], dim=0) + cpu_embed_cache_client.copy_to_cache( + embed_tensor=cur_embed, start_index_in_cache=item.start_index_in_embed_cache + ) + ids_to_set.append(uid) + + if ids_to_set: + self.cache_client.root.set_items_embed(ids=ids_to_set) + torch.cuda.current_stream().synchronize() diff --git a/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py b/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py new file mode 100644 index 0000000000..dd9b54ee81 --- /dev/null +++ b/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py @@ -0,0 +1,408 @@ +# coding=utf-8 +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import json +import time +from PIL import Image +from io import BytesIO +from typing import List +from safetensors import safe_open +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers.activations import ACT2FN + +from lightllm.server.multimodal_params import ImageItem +from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data +from lightllm.models.qwen2_vl.vision_process import resize_image, Qwen2VLImageProcessor +from lightllm.models.qwen2_vl.qwen2_visual import VisionRotaryEmbedding, VisionFlashAttention + + +class Qwen3OmniMoeVisionMLP(nn.Module): + def __init__(self, hidden_size: int, intermediate_size: int, hidden_act: str): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.linear_fc1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=True) + self.linear_fc2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=True) + self.act_fn = ACT2FN[hidden_act] + + def forward(self, hidden_state): + return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state))) + + +class Qwen3OmniMoeVisionPatchEmbed(nn.Module): + def __init__( + self, + patch_size: int = 14, + temporal_patch_size: int = 2, + in_channels: int = 3, + embed_dim: int = 1152, + ) -> None: + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.in_channels = in_channels + self.embed_dim = embed_dim + + kernel_size = [temporal_patch_size, patch_size, patch_size] + self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + target_dtype = self.proj.weight.dtype + if hidden_states.dtype != target_dtype: + hidden_states = hidden_states.to(target_dtype) + hidden_states = hidden_states.view( + -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size + ) + hidden_states = self.proj(hidden_states).view(-1, self.embed_dim) + return hidden_states + + +class Qwen3OmniMoeVisionPatchMerger(nn.Module): + def __init__(self, hidden_size, out_hidden_size, spatial_merge_size, use_postshuffle_norm=False) -> None: + super().__init__() + self.hidden_size = hidden_size * (spatial_merge_size ** 2) + self.use_postshuffle_norm = use_postshuffle_norm + self.ln_q = nn.LayerNorm(self.hidden_size if use_postshuffle_norm else hidden_size, eps=1e-6) + self.mlp = nn.ModuleList( + [ + nn.Linear(self.hidden_size, self.hidden_size), + nn.GELU(), + nn.Linear(self.hidden_size, out_hidden_size), + ] + ) + + def forward(self, hidden: torch.Tensor) -> torch.Tensor: + hidden = self.ln_q(hidden.view(-1, self.hidden_size) if self.use_postshuffle_norm else hidden).view( + -1, self.hidden_size + ) + for layer in self.mlp: + hidden = layer(hidden) + return hidden + + +class Qwen3OmniMoeVisionBlock(nn.Module): + def __init__(self, hidden_size, intermediate_size, num_heads, hidden_act) -> None: + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, eps=1e-6) + self.norm2 = nn.LayerNorm(hidden_size, eps=1e-6) + + self.attn = VisionFlashAttention(hidden_size, num_heads=num_heads) + self.mlp = Qwen3OmniMoeVisionMLP(hidden_size, intermediate_size, hidden_act) + + def forward(self, hidden_states, cu_seqlens, max_seqlen, rotary_cos, rotary_sin) -> torch.Tensor: + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + rotary_cos=rotary_cos, + rotary_sin=rotary_sin, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +class Qwen3OmniMoeVisionTransformerPretrainedModel(nn.Module): + def __init__( + self, + kvargs, + depth=27, + out_hidden_size=4096, + hidden_size=1152, + hidden_act="gelu_pytorch_tanh", + intermediate_size=4304, + deepstack_visual_indexes=[8, 16, 24], + num_heads=16, + in_channels=3, + patch_size=16, + spatial_merge_size=2, + temporal_patch_size=2, + num_position_embeddings=2304, + **kwargs, + ): + super().__init__() + self.data_type = kvargs.get("data_type", "bfloat16") + + self.depth = depth + self.out_hidden_size = out_hidden_size + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.num_heads = num_heads + self.in_channels = in_channels + self.patch_size = patch_size + self.intermediate_size = intermediate_size + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + self.num_position_embeddings = num_position_embeddings + self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size + + self.deepstack_visual_indexes = deepstack_visual_indexes + self.merger_list = nn.ModuleList( + [ + Qwen3OmniMoeVisionPatchMerger( + hidden_size=self.hidden_size, + out_hidden_size=self.out_hidden_size, + spatial_merge_size=self.spatial_merge_size, + use_postshuffle_norm=True, + ) + for _ in range(len(self.deepstack_visual_indexes)) + ] + ) + + self.patch_embed = Qwen3OmniMoeVisionPatchEmbed( + patch_size=self.patch_size, + temporal_patch_size=self.temporal_patch_size, + in_channels=self.in_channels, + embed_dim=self.hidden_size, + ) + + self.pos_embed = nn.Embedding(self.num_position_embeddings, self.hidden_size) + self.num_grid_per_side = int(self.num_position_embeddings ** 0.5) + + head_dim = self.hidden_size // self.num_heads + self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2).cuda() + + self.blocks = nn.ModuleList( + [ + Qwen3OmniMoeVisionBlock(self.hidden_size, self.intermediate_size, self.num_heads, self.hidden_act) + for _ in range(self.depth) + ] + ) + self.merger = Qwen3OmniMoeVisionPatchMerger( + hidden_size=self.hidden_size, + out_hidden_size=self.out_hidden_size, + spatial_merge_size=self.spatial_merge_size, + use_postshuffle_norm=False, + ) + self._init_datatype() + + def _init_datatype(self): + if isinstance(self.data_type, torch.dtype): + return + if self.data_type in ["fp16", "float16"]: + self.data_type = torch.float16 + elif self.data_type in ["bf16", "bfloat16"]: + self.data_type = torch.bfloat16 + elif self.data_type in ["fp32", "float32"]: + self.data_type = torch.float32 + else: + raise ValueError(f"Unsupport datatype {self.data_type}!") + return + + def concat_img_embed_and_deepstack_features(self, image_embed, deepstack_feature_lists, valid_ids): + all_chunks = [] + + for start, end in valid_ids: + hs_i = image_embed[start:end] + ds_i_list = [feat[start:end] for feat in deepstack_feature_lists] + combined_i = torch.cat([hs_i, *ds_i_list], dim=1).view((end - start), len(ds_i_list) + 1, hs_i.shape[-1]) + all_chunks.append(combined_i) + + all_img_embeds_ds = torch.cat(all_chunks, dim=0) + return all_img_embeds_ds, valid_ids + + def load_model(self, weight_dir): + + processor_config_path = os.path.join(weight_dir, "preprocessor_config.json") + with open(processor_config_path, "r") as f: + processor_config_dict = json.load(f) + self.processor = Qwen2VLImageProcessor(**processor_config_dict) + + bin_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")] + if bin_weight_files: + weight_dict = {} + for file_ in bin_weight_files: + f = torch.load(os.path.join(weight_dir, file_), "cpu") + for k, v in f.items(): + if "thinker.visual" in k: + weight_dict[k[len("thinker.visual.") :]] = v + else: + hf_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".safetensors")] + weight_dict = {} + for file_ in hf_weight_files: + f = safe_open(os.path.join(weight_dir, file_), "pt", "cpu") + for k in f.keys(): + if "thinker.visual" in k: + weight_dict[k[len("thinker.visual.") :]] = f.get_tensor(k) + + self.load_state_dict(weight_dict) + + def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + merge_size = self.spatial_merge_size + + total_tokens = int(torch.prod(grid_thw, dim=1).sum().item()) + pos_ids = torch.empty((total_tokens, 2), dtype=torch.long) + + offset = 0 + for num_frames, height, width in grid_thw: + merged_h, merged_w = height // merge_size, width // merge_size + + block_rows = torch.arange(merged_h) # block row indices + block_cols = torch.arange(merged_w) # block col indices + intra_row = torch.arange(merge_size) # intra-block row offsets + intra_col = torch.arange(merge_size) # intra-block col offsets + + # Compute full-resolution positions + row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None] + col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :] + + row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) + col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) + + coords = torch.stack((row_idx, col_idx), dim=-1) + + if num_frames > 1: + coords = coords.repeat(num_frames, 1) + + num_tokens = coords.shape[0] + pos_ids[offset : offset + num_tokens] = coords + offset += num_tokens + + max_hw = int(grid_thw[:, 1:].max().item()) + cos_full, sin_full = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2) + cos = cos_full[pos_ids].flatten(1) + sin = sin_full[pos_ids].flatten(1) + return cos, sin + + def fast_pos_embed_interpolate(self, grid_thw): + grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2] + + idx_list = [[] for _ in range(4)] + weight_list = [[] for _ in range(4)] + + for t, h, w in zip(grid_ts, grid_hs, grid_ws): + h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) + w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w) + + h_idxs_floor = h_idxs.int() + w_idxs_floor = w_idxs.int() + h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + + dh = h_idxs - h_idxs_floor + dw = w_idxs - w_idxs_floor + + base_h = h_idxs_floor * self.num_grid_per_side + base_h_ceil = h_idxs_ceil * self.num_grid_per_side + + indices = [ + (base_h[None].T + w_idxs_floor[None]).flatten(), + (base_h[None].T + w_idxs_ceil[None]).flatten(), + (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), + (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), + ] + + weights = [ + ((1 - dh)[None].T * (1 - dw)[None]).flatten(), + ((1 - dh)[None].T * dw[None]).flatten(), + (dh[None].T * (1 - dw)[None]).flatten(), + (dh[None].T * dw[None]).flatten(), + ] + + for i in range(4): + idx_list[i].extend(indices[i].tolist()) + weight_list[i].extend(weights[i].tolist()) + + idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=self.pos_embed.weight.device) + weight_tensor = torch.tensor( + weight_list, dtype=self.pos_embed.weight.dtype, device=self.pos_embed.weight.device + ) + pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None] + patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] + + patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)]) + + patch_pos_embeds_permute = [] + merge_size = self.spatial_merge_size + for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): + pos_embed = pos_embed.repeat(t, 1) + pos_embed = ( + pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1) + .permute(0, 1, 3, 2, 4, 5) + .flatten(0, 4) + ) + patch_pos_embeds_permute.append(pos_embed) + patch_pos_embeds = torch.cat(patch_pos_embeds_permute) + return patch_pos_embeds + + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: + hidden_states = self.patch_embed(hidden_states) + pos_embeds = self.fast_pos_embed_interpolate(grid_thw) + hidden_states = hidden_states + pos_embeds + rotary_cos, rotary_sin = self.rot_pos_emb(grid_thw) + rotary_cos = rotary_cos.to("cuda", non_blocking=True) + rotary_sin = rotary_sin.to("cuda", non_blocking=True) + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, + dtype=torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0).to("cuda", non_blocking=True) + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + deepstack_feature_lists = [] + for layer_num, blk in enumerate(self.blocks): + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + rotary_cos=rotary_cos, + rotary_sin=rotary_sin, + ) + if layer_num in self.deepstack_visual_indexes: + deepstack_feature = self.merger_list[self.deepstack_visual_indexes.index(layer_num)](hidden_states) + deepstack_feature_lists.append(deepstack_feature) + + hidden_states = self.merger(hidden_states) + print(f"hidden_states is {hidden_states}, deepstack is {deepstack_feature_lists}") + return hidden_states, deepstack_feature_lists + + def encode(self, images: List[ImageItem]): + img_tensors = [] + valid_ids = [] + valid_id = 0 + img_grids = [] + uuids = [] + for i, img in enumerate(images): + if isinstance(img, ImageItem): + uuids.append(img.uuid) + image_data = read_shm(get_shm_name_data(img.uuid)) + image_data = Image.open(BytesIO(image_data)) + pixel_values, image_grid_thw = self.processor.preprocess(image_data) + img_tensors.append(pixel_values) + img_grids.append(image_grid_thw) + else: + raise Exception("Unsupport input types: {} for {}".format(type(img), img)) + + # must devide merge_length + cur_num = img_tensors[-1].shape[0] // (self.spatial_merge_size ** 2) + + valid_ids.append([valid_id, valid_id + cur_num]) + valid_id += cur_num + + if len(img_tensors) <= 0: + return None + + imgs = torch.cat(img_tensors, dim=0) + grid_thw = torch.cat(img_grids, dim=0) + + pixel_values = imgs.to("cuda", dtype=self.data_type, non_blocking=True) + image_grid_thw = grid_thw.to("cuda", non_blocking=True) + img_embeds, deepstack_feature_lists = self.forward(pixel_values, grid_thw=image_grid_thw) + all_img_embeds_df, valid_ids = self.concat_img_embed_and_deepstack_features( + img_embeds, deepstack_feature_lists, valid_ids + ) + + return all_img_embeds_df, uuids, valid_ids diff --git a/lightllm/server/audioserver/model_infer/model_rpc.py b/lightllm/server/audioserver/model_infer/model_rpc.py index f3c813800c..8c123bbd1c 100644 --- a/lightllm/server/audioserver/model_infer/model_rpc.py +++ b/lightllm/server/audioserver/model_infer/model_rpc.py @@ -4,6 +4,7 @@ from typing import Dict, List, Tuple from transformers.configuration_utils import PretrainedConfig from lightllm.models.whisper.whisper_audio import WhisperAudioModel +from lightllm.models.qwen3_omni_moe_thinker.qwen3_omni_audio import Qwen3OmniMoeAudioEncoder from lightllm.server.multimodal_params import AudioItem from lightllm.utils.infer_utils import set_random_seed from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient @@ -19,6 +20,9 @@ def exposed_init_model(self, kvargs): weight_dir = kvargs["weight_dir"] model_cfg, _ = PretrainedConfig.get_config_dict(weight_dir) + if model_cfg.get("thinker_config") is not None: + model_cfg = model_cfg["thinker_config"] + audio_config = model_cfg["audio_config"] model_kvargs = {"cache_port": kvargs["cache_port"], "data_type": kvargs["data_type"]} @@ -26,6 +30,8 @@ def exposed_init_model(self, kvargs): self.model_type = audio_config["model_type"] if self.model_type == "clap_audio_model" or self.model_type == "whisper": self.model = WhisperAudioModel(model_kvargs) + elif self.model_type == "qwen3_omni_moe_audio_encoder": + self.model = Qwen3OmniMoeAudioEncoder(model_kvargs) else: raise Exception(f"can not support {self.model_type} now") diff --git a/lightllm/server/tokenizer.py b/lightllm/server/tokenizer.py index e0b2bd425e..c245fa85a6 100644 --- a/lightllm/server/tokenizer.py +++ b/lightllm/server/tokenizer.py @@ -31,6 +31,7 @@ from ..models.qwen3_vl.model import QWen3VLTokenizer from ..models.internvl.model import InternvlTokenizer from ..models.gemma3.model import Gemma3Tokenizer +from ..models.qwen3_omni_moe_thinker.model import QWen3OmniTokenizer # A fast LLaMA tokenizer with the pre-processed `tokenizer.json` file. _FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer" @@ -100,6 +101,12 @@ def get_tokenizer( tokenizer = QWen3VLTokenizer( tokenizer=tokenizer, image_processor=processor.image_processor, model_cfg=model_cfg ) + elif model_cfg.get("thinker_config") is not None: + from transformers import AutoProcessor + + model_cfg = model_cfg["thinker_config"] + processor = AutoProcessor.from_pretrained(tokenizer_name) + tokenizer = QWen3OmniTokenizer(tokenizer, image_processor=processor.image_processor, model_cfg=model_cfg) elif model_type == "internvl_chat": tokenizer = InternvlTokenizer(tokenizer, model_cfg, weight_dir=tokenizer_name) elif model_type == "gemma3": diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 8f4a1ee450..e034dc662d 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -19,6 +19,7 @@ from lightllm.models.qwen2_5_vl.qwen2_5_visual import Qwen2_5_VisionTransformerPretrainedModel from lightllm.models.qwen3_vl.qwen3_visual import Qwen3VisionTransformerPretrainedModel from lightllm.models.tarsier2.tarsier2_visual import TarsierVisionTransformerPretrainedModel +from lightllm.models.qwen3_omni_moe_thinker.qwen3_omni_visual import Qwen3OmniMoeVisionTransformerPretrainedModel from lightllm.utils.infer_utils import set_random_seed from lightllm.utils.dist_utils import init_vision_distributed_env from lightllm.utils.graceful_utils import graceful_registry @@ -80,6 +81,15 @@ def exposed_init_model(self, kvargs): # self.model = InternVLVisionModel() elif self.model_type == "gemma3": self.model = Gemma3VisionModel() + elif ( + model_cfg.get("thinker_config", {}).get("vision_config", {}).get("model_type") + == "qwen3_omni_moe_vision_encoder" + ): + self.model = ( + Qwen3OmniMoeVisionTransformerPretrainedModel(kvargs, **model_cfg["thinker_config"]["vision_config"]) + .eval() + .bfloat16() + ) else: raise Exception(f"can not support {self.model_type} now") diff --git a/lightllm/utils/config_utils.py b/lightllm/utils/config_utils.py index b06309f96f..9085be4572 100644 --- a/lightllm/utils/config_utils.py +++ b/lightllm/utils/config_utils.py @@ -25,6 +25,8 @@ def _get_config_llm_keyvalue(model_path: str, key_name: list[str]): value = config_json["llm_config"][key] except: value = config_json.get("text_config", {}).get(key) + if config_json.get("thinker_config") is not None: + value = config_json.get("thinker_config", {}).get("text_config").get(key) if value is not None: return value @@ -78,6 +80,7 @@ def get_layer_num(model_path: str) -> int: def get_eos_token_ids(model_path: str) -> Optional[List[int]]: eos_token_id = _get_config_llm_keyvalue(model_path=model_path, key_name=["eos_token_id"]) + return [151645] # 后面看看怎么改? 直接改config.json? if isinstance(eos_token_id, int): return [eos_token_id] if isinstance(eos_token_id, list): @@ -100,6 +103,8 @@ def get_model_architectures(model_path: str): def get_vocab_size(model_path: str): try: config_json = get_config_json(model_path) + if "thinker_config" in config_json: + config_json = config_json["thinker_config"] if "llm_config" in config_json: vocab_size = int(config_json["llm_config"]["vocab_size"]) return vocab_size From d8cc7b2564af7e0cb38a4a8c2d44825637bb78b6 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 5 Feb 2026 06:41:52 +0000 Subject: [PATCH 02/20] fix --- lightllm/utils/config_utils.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/lightllm/utils/config_utils.py b/lightllm/utils/config_utils.py index 9085be4572..790f185f25 100644 --- a/lightllm/utils/config_utils.py +++ b/lightllm/utils/config_utils.py @@ -79,8 +79,15 @@ def get_layer_num(model_path: str) -> int: def get_eos_token_ids(model_path: str) -> Optional[List[int]]: + try: + # qwen3-omini special eos_token_id + config_json = get_config_json(model_path) + assert config_json["architectures"][0] == "Qwen3OmniMoeForConditionalGeneration" + return [151645] + except: + pass + eos_token_id = _get_config_llm_keyvalue(model_path=model_path, key_name=["eos_token_id"]) - return [151645] # 后面看看怎么改? 直接改config.json? if isinstance(eos_token_id, int): return [eos_token_id] if isinstance(eos_token_id, list): @@ -103,6 +110,7 @@ def get_model_architectures(model_path: str): def get_vocab_size(model_path: str): try: config_json = get_config_json(model_path) + # qwen3-omini special if "thinker_config" in config_json: config_json = config_json["thinker_config"] if "llm_config" in config_json: From f50c20ad08ddf990e0d4c06840288d1dc5b08de8 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Thu, 5 Feb 2026 07:12:22 +0000 Subject: [PATCH 03/20] fix qwen3-omni tokenizer --- lightllm/models/qwen3_omni_moe_thinker/model.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/lightllm/models/qwen3_omni_moe_thinker/model.py b/lightllm/models/qwen3_omni_moe_thinker/model.py index 4ceab9573c..c8de223e65 100644 --- a/lightllm/models/qwen3_omni_moe_thinker/model.py +++ b/lightllm/models/qwen3_omni_moe_thinker/model.py @@ -53,14 +53,6 @@ def init_audioitem_extral_params( ): return - def get_image_token_length(self, img: ImageItem): - return ( - self.get_image_patch_func( - img.image_w, img.image_h, max_num=img.extra_params["image_patch_max_num"], use_thumbnail=True - ) - * self.image_length - ) - def get_audio_token_length(self, audio: AudioItem): L = audio.audio_length audio_token_num = 0 From 2cc5d01c88b03c00836ab1b63db0b7e7117d2d68 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 5 Feb 2026 07:32:51 +0000 Subject: [PATCH 04/20] fix prelayer. --- .../pre_and_post_layer_weight.py | 11 ------ .../pre_and_post_layer_weight.py | 37 +++++++++++-------- 2 files changed, 22 insertions(+), 26 deletions(-) diff --git a/lightllm/models/qwen2/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen2/layer_weights/pre_and_post_layer_weight.py index 59c965b692..6449430d9e 100644 --- a/lightllm/models/qwen2/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen2/layer_weights/pre_and_post_layer_weight.py @@ -1,18 +1,7 @@ from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import LMHeadWeight class Qwen2PreAndPostLayerWeight(LlamaPreAndPostLayerWeight): def __init__(self, data_type, network_config): super().__init__(data_type, network_config) - hidden_size = network_config["hidden_size"] - vocab_size = network_config["vocab_size"] - tie_word_embeddings = network_config.get("tie_word_embeddings", False) - self.lm_head_weight_ = LMHeadWeight( - dim=hidden_size, - vocab_size=vocab_size, - weight_name="thinker.lm_head.weight", - data_type=self.data_type_, - embedding_weight=self.wte_weight_ if tie_word_embeddings else None, - ) return diff --git a/lightllm/models/qwen3_omni_moe_thinker/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen3_omni_moe_thinker/layer_weights/pre_and_post_layer_weight.py index 614a68bcb2..5ac8060c4e 100644 --- a/lightllm/models/qwen3_omni_moe_thinker/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen3_omni_moe_thinker/layer_weights/pre_and_post_layer_weight.py @@ -1,23 +1,30 @@ -import numpy as np from lightllm.models.qwen2.layer_weights.pre_and_post_layer_weight import Qwen2PreAndPostLayerWeight - -# add key: language_model.xxx -> xxx -# only change keys at PreAndPostLayerWeight load, TransformLayerWeight is correct now -def rename_weight_keys(weights): - prefix = "thinker.model." - keys = list(weights.keys()) - for k in keys: - if prefix in k: - weights[k.replace(prefix, "model.")] = weights.pop(k) +from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LMHeadWeight, RMSNormWeight class Qwen3OmniMOEThinkerPreAndPostLayerWeight(Qwen2PreAndPostLayerWeight): def __init__(self, data_type, network_config): - self.layer_num_ = network_config["num_hidden_layers"] super().__init__(data_type, network_config) - return - def load_hf_weights(self, weights): - rename_weight_keys(weights) - super().load_hf_weights(weights) + hidden_size = network_config["hidden_size"] + vocab_size = network_config["vocab_size"] + self.wte_weight_ = EmbeddingWeight( + dim=hidden_size, + vocab_size=vocab_size, + weight_name="thinker.model.embed_tokens.weight", + data_type=self.data_type_, + ) + tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False) + self.lm_head_weight_ = LMHeadWeight( + dim=hidden_size, + vocab_size=vocab_size, + weight_name="thinker.lm_head.weight", + data_type=self.data_type_, + embedding_weight=self.wte_weight_ if tie_word_embeddings else None, + ) + self.final_norm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name="thinker.model.norm.weight", + data_type=self.data_type_, + ) return From c21060d210d20feef011a712f990e11493f51a86 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 5 Feb 2026 07:38:31 +0000 Subject: [PATCH 05/20] fix --- .../common/basemodel/layer_weights/base_layer_weight.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/lightllm/common/basemodel/layer_weights/base_layer_weight.py b/lightllm/common/basemodel/layer_weights/base_layer_weight.py index 1875e2c3b3..b1d992a7c4 100644 --- a/lightllm/common/basemodel/layer_weights/base_layer_weight.py +++ b/lightllm/common/basemodel/layer_weights/base_layer_weight.py @@ -33,7 +33,11 @@ def verify_load(self): for attr_name in dir(self): attr = getattr(self, attr_name) if isinstance(attr, BaseWeight): - assert attr.verify_load(), f"Loading {attr_name} of layers {self.layer_num_} fails." + if hasattr(self, "layer_num_"): + layer_num = self.layer_num_ + else: + layer_num = None + assert attr.verify_load(), f"Loading {attr_name} of layers {layer_num} fails." def _cuda(self, cpu_tensor): return cpu_tensor.contiguous().to(self.data_type_).cuda(get_current_device_id()) From 1ebec5a154dec08d82f88a4db24877397d931269 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 5 Feb 2026 07:47:46 +0000 Subject: [PATCH 06/20] fix transformer layer weight. --- .../transformers_layer_weight.py | 45 ++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) diff --git a/lightllm/models/qwen3_omni_moe_thinker/layer_weights/transformers_layer_weight.py b/lightllm/models/qwen3_omni_moe_thinker/layer_weights/transformers_layer_weight.py index 220d846e0e..775ba5ffe2 100644 --- a/lightllm/models/qwen3_omni_moe_thinker/layer_weights/transformers_layer_weight.py +++ b/lightllm/models/qwen3_omni_moe_thinker/layer_weights/transformers_layer_weight.py @@ -5,6 +5,49 @@ class Qwen3OmniMOEThinkerTransformerLayerWeight(Qwen3MOETransformerLayerWeight): def __init__(self, layer_num, data_type, network_config, quant_cfg=None): - self.layer_num_ = network_config["num_hidden_layers"] super().__init__(layer_num, data_type, network_config, quant_cfg) return + + def _init_weight_names(self): + self._q_weight_name = f"thinker.model.layers.{self.layer_num_}.self_attn.q_proj.weight" + self._q_norm_name = f"thinker.model.layers.{self.layer_num_}.self_attn.q_norm.weight" + self._q_bias_name = None + self._k_weight_name = f"thinker.model.layers.{self.layer_num_}.self_attn.k_proj.weight" + self._k_norm_name = f"thinker.model.layers.{self.layer_num_}.self_attn.k_norm.weight" + self._k_bias_name = None + self._v_weight_name = f"thinker.model.layers.{self.layer_num_}.self_attn.v_proj.weight" + self._v_bias_name = None + self._kv_weight_name = f"thinker.model.layers.{self.layer_num_}.self_attn.kv_proj.weight" + self._kv_bias_name = None + self._o_weight_name = f"thinker.model.layers.{self.layer_num_}.self_attn.o_proj.weight" + self._o_bias_name = None + self._att_norm_weight_name = f"thinker.model.layers.{self.layer_num_}.input_layernorm.weight" + self._att_norm_bias_name = None + self._ffn_norm_weight_name = f"thinker.model.layers.{self.layer_num_}.post_attention_layernorm.weight" + self._ffn_norm_bias_name = None + + def _init_moe(self): + moe_intermediate_size = self.network_config_["moe_intermediate_size"] + self.moe_gate = ROWMMWeight( + in_dim=self.network_config_["hidden_size"], + out_dims=[self.n_routed_experts], + weight_names=f"thinker.model.layers.{self.layer_num_}.mlp.gate.weight", + data_type=self.data_type_, + quant_method=None, + tp_rank=0, + tp_world_size=1, + ) + self.experts = FusedMoeWeight( + gate_proj_name="gate_proj", + down_proj_name="down_proj", + up_proj_name="up_proj", + e_score_correction_bias_name="", + weight_prefix=f"thinker.model.layers.{self.layer_num_}.mlp.experts", + n_routed_experts=self.n_routed_experts, + hidden_size=self.network_config_["hidden_size"], + moe_intermediate_size=moe_intermediate_size, + data_type=self.data_type_, + quant_method=self.quant_cfg.get_quant_method(self.layer_num_, "fused_moe"), + layer_num=self.layer_num_, + network_config=self.network_config_, + ) From bd24b7ef5799d1af1716bad686163eb00b2df7fd Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 5 Feb 2026 08:07:16 +0000 Subject: [PATCH 07/20] fix mrope --- lightllm/models/qwen2_vl/infer_struct.py | 4 ++-- .../qwen2_vl/triton_kernel/get_mrope_position_ids.py | 9 ++++++--- lightllm/models/qwen3_omni_moe_thinker/infer_struct.py | 4 +++- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/lightllm/models/qwen2_vl/infer_struct.py b/lightllm/models/qwen2_vl/infer_struct.py index 2508fd554a..67b2181ba5 100644 --- a/lightllm/models/qwen2_vl/infer_struct.py +++ b/lightllm/models/qwen2_vl/infer_struct.py @@ -34,7 +34,7 @@ def init_some_extra_state(self, model): self.position_sin = model._sin_cached[self.position_ids] return - def get_mrope_position(self, multimodal_params: List[dict]) -> torch.Tensor: + def get_mrope_position(self, multimodal_params: List[dict], is_qwen3_omini: bool = False) -> torch.Tensor: if len(multimodal_params) == 0: return self.position_ids.unsqueeze(0).expand(3, -1) b_image_start_idx = [] @@ -79,6 +79,6 @@ def get_mrope_position(self, multimodal_params: List[dict]) -> torch.Tensor: b_ready_cache_len=self.b_ready_cache_len, b_q_seq_len=self.b_q_seq_len, b_start_loc=self.b_q_start_loc, - use_image_h=getattr(self, "use_image_h", True), + qwen3_omini_mode=is_qwen3_omini, ) return position_ids diff --git a/lightllm/models/qwen2_vl/triton_kernel/get_mrope_position_ids.py b/lightllm/models/qwen2_vl/triton_kernel/get_mrope_position_ids.py index b5455fb4cf..161f5a4b1c 100644 --- a/lightllm/models/qwen2_vl/triton_kernel/get_mrope_position_ids.py +++ b/lightllm/models/qwen2_vl/triton_kernel/get_mrope_position_ids.py @@ -17,6 +17,7 @@ def _get_mrope_position_triton( b_q_seq_len: torch.Tensor, b_start_loc: torch.Tensor, BLOCK_SIZE: tl.constexpr, + QWEN3_OMNI_MODE: tl.constexpr = False, ) -> torch.Tensor: cur_batch = tl.program_id(0) cache_len = tl.load(b_ready_cache_len + cur_batch) @@ -30,6 +31,9 @@ def _get_mrope_position_triton( image_len = tl.load(b_image_len + image_start_num + i) image_h = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 1) image_w = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 2) + # qwen3-omini is right, older qwen3-vl is error setting, but must keep compatible + if QWEN3_OMNI_MODE: + image_h = image_w for j in range(0, image_len, BLOCK_SIZE): off = j + tl.arange(0, BLOCK_SIZE) # 目前没考虑视频,所以t 恒为 0 @@ -91,14 +95,12 @@ def get_mrope_position_triton( b_ready_cache_len: torch.Tensor, b_q_seq_len: torch.Tensor, b_start_loc: torch.Tensor, - use_image_h: bool = True, + qwen3_omini_mode: bool = False, ) -> torch.Tensor: batch_size = b_q_seq_len.shape[0] assert batch_size == b_image_nums.shape[0] grid = (batch_size,) - if not use_image_h: # 也可以放在前面生成的地方改, 看哪里合适 - b_image_thwd[:, 1] = b_image_thwd[:, 2] BLOCK_SIZE = 64 _get_mrope_position_triton[grid]( b_image_start_idx=b_image_start_idx, @@ -113,6 +115,7 @@ def get_mrope_position_triton( b_q_seq_len=b_q_seq_len, b_start_loc=b_start_loc, BLOCK_SIZE=BLOCK_SIZE, + QWEN3_OMNI_MODE=qwen3_omini_mode, ) diff --git a/lightllm/models/qwen3_omni_moe_thinker/infer_struct.py b/lightllm/models/qwen3_omni_moe_thinker/infer_struct.py index a273f2abae..66bc79cfed 100644 --- a/lightllm/models/qwen3_omni_moe_thinker/infer_struct.py +++ b/lightllm/models/qwen3_omni_moe_thinker/infer_struct.py @@ -3,5 +3,7 @@ class Qwen3OmniMOEInferStateInfo(Qwen3VLInferStateInfo): def __init__(self): - self.use_image_h = False super().__init__() + + def get_mrope_position(self, multimodal_params): + return super().get_mrope_position(multimodal_params, is_qwen3_omini=True) From 68de2b34e42353856c1f697cb0c78582524c3310 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 5 Feb 2026 09:21:00 +0000 Subject: [PATCH 08/20] Fix mrope --- lightllm/models/qwen2_vl/infer_struct.py | 3 +-- .../qwen2_vl/triton_kernel/get_mrope_position_ids.py | 10 ++-------- lightllm/models/qwen3_omni_moe_thinker/infer_struct.py | 3 --- 3 files changed, 3 insertions(+), 13 deletions(-) diff --git a/lightllm/models/qwen2_vl/infer_struct.py b/lightllm/models/qwen2_vl/infer_struct.py index 67b2181ba5..747be932d9 100644 --- a/lightllm/models/qwen2_vl/infer_struct.py +++ b/lightllm/models/qwen2_vl/infer_struct.py @@ -34,7 +34,7 @@ def init_some_extra_state(self, model): self.position_sin = model._sin_cached[self.position_ids] return - def get_mrope_position(self, multimodal_params: List[dict], is_qwen3_omini: bool = False) -> torch.Tensor: + def get_mrope_position(self, multimodal_params: List[dict]) -> torch.Tensor: if len(multimodal_params) == 0: return self.position_ids.unsqueeze(0).expand(3, -1) b_image_start_idx = [] @@ -79,6 +79,5 @@ def get_mrope_position(self, multimodal_params: List[dict], is_qwen3_omini: bool b_ready_cache_len=self.b_ready_cache_len, b_q_seq_len=self.b_q_seq_len, b_start_loc=self.b_q_start_loc, - qwen3_omini_mode=is_qwen3_omini, ) return position_ids diff --git a/lightllm/models/qwen2_vl/triton_kernel/get_mrope_position_ids.py b/lightllm/models/qwen2_vl/triton_kernel/get_mrope_position_ids.py index 161f5a4b1c..eace676c00 100644 --- a/lightllm/models/qwen2_vl/triton_kernel/get_mrope_position_ids.py +++ b/lightllm/models/qwen2_vl/triton_kernel/get_mrope_position_ids.py @@ -17,7 +17,6 @@ def _get_mrope_position_triton( b_q_seq_len: torch.Tensor, b_start_loc: torch.Tensor, BLOCK_SIZE: tl.constexpr, - QWEN3_OMNI_MODE: tl.constexpr = False, ) -> torch.Tensor: cur_batch = tl.program_id(0) cache_len = tl.load(b_ready_cache_len + cur_batch) @@ -29,16 +28,13 @@ def _get_mrope_position_triton( local_image_start_idx = tl.load(b_image_start_idx + image_start_num + i) image_start_idx = start_loc + local_image_start_idx - cache_len image_len = tl.load(b_image_len + image_start_num + i) - image_h = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 1) + # image_h = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 1) image_w = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 2) - # qwen3-omini is right, older qwen3-vl is error setting, but must keep compatible - if QWEN3_OMNI_MODE: - image_h = image_w for j in range(0, image_len, BLOCK_SIZE): off = j + tl.arange(0, BLOCK_SIZE) # 目前没考虑视频,所以t 恒为 0 t_pos = local_image_start_idx + off * 0 - h_pos = local_image_start_idx + off // image_h + h_pos = local_image_start_idx + off // image_w w_pos = local_image_start_idx + off % image_w tl.store( position_ids + off + image_start_idx, @@ -95,7 +91,6 @@ def get_mrope_position_triton( b_ready_cache_len: torch.Tensor, b_q_seq_len: torch.Tensor, b_start_loc: torch.Tensor, - qwen3_omini_mode: bool = False, ) -> torch.Tensor: batch_size = b_q_seq_len.shape[0] @@ -115,7 +110,6 @@ def get_mrope_position_triton( b_q_seq_len=b_q_seq_len, b_start_loc=b_start_loc, BLOCK_SIZE=BLOCK_SIZE, - QWEN3_OMNI_MODE=qwen3_omini_mode, ) diff --git a/lightllm/models/qwen3_omni_moe_thinker/infer_struct.py b/lightllm/models/qwen3_omni_moe_thinker/infer_struct.py index 66bc79cfed..1c09ebf446 100644 --- a/lightllm/models/qwen3_omni_moe_thinker/infer_struct.py +++ b/lightllm/models/qwen3_omni_moe_thinker/infer_struct.py @@ -4,6 +4,3 @@ class Qwen3OmniMOEInferStateInfo(Qwen3VLInferStateInfo): def __init__(self): super().__init__() - - def get_mrope_position(self, multimodal_params): - return super().get_mrope_position(multimodal_params, is_qwen3_omini=True) From 886306d4bfe0f55f9ab79d2f2fb8f0ec2fc46f80 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 5 Feb 2026 09:43:57 +0000 Subject: [PATCH 09/20] fix cpu cache impl. --- lightllm/utils/embed_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lightllm/utils/embed_utils.py b/lightllm/utils/embed_utils.py index 81b05dc291..dec1537d66 100644 --- a/lightllm/utils/embed_utils.py +++ b/lightllm/utils/embed_utils.py @@ -27,12 +27,12 @@ def calcu_embed_cache_meta() -> "EmbedCacheMeta": args = get_env_start_args() assert args.enable_multimodal from lightllm.utils.llm_utils import get_llm_model_class - from lightllm.models import Qwen3VLTpPartModel, Qwen3VLMOETpPartModel + from lightllm.models import Qwen3VLTpPartModel, Qwen3VLMOETpPartModel, Qwen3OmniMOETpPartModel model_class = get_llm_model_class() model_dir = args.model_dir - if model_class in [Qwen3VLTpPartModel, Qwen3VLMOETpPartModel]: + if model_class in [Qwen3VLTpPartModel, Qwen3VLMOETpPartModel, Qwen3OmniMOETpPartModel]: embed_cache_meta_data = EmbedCacheMeta( token_num=None, layer_num=4, From e988b20b69287317ccb12034d4d19e958cc9172e Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 5 Feb 2026 09:49:38 +0000 Subject: [PATCH 10/20] add embed cache vision function --- lightllm/server/embed_cache/embed_cache_client.py | 13 +++++++++++++ .../server/visualserver/model_infer/model_rpc.py | 2 +- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/lightllm/server/embed_cache/embed_cache_client.py b/lightllm/server/embed_cache/embed_cache_client.py index ad01dd8082..8c5b7f71ee 100644 --- a/lightllm/server/embed_cache/embed_cache_client.py +++ b/lightllm/server/embed_cache/embed_cache_client.py @@ -47,6 +47,19 @@ def copy_to_cache(self, embed_tensor: torch.Tensor, start_index_in_cache: int): start_index_in_cache=start_index_in_cache, ) + def copy_vision_to_cache(self, embed_tensor: torch.Tensor, start_index_in_cache: int): + from .copy_to_cache import offload_embed_tensor_to_cache + + if embed_tensor.ndim == 3: + # check for qwen3 vision embed tensor shape, use apply deepstack + assert embed_tensor.shape[1] == self.cpu_embed_cache_tensor.shape[1] + + offload_embed_tensor_to_cache( + embed_tensor=embed_tensor, + cache_tensor=self.cpu_embed_cache_tensor, + start_index_in_cache=start_index_in_cache, + ) + def _create_shm_embed_kv_cache(self): shm_ptr = create_shm_kv_cache_ptr( key=self.args.multi_modal_cache_shm_id, size=self.embed_cache_tensor_meta.calcu_size() diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index e034dc662d..3e97f4de3e 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -127,7 +127,7 @@ def exposed_encode(self, images: List[ImageItem]): uid = uuids[i] start, end = valid_ids[i] image = images[i] - self.cpu_embed_cache_client.copy_to_cache( + self.cpu_embed_cache_client.copy_vision_to_cache( embed_tensor=all_img_embeds[start:end], start_index_in_cache=image.start_index_in_embed_cache ) ids_to_set.append(uid) From 74ae14c35222d1698a1e2c63b1aff150e7c2b1a3 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Thu, 5 Feb 2026 09:59:58 +0000 Subject: [PATCH 11/20] fix0205 --- .../qwen3_omni_moe_thinker/audio_process.py | 2 +- .../qwen3_omni_audio.py | 60 ++++--------------- 2 files changed, 13 insertions(+), 49 deletions(-) diff --git a/lightllm/models/qwen3_omni_moe_thinker/audio_process.py b/lightllm/models/qwen3_omni_moe_thinker/audio_process.py index 3b15bd271e..74f616aedc 100644 --- a/lightllm/models/qwen3_omni_moe_thinker/audio_process.py +++ b/lightllm/models/qwen3_omni_moe_thinker/audio_process.py @@ -177,4 +177,4 @@ def _preprocess( if return_tensors is not None: padded_inputs = padded_inputs.convert_to_tensors(return_tensors) - return padded_inputs + return padded_inputs["input_features"], padded_inputs["attention_mask"] diff --git a/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_audio.py b/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_audio.py index 7e62c14ecd..9d41de4df3 100644 --- a/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_audio.py +++ b/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_audio.py @@ -333,13 +333,8 @@ def forward( return hidden_states def encode(self, audio_items: List[AudioItem], cpu_embed_cache_client: CpuEmbedCacheClient): - # 每个元素是一个chunk - batch_audios = [] - batch_audio_lens = [] uuids = [] items: List[AudioItem] = [] - # 记录每个chunk属于哪个audio_items下标 - chunk_owner_index = [] for i, item in enumerate(audio_items): if isinstance(item, AudioItem): uuids.append(item.uuid) @@ -349,48 +344,19 @@ def encode(self, audio_items: List[AudioItem], cpu_embed_cache_client: CpuEmbedC audio, _ = librosa.load(audio, sr=16000) else: raise ValueError(f"cannot read audio which type is {type(item)}!") + # 这里后面还要改 + input_features, feature_attention_mask = self.processor._preprocess(audio) + if feature_attention_mask is not None: + audio_feature_lengths = torch.sum(feature_attention_mask, dim=1) + input_features = input_features.permute(0, 2, 1)[feature_attention_mask.bool()].permute(1, 0) + else: + audio_feature_lengths = None - # padding to min audio len - MIN_AUDIO_LEN = 480 - - if audio.shape[0] < MIN_AUDIO_LEN: - audio = np.pad(audio, (0, MIN_AUDIO_LEN - len(audio)), mode="constant", constant_values=0.0) - - if audio.shape[0] > self.max_length: - start = 0 - while start < audio.shape[0]: - end = min(start + self.max_length, audio.shape[0]) - chunk = audio[start:end] - - if chunk.shape[0] < MIN_AUDIO_LEN: - chunk = np.pad(chunk, (0, MIN_AUDIO_LEN - chunk.shape[0]), mode="constant", constant_values=0.0) - batch_audios.append(chunk) - batch_audio_lens.append(min(chunk.shape[0], self.max_length)) - chunk_owner_index.append(i) - - start = end - else: - batch_audio_lens.append(min(audio.shape[0], self.max_length)) - batch_audios.append(audio) - chunk_owner_index.append(i) - - batch_audio_lens = np.array(batch_audio_lens, dtype=np.int32) - - audios, audio_lens_after_cnn = self.processor._preprocess( - batch_audios, sampling_rate=16000, return_tensors="pt" + feature_lens = audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) + audio_features = self.forward( + input_features, + feature_lens=feature_lens, ) - audios = self.forward(audios, audio_lens_after_cnn) - audio_lens_after_cnn = np.array(audio_lens_after_cnn, dtype=np.int32) - audio_token_num = (audio_lens_after_cnn - 2) // 2 + 1 - - num_audios = len(audio_items) - per_audio_embeds = [[] for _ in range(num_audios)] - - for chunk_idx, owner in enumerate(chunk_owner_index): - token_len = int(audio_token_num[chunk_idx]) - if token_len <= 0: - continue - per_audio_embeds[owner].append(audios[chunk_idx][:token_len]) ready_audio = obtain(self.cache_client.root.get_items_embed(uuids)) ids_to_set = [] @@ -401,10 +367,8 @@ def encode(self, audio_items: List[AudioItem], cpu_embed_cache_client: CpuEmbedC uid = uuids[i] item = items[i] - # 拼接该 audio 的所有 chunk embedding - cur_embed = torch.cat(per_audio_embeds[i], dim=0) cpu_embed_cache_client.copy_to_cache( - embed_tensor=cur_embed, start_index_in_cache=item.start_index_in_embed_cache + embed_tensor=audio_features, start_index_in_cache=item.start_index_in_embed_cache ) ids_to_set.append(uid) From 169fe84fe1419ecc53afbc57acce39ce0eb1fccb Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Thu, 5 Feb 2026 11:48:03 +0000 Subject: [PATCH 12/20] add-audio --- .../models/qwen3_omni_moe_thinker/audio_process.py | 13 +++++++++---- .../qwen3_omni_moe_thinker/qwen3_omni_audio.py | 9 +++++++-- .../server/audioserver/model_infer/model_rpc.py | 2 +- 3 files changed, 17 insertions(+), 7 deletions(-) diff --git a/lightllm/models/qwen3_omni_moe_thinker/audio_process.py b/lightllm/models/qwen3_omni_moe_thinker/audio_process.py index 74f616aedc..833cc8f4b0 100644 --- a/lightllm/models/qwen3_omni_moe_thinker/audio_process.py +++ b/lightllm/models/qwen3_omni_moe_thinker/audio_process.py @@ -98,9 +98,9 @@ def _preprocess( pad_to_multiple_of: Optional[int] = None, return_tensors: Optional[Union[str, TensorType]] = None, return_attention_mask: Optional[bool] = None, - padding: Optional[str] = "max_length", + padding: Optional[str] = "longest", # max_length代表padding到max_length max_length: Optional[int] = None, - sampling_rate: Optional[int] = None, + sampling_rate: Optional[int] = 16000, do_normalize: Optional[bool] = None, device: Optional[str] = "cpu", return_token_timestamps: Optional[bool] = None, @@ -176,5 +176,10 @@ def _preprocess( if return_tensors is not None: padded_inputs = padded_inputs.convert_to_tensors(return_tensors) - - return padded_inputs["input_features"], padded_inputs["attention_mask"] + input_features = torch.from_numpy(np.asarray(padded_inputs["input_features"], dtype=np.float32)).to( + device="cuda", dtype=torch.bfloat16 + ) + attention_mask = torch.from_numpy(np.asarray(padded_inputs["attention_mask"], dtype=np.float32)).to( + device="cuda", dtype=torch.int32 + ) + return input_features, attention_mask diff --git a/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_audio.py b/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_audio.py index 9d41de4df3..7ca97a516a 100644 --- a/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_audio.py +++ b/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_audio.py @@ -2,6 +2,7 @@ import json import math import torch +import rpyc import librosa import numpy as np from io import BytesIO @@ -206,6 +207,9 @@ def __init__( self.proj2 = nn.Linear(d_model, output_dim) self.n_window_infer = n_window_infer self.conv_chunksize = conv_chunksize + + self.cache_port = kvargs["cache_port"] + self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True}) self._init_datatype() def _init_datatype(self): @@ -271,6 +275,7 @@ def forward( feature_lens=None, aftercnn_lens=None, ): + aftercnn_lens = _get_feat_extract_output_lengths(feature_lens) chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long() chunk_lengths = torch.tensor( @@ -344,8 +349,8 @@ def encode(self, audio_items: List[AudioItem], cpu_embed_cache_client: CpuEmbedC audio, _ = librosa.load(audio, sr=16000) else: raise ValueError(f"cannot read audio which type is {type(item)}!") - # 这里后面还要改 - input_features, feature_attention_mask = self.processor._preprocess(audio) + + input_features, feature_attention_mask = self.processor._preprocess(audio, return_attention_mask=True) if feature_attention_mask is not None: audio_feature_lengths = torch.sum(feature_attention_mask, dim=1) input_features = input_features.permute(0, 2, 1)[feature_attention_mask.bool()].permute(1, 0) diff --git a/lightllm/server/audioserver/model_infer/model_rpc.py b/lightllm/server/audioserver/model_infer/model_rpc.py index 8c123bbd1c..ad18ffd7f5 100644 --- a/lightllm/server/audioserver/model_infer/model_rpc.py +++ b/lightllm/server/audioserver/model_infer/model_rpc.py @@ -31,7 +31,7 @@ def exposed_init_model(self, kvargs): if self.model_type == "clap_audio_model" or self.model_type == "whisper": self.model = WhisperAudioModel(model_kvargs) elif self.model_type == "qwen3_omni_moe_audio_encoder": - self.model = Qwen3OmniMoeAudioEncoder(model_kvargs) + self.model = Qwen3OmniMoeAudioEncoder(model_kvargs).eval().bfloat16() else: raise Exception(f"can not support {self.model_type} now") From d0ed9d4404abff05ff342d58e15a67e824a1b7b3 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Thu, 5 Feb 2026 12:23:55 +0000 Subject: [PATCH 13/20] add-audio --- .../models/qwen3_omni_moe_thinker/model.py | 125 +++++++++++++----- .../qwen3_omni_audio.py | 5 + 2 files changed, 96 insertions(+), 34 deletions(-) diff --git a/lightllm/models/qwen3_omni_moe_thinker/model.py b/lightllm/models/qwen3_omni_moe_thinker/model.py index c8de223e65..82fa71ba0e 100644 --- a/lightllm/models/qwen3_omni_moe_thinker/model.py +++ b/lightllm/models/qwen3_omni_moe_thinker/model.py @@ -1,5 +1,7 @@ import os import json +import librosa +from io import BytesIO from lightllm.common.build_utils import repair_config from lightllm.models.registry import ModelRegistry from lightllm.models.qwen3_moe.model import Qwen3MOEModel @@ -20,10 +22,21 @@ from lightllm.server.multimodal_params import AudioItem, MultimodalParams, ImageItem +def _get_feat_extract_output_lengths(input_lengths): + """ + Computes the output length of the convolutional layers and the output length of the audio encoder + """ + + input_lengths_leave = input_lengths % 100 + feat_lengths = (input_lengths_leave - 1) // 2 + 1 + output_lengths = ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13 + return output_lengths + + # <|audio_start|><|audio_pad|><|audio_end|> AUDIO_START_TOKEN = "<|audio_start|>" AUDIO_END_TOKEN = "<|audio_end|>" - +AUDIO_TOKEN_TOKEN = "<|audio_pad|>" MIN_AUDIO_LEN = 480 @@ -45,8 +58,14 @@ def __init__(self, tokenizer=None, image_processor=None, **kwargs): self.audio_end_tag = AUDIO_END_TOKEN self.audio_end_id = tokenizer.convert_tokens_to_ids(self.audio_end_tag) - self.audio_min_length = MIN_AUDIO_LEN - self.audio_max_length = 16000 * 30 + self.audio_token_tag = AUDIO_TOKEN_TOKEN + self.audio_token_id = tokenizer.convert_tokens_to_ids(self.audio_token_tag) + + # 这些太hard了, 后面改一下,可以直接从audio_processor里取? + self.sampling_rate = 16000 + self.chunk_length = 30 + self.n_samples = self.chunk_length * self.sampling_rate + self.hop_length = 160 def init_audioitem_extral_params( self, audio: AudioItem, multi_params: MultimodalParams, sampling_params: SamplingParams @@ -54,37 +73,75 @@ def init_audioitem_extral_params( return def get_audio_token_length(self, audio: AudioItem): - L = audio.audio_length - audio_token_num = 0 - chunk_lens = [] - if L <= self.audio_max_length: - cur_len = L - if cur_len < self.audio_min_length: - cur_len = self.audio_min_length - chunk_lens.append(cur_len) - else: - start = 0 - while start < L: - end = min(start + self.audio_max_length, L) - cur_len = end - start - - if cur_len < self.audio_min_length: - cur_len = self.audio_min_length - - chunk_lens.append(cur_len) - start = end - for chunk_len in chunk_lens: - mel_len = chunk_len // 160 - dilation = 1 - L_in = mel_len - for (padding, kernel_size, stride) in eval("[(1,3,1)] + [(1,3,2)] "): - L_out = L_in + 2 * padding - dilation * (kernel_size - 1) - 1 - L_out = 1 + L_out // stride - L_in = L_out - audio_len_after_cnn = L_out - chunk_token_num = (audio_len_after_cnn - 2) // 2 + 1 - audio_token_num += int(chunk_token_num) - return audio_token_num + # audio_bytes = audio._preload_data + # audio_values, _ = librosa.load(BytesIO(audio_bytes), sr=self.sampling_rate) + # length = max(int(audio_values.shape[0]), int(MIN_AUDIO_LEN)) #这个最短还有必要吗?稍等再检查一下 + # L_eff = min(length, int(self.n_samples)) + # num_frames = L_eff // int(self.hop_length) + + return 290 + + def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs): + origin_ids = self.tokenizer.encode(prompt) + + # -> + origin_ids = [token for token in origin_ids if token not in (self.image_token_id, self.audio_token_id)] + # --> id,id+1...id+num + input_ids = [] + image_id = 0 + while True: + try: + start_idx = origin_ids.index(self.image_start_id) + if start_idx + 1 >= len(origin_ids): + break + if origin_ids[start_idx + 1] == self.image_end_id: + input_ids.extend(origin_ids[: start_idx + 1]) + token_id = multimodal_params.images[image_id].token_id + token_num = multimodal_params.images[image_id].token_num + multimodal_params.images[image_id].start_idx = len(input_ids) + input_ids.extend(range(token_id, token_id + token_num)) + input_ids.append(self.image_end_id) + origin_ids = origin_ids[start_idx + 2 :] + image_id += 1 + else: + raise ValueError("image token error") + except ValueError: + break + if multimodal_params: + image_cnt = len(multimodal_params.images) + if image_cnt != image_id: + raise ValueError(image_cnt == image_id, f"invalid image tag num: {image_cnt} vs {image_id}!") + input_ids.extend(origin_ids) + + # audio + origin_ids = input_ids + input_ids = [] + audio_id = 0 + start_idx = 0 + while True: + try: + start_idx = origin_ids.index(self.audio_start_id) + if start_idx + 1 >= len(origin_ids): + break + if origin_ids[start_idx + 1] == self.audio_end_id: + input_ids.extend(origin_ids[: start_idx + 1]) + token_id = multimodal_params.audios[audio_id].token_id + token_num = multimodal_params.audios[audio_id].token_num + input_ids.extend(range(token_id, token_id + token_num)) + input_ids.append(self.audio_end_id) + origin_ids = origin_ids[start_idx + 2 :] + audio_id += 1 + else: + raise ValueError("audio token error") + except ValueError: + break + if multimodal_params: + audio_cnt = len(multimodal_params.audios) + if audio_cnt != audio_id: + raise ValueError(audio_cnt == audio_id, f"invalid audio tag num: {audio_cnt} vs {audio_id}!") + input_ids.extend(origin_ids) + + return input_ids @ModelRegistry(["qwen3_omni_moe"], is_multimodal=True) diff --git a/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_audio.py b/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_audio.py index 7ca97a516a..7582b22272 100644 --- a/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_audio.py +++ b/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_audio.py @@ -351,17 +351,22 @@ def encode(self, audio_items: List[AudioItem], cpu_embed_cache_client: CpuEmbedC raise ValueError(f"cannot read audio which type is {type(item)}!") input_features, feature_attention_mask = self.processor._preprocess(audio, return_attention_mask=True) + print(f"input_features is {input_features}, input_features.shape is {input_features.shape}") + print(f"feature_attention_mask is {feature_attention_mask}, shape is {feature_attention_mask.shape}") if feature_attention_mask is not None: audio_feature_lengths = torch.sum(feature_attention_mask, dim=1) input_features = input_features.permute(0, 2, 1)[feature_attention_mask.bool()].permute(1, 0) else: audio_feature_lengths = None + print(f"input_features is {input_features}, input_features.shape is {input_features.shape}") feature_lens = audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) + print(f"feature_lens is {feature_lens}") audio_features = self.forward( input_features, feature_lens=feature_lens, ) + print(f"audio_features is {audio_features}, shape is {audio_features.shape}") ready_audio = obtain(self.cache_client.root.get_items_embed(uuids)) ids_to_set = [] From 753e13f0fa8a3332c1aa49e85b156295baf13677 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Thu, 5 Feb 2026 13:25:45 +0000 Subject: [PATCH 14/20] add-audio --- .../models/qwen3_omni_moe_thinker/model.py | 44 +++++++------------ .../qwen3_omni_audio.py | 43 ++++++++++-------- lightllm/server/tokenizer.py | 2 +- 3 files changed, 42 insertions(+), 47 deletions(-) diff --git a/lightllm/models/qwen3_omni_moe_thinker/model.py b/lightllm/models/qwen3_omni_moe_thinker/model.py index 82fa71ba0e..53fcbf39c0 100644 --- a/lightllm/models/qwen3_omni_moe_thinker/model.py +++ b/lightllm/models/qwen3_omni_moe_thinker/model.py @@ -33,39 +33,32 @@ def _get_feat_extract_output_lengths(input_lengths): return output_lengths -# <|audio_start|><|audio_pad|><|audio_end|> -AUDIO_START_TOKEN = "<|audio_start|>" -AUDIO_END_TOKEN = "<|audio_end|>" -AUDIO_TOKEN_TOKEN = "<|audio_pad|>" MIN_AUDIO_LEN = 480 class QWen3OmniTokenizer(QWen3VLTokenizer): - def __init__(self, tokenizer=None, image_processor=None, **kwargs): + def __init__(self, tokenizer=None, processor=None, **kwargs): self.tokenizer = tokenizer - self.image_processor = image_processor + # image + self.image_processor = processor.image_processor self.min_pixel = self.image_processor.min_pixels self.max_pixel = self.image_processor.max_pixels self.patch_size = self.image_processor.patch_size self.merge_size = self.image_processor.merge_size + + # audio + self.audio_processor = processor.feature_extractor + self.sampling_rate = self.audio_processor.sampling_rate + self.n_samples = self.audio_processor.n_samples + self.hop_length = self.audio_processor.hop_length + self.image_start_id = kwargs["model_cfg"]["vision_start_token_id"] self.image_end_id = kwargs["model_cfg"]["vision_end_token_id"] self.image_token_id = kwargs["model_cfg"]["image_token_id"] - self.audio_start_tag = AUDIO_START_TOKEN - self.audio_start_id = tokenizer.convert_tokens_to_ids(self.audio_start_tag) - - self.audio_end_tag = AUDIO_END_TOKEN - self.audio_end_id = tokenizer.convert_tokens_to_ids(self.audio_end_tag) - - self.audio_token_tag = AUDIO_TOKEN_TOKEN - self.audio_token_id = tokenizer.convert_tokens_to_ids(self.audio_token_tag) - - # 这些太hard了, 后面改一下,可以直接从audio_processor里取? - self.sampling_rate = 16000 - self.chunk_length = 30 - self.n_samples = self.chunk_length * self.sampling_rate - self.hop_length = 160 + self.audio_start_id = kwargs["model_cfg"]["audio_start_token_id"] + self.audio_end_id = kwargs["model_cfg"]["audio_end_token_id"] + self.audio_token_id = kwargs["model_cfg"]["audio_token_id"] def init_audioitem_extral_params( self, audio: AudioItem, multi_params: MultimodalParams, sampling_params: SamplingParams @@ -73,13 +66,10 @@ def init_audioitem_extral_params( return def get_audio_token_length(self, audio: AudioItem): - # audio_bytes = audio._preload_data - # audio_values, _ = librosa.load(BytesIO(audio_bytes), sr=self.sampling_rate) - # length = max(int(audio_values.shape[0]), int(MIN_AUDIO_LEN)) #这个最短还有必要吗?稍等再检查一下 - # L_eff = min(length, int(self.n_samples)) - # num_frames = L_eff // int(self.hop_length) - - return 290 + length = min(audio.audio_length, int(self.n_samples)) + token_num = length // int(self.hop_length) + print(f"token_num is {token_num}") + return token_num def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs): origin_ids = self.tokenizer.encode(prompt) diff --git a/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_audio.py b/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_audio.py index 7582b22272..0847008d3c 100644 --- a/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_audio.py +++ b/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_audio.py @@ -340,33 +340,37 @@ def forward( def encode(self, audio_items: List[AudioItem], cpu_embed_cache_client: CpuEmbedCacheClient): uuids = [] items: List[AudioItem] = [] + per_audio_features: List[torch.Tensor] = [] for i, item in enumerate(audio_items): if isinstance(item, AudioItem): uuids.append(item.uuid) items.append(item) audio_data = read_shm(get_shm_name_data(item.uuid)) audio = BytesIO(audio_data) - audio, _ = librosa.load(audio, sr=16000) + audio, _ = librosa.load(audio, sr=self.processor.sampling_rate) else: raise ValueError(f"cannot read audio which type is {type(item)}!") - input_features, feature_attention_mask = self.processor._preprocess(audio, return_attention_mask=True) - print(f"input_features is {input_features}, input_features.shape is {input_features.shape}") - print(f"feature_attention_mask is {feature_attention_mask}, shape is {feature_attention_mask.shape}") - if feature_attention_mask is not None: - audio_feature_lengths = torch.sum(feature_attention_mask, dim=1) - input_features = input_features.permute(0, 2, 1)[feature_attention_mask.bool()].permute(1, 0) - else: - audio_feature_lengths = None - print(f"input_features is {input_features}, input_features.shape is {input_features.shape}") - - feature_lens = audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) - print(f"feature_lens is {feature_lens}") - audio_features = self.forward( - input_features, - feature_lens=feature_lens, - ) - print(f"audio_features is {audio_features}, shape is {audio_features.shape}") + input_features, feature_attention_mask = self.processor._preprocess(audio, return_attention_mask=True) + print(f"input_features is {input_features}, input_features.shape is {input_features.shape}") + print(f"feature_attention_mask is {feature_attention_mask}, shape is {feature_attention_mask.shape}") + if feature_attention_mask is not None: + audio_feature_lengths = torch.sum(feature_attention_mask, dim=1) + input_features = input_features.permute(0, 2, 1)[feature_attention_mask.bool()].permute(1, 0) + else: + audio_feature_lengths = None + print(f"input_features is {input_features}, input_features.shape is {input_features.shape}") + + feature_lens = ( + audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) + ) + print(f"feature_lens is {feature_lens}") + audio_features = self.forward( + input_features, + feature_lens=feature_lens, + ) + per_audio_features.append(audio_features) + print(f"audio_features is {audio_features}, shape is {audio_features.shape}") ready_audio = obtain(self.cache_client.root.get_items_embed(uuids)) ids_to_set = [] @@ -377,8 +381,9 @@ def encode(self, audio_items: List[AudioItem], cpu_embed_cache_client: CpuEmbedC uid = uuids[i] item = items[i] + cur_embed = per_audio_features[i] cpu_embed_cache_client.copy_to_cache( - embed_tensor=audio_features, start_index_in_cache=item.start_index_in_embed_cache + embed_tensor=cur_embed, start_index_in_cache=item.start_index_in_embed_cache ) ids_to_set.append(uid) diff --git a/lightllm/server/tokenizer.py b/lightllm/server/tokenizer.py index c245fa85a6..09bc938f23 100644 --- a/lightllm/server/tokenizer.py +++ b/lightllm/server/tokenizer.py @@ -106,7 +106,7 @@ def get_tokenizer( model_cfg = model_cfg["thinker_config"] processor = AutoProcessor.from_pretrained(tokenizer_name) - tokenizer = QWen3OmniTokenizer(tokenizer, image_processor=processor.image_processor, model_cfg=model_cfg) + tokenizer = QWen3OmniTokenizer(tokenizer, processor=processor, model_cfg=model_cfg) elif model_type == "internvl_chat": tokenizer = InternvlTokenizer(tokenizer, model_cfg, weight_dir=tokenizer_name) elif model_type == "gemma3": From 910a3f2b7b8c2788e538f379f03773ab3ce1091d Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 6 Feb 2026 06:17:27 +0000 Subject: [PATCH 15/20] fix --- .../layer_infer/transformer_layer_infer.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/lightllm/models/qwen3_omni_moe_thinker/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_omni_moe_thinker/layer_infer/transformer_layer_infer.py index 91c65e6750..1a05a752f3 100644 --- a/lightllm/models/qwen3_omni_moe_thinker/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_omni_moe_thinker/layer_infer/transformer_layer_infer.py @@ -1,27 +1,12 @@ -import os import torch -import torch.functional as F -import torch.distributed as dist -import numpy as np -import triton -from typing import Tuple from lightllm.models.qwen3_vl_moe.layer_infer.transformer_layer_infer import Qwen3VLMOETransformerLayerInfer -from lightllm.models.qwen3_moe.layer_infer.transformer_layer_infer import Qwen3MOETransformerLayerInfer -from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer -from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd -from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd -from functools import partial from lightllm.utils.log_utils import init_logger -from lightllm.utils.dist_utils import get_global_world_size -from lightllm.distributed.communication_op import all_gather_into_tensor, reduce_scatter_tensor logger = init_logger(__name__) class Qwen3OmniMOETransformerLayerInfer(Qwen3VLMOETransformerLayerInfer): def __init__(self, layer_num, network_config): - self.layer_num_ = network_config["num_hidden_layers"] super().__init__(layer_num, network_config) self.head_dim_ = network_config["head_dim"] self.mrope_section = torch.tensor( From 31381cca1eb12a8dcae303df36df668e5dae858e Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 6 Feb 2026 06:53:22 +0000 Subject: [PATCH 16/20] fix requirements. --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index a3b9473f82..df23f61d52 100644 --- a/requirements.txt +++ b/requirements.txt @@ -61,7 +61,7 @@ sortedcontainers==2.4.0 toolz==0.12.0 torch==2.8.0 tqdm==4.65.0 -transformers==4.53.3 +transformers==4.57.1 tokenizers==0.21.1 urllib3==1.26.16 uvicorn==0.19.0 From 94d3eaa8685824d14a9a1e079ad057c3f55d09cb Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 6 Feb 2026 07:07:45 +0000 Subject: [PATCH 17/20] fix. --- .../qwen3_vl_moe/layer_infer/transformer_layer_infer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py index 4ccc6da372..0bf716fc56 100644 --- a/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py @@ -23,8 +23,10 @@ def _get_qkv( layer_weight: Qwen3MOETransformerLayerWeight, ) -> Tuple[torch.Tensor, torch.Tensor]: input = input.view(-1, self.embed_dim_) - q = layer_weight.q_proj.mm(input) - cache_kv = layer_weight.kv_proj.mm(input) + qkv = layer_weight.qkv_proj.mm(input) + q, cache_kv = qkv.split( + [self.tp_q_head_num_ * self.head_dim_, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_], dim=-1 + ) layer_weight.q_norm_weight_( q, eps=self.eps_, From f17f68c615d5282774c1efa5b97bc8726ed9a03d Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 6 Feb 2026 09:09:32 +0000 Subject: [PATCH 18/20] fix audio token num calcu --- lightllm/models/qwen3_omni_moe_thinker/model.py | 12 ++++++++++-- .../qwen3_omni_moe_thinker/qwen3_omni_audio.py | 6 +----- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/lightllm/models/qwen3_omni_moe_thinker/model.py b/lightllm/models/qwen3_omni_moe_thinker/model.py index 53fcbf39c0..9ede8cec1a 100644 --- a/lightllm/models/qwen3_omni_moe_thinker/model.py +++ b/lightllm/models/qwen3_omni_moe_thinker/model.py @@ -66,11 +66,19 @@ def init_audioitem_extral_params( return def get_audio_token_length(self, audio: AudioItem): + # 这里得处理对应奖语音长度按照 30 进行限制,后续处理中,超过30的会被截断。 length = min(audio.audio_length, int(self.n_samples)) - token_num = length // int(self.hop_length) - print(f"token_num is {token_num}") + token_num = self._caclu_audio_token_num(length) + # print(f"token_num is {token_num} n_samples is {self.n_samples} hop_length is {self.hop_length}") return token_num + def _caclu_audio_token_num(self, input_audio_len: int): + _mel_len = input_audio_len // int(self.hop_length) + input_lengths_leave = _mel_len % 100 + feat_lengths = (input_lengths_leave - 1) // 2 + 1 + output_lengths = ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (_mel_len // 100) * 13 + return output_lengths + def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs): origin_ids = self.tokenizer.encode(prompt) diff --git a/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_audio.py b/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_audio.py index 0847008d3c..c66033f532 100644 --- a/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_audio.py +++ b/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_audio.py @@ -352,25 +352,21 @@ def encode(self, audio_items: List[AudioItem], cpu_embed_cache_client: CpuEmbedC raise ValueError(f"cannot read audio which type is {type(item)}!") input_features, feature_attention_mask = self.processor._preprocess(audio, return_attention_mask=True) - print(f"input_features is {input_features}, input_features.shape is {input_features.shape}") - print(f"feature_attention_mask is {feature_attention_mask}, shape is {feature_attention_mask.shape}") if feature_attention_mask is not None: audio_feature_lengths = torch.sum(feature_attention_mask, dim=1) input_features = input_features.permute(0, 2, 1)[feature_attention_mask.bool()].permute(1, 0) else: audio_feature_lengths = None - print(f"input_features is {input_features}, input_features.shape is {input_features.shape}") feature_lens = ( audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) ) - print(f"feature_lens is {feature_lens}") + audio_features = self.forward( input_features, feature_lens=feature_lens, ) per_audio_features.append(audio_features) - print(f"audio_features is {audio_features}, shape is {audio_features.shape}") ready_audio = obtain(self.cache_client.root.get_items_embed(uuids)) ids_to_set = [] From 9a0b6799dd45e3a1b5f46498d02b4621e1b77d16 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Fri, 6 Feb 2026 10:55:45 +0000 Subject: [PATCH 19/20] fix-audio --- lightllm/models/qwen3_omni_moe_thinker/model.py | 13 ------------- requirements.txt | 2 +- 2 files changed, 1 insertion(+), 14 deletions(-) diff --git a/lightllm/models/qwen3_omni_moe_thinker/model.py b/lightllm/models/qwen3_omni_moe_thinker/model.py index 9ede8cec1a..2e863da001 100644 --- a/lightllm/models/qwen3_omni_moe_thinker/model.py +++ b/lightllm/models/qwen3_omni_moe_thinker/model.py @@ -22,17 +22,6 @@ from lightllm.server.multimodal_params import AudioItem, MultimodalParams, ImageItem -def _get_feat_extract_output_lengths(input_lengths): - """ - Computes the output length of the convolutional layers and the output length of the audio encoder - """ - - input_lengths_leave = input_lengths % 100 - feat_lengths = (input_lengths_leave - 1) // 2 + 1 - output_lengths = ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13 - return output_lengths - - MIN_AUDIO_LEN = 480 @@ -69,7 +58,6 @@ def get_audio_token_length(self, audio: AudioItem): # 这里得处理对应奖语音长度按照 30 进行限制,后续处理中,超过30的会被截断。 length = min(audio.audio_length, int(self.n_samples)) token_num = self._caclu_audio_token_num(length) - # print(f"token_num is {token_num} n_samples is {self.n_samples} hop_length is {self.hop_length}") return token_num def _caclu_audio_token_num(self, input_audio_len: int): @@ -162,7 +150,6 @@ def _init_config(self): all_config = json.load(json_file) self.config = all_config["thinker_config"]["text_config"] # rename keys - print(f"self.config is {self.config}") repair_config(self.config, same_names=["num_attention_heads", "n_head"]) repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"]) repair_config(self.config, same_names=["num_hidden_layers", "n_layer"]) diff --git a/requirements.txt b/requirements.txt index df23f61d52..89c52ebf86 100644 --- a/requirements.txt +++ b/requirements.txt @@ -62,7 +62,7 @@ toolz==0.12.0 torch==2.8.0 tqdm==4.65.0 transformers==4.57.1 -tokenizers==0.21.1 +tokenizers==0.22.1 urllib3==1.26.16 uvicorn==0.19.0 uvloop==0.17.0 From 7592a22aff2eef6425f92bb953769ff2affd06da Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Fri, 6 Feb 2026 10:59:40 +0000 Subject: [PATCH 20/20] fix-audio --- .../meta_weights/code2wav_causal_conv_net.py | 116 ------------ .../code2wav_causal_trans_conv_net.py | 101 ----------- .../meta_weights/code2wav_conv_ne_xt.py | 165 ------------------ .../meta_weights/talker_resize_mlp_weight.py | 109 ------------ 4 files changed, 491 deletions(-) delete mode 100644 lightllm/models/qwen3_omni_moe_thinker/layer_weights/meta_weights/code2wav_causal_conv_net.py delete mode 100644 lightllm/models/qwen3_omni_moe_thinker/layer_weights/meta_weights/code2wav_causal_trans_conv_net.py delete mode 100644 lightllm/models/qwen3_omni_moe_thinker/layer_weights/meta_weights/code2wav_conv_ne_xt.py delete mode 100644 lightllm/models/qwen3_omni_moe_thinker/layer_weights/meta_weights/talker_resize_mlp_weight.py diff --git a/lightllm/models/qwen3_omni_moe_thinker/layer_weights/meta_weights/code2wav_causal_conv_net.py b/lightllm/models/qwen3_omni_moe_thinker/layer_weights/meta_weights/code2wav_causal_conv_net.py deleted file mode 100644 index 1caabf31da..0000000000 --- a/lightllm/models/qwen3_omni_moe_thinker/layer_weights/meta_weights/code2wav_causal_conv_net.py +++ /dev/null @@ -1,116 +0,0 @@ -import math -import torch -import numpy as np -from typing import Dict, Optional -from transformers.activations import ACT2FN -from lightllm.common.basemodel.layer_weights.meta_weights.base_weight import BaseWeightTpl -from lightllm.common.basemodel.layer_weights.meta_weights.platform_op import PlatformAwareOp -from lightllm.common.basemodel.triton_kernel.embedding import embedding as embedding_kernel -from lightllm.utils.dist_utils import get_dp_world_size, get_current_rank_in_dp - - -class Qwen3OmniMoeCausalConvNetWeight(BaseWeightTpl, PlatformAwareOp): - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: int, - weight_name: str, - bias_name: str, - data_type: torch.dtype, - dilation: int = 1, - stride: int = 1, - groups: int = 1, - ): - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.kernel_size = kernel_size - self.dilation = dilation - self.stride = stride - self.groups = groups - self.weight_name: str = weight_name - self.bias_name: str = bias_name - self.data_type_ = data_type - - self.kernel_size_effective = (kernel_size - 1) * dilation + 1 - self.padding = self.kernel_size_effective - self.stride - - self._create_weight() - - def _create_weight(self): - # Conv1d weight shape: (out_channels, in_channels // groups, kernel_size) - self.weight: torch.Tensor = torch.empty( - self.out_channels, - self.in_channels // self.groups, - self.kernel_size, - dtype=self.data_type_, - device=self.device_id_, - ) - self.bias: torch.Tensor = torch.empty(self.out_channels, dtype=self.data_type_, device=self.device_id_) - self.weight.load_ok = False - self.bias.load_ok = False - - def load_hf_weights(self, weights: Dict[str, torch.Tensor]): - if self.weight_name in weights: - t_weight = weights[self.weight_name] - assert t_weight.shape == ( - self.out_channels, - self.in_channels // self.groups, - self.kernel_size, - ) - self.weight.copy_(t_weight.to(self.data_type_)) - self.weight.load_ok = True - if self.bias_name in weights: - t_bias = weights[self.bias_name] - assert t_bias.shape == (self.out_channels,) - self.bias.copy_(t_bias.to(self.data_type_)) - self.bias.load_ok = True - - def verify_load(self): - return self.weight.load_ok and self.bias.load_ok - - def _get_extra_padding_for_conv1d(self, hidden_state: torch.Tensor) -> int: - length = hidden_state.shape[-1] - n_frames = (length - self.kernel_size_effective + self.padding) / self.stride + 1 - ideal_length = (math.ceil(n_frames) - 1) * self.stride + (self.kernel_size_effective - self.padding) - return int(ideal_length - length) - - def _native_forward( - self, hidden_state: torch.Tensor, out: Optional[torch.Tensor] = None, _alloc_func=torch.empty - ) -> torch.Tensor: - extra_padding = self._get_extra_padding_for_conv1d(hidden_state) - hidden_state = torch.nn.functional.pad(hidden_state, (self.padding, extra_padding), mode="constant", value=0) - x = torch.nn.functional.conv1d( - hidden_state, - self.weight, - self.bias, - stride=self.stride, - padding=0, - dilation=self.dilation, - groups=self.groups, - ).contiguous() - if out is not None: - out.copy_(x) - return out - return x - - def _cuda_forward( - self, hidden_state: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty - ) -> torch.Tensor: - if out is None: - result = self._native_forward(hidden_state=hidden_state, out=None, _alloc_func=alloc_func) - return result - result = self._native_forward(hidden_state=hidden_state, out=None, _alloc_func=alloc_func) - out.copy_(result) - return out - - def _musa_forward( - self, hidden_state: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty - ) -> torch.Tensor: - return self._cuda_forward(hidden_state=hidden_state, out=out, alloc_func=alloc_func) - - def __call__( - self, hidden_state: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty - ) -> torch.Tensor: - return self._forward(hidden_state=hidden_state, out=out, alloc_func=alloc_func) diff --git a/lightllm/models/qwen3_omni_moe_thinker/layer_weights/meta_weights/code2wav_causal_trans_conv_net.py b/lightllm/models/qwen3_omni_moe_thinker/layer_weights/meta_weights/code2wav_causal_trans_conv_net.py deleted file mode 100644 index 058565e848..0000000000 --- a/lightllm/models/qwen3_omni_moe_thinker/layer_weights/meta_weights/code2wav_causal_trans_conv_net.py +++ /dev/null @@ -1,101 +0,0 @@ -import math -import torch -import numpy as np -from typing import Dict, Optional -from transformers.activations import ACT2FN -from lightllm.common.basemodel.layer_weights.meta_weights.base_weight import BaseWeightTpl -from lightllm.common.basemodel.layer_weights.meta_weights.platform_op import PlatformAwareOp -from lightllm.common.basemodel.triton_kernel.embedding import embedding as embedding_kernel -from lightllm.utils.dist_utils import get_dp_world_size, get_current_rank_in_dp - - -class Qwen3OmniMoeCode2wavCausalTransConvNetWeight(BaseWeightTpl, PlatformAwareOp): - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: int, - stride: int, - weight_name: str, - bias_name: str, - data_type: torch.dtype, - ): - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.kernel_size = kernel_size - self.stride = stride - self.weight_name: str = weight_name - self.bias_name: str = bias_name - self.data_type_ = data_type - - pad = kernel_size - stride - self.left_pad = math.ceil(pad) - self.right_pad = pad = self.left_pad - - self._create_weight() - - def _create_weight(self): - # ConvTranspose1d weight shape: (in_channels, out_channels, kernel_size) when groups=1 - self.weight: torch.Tensor = torch.empty( - self.in_channels, self.out_channels, self.kernel_size, dtype=self.data_type_, device=self.device_id_ - ) - self.bias: torch.Tensor = torch.empty(self.out_channels, dtype=self.data_type_, device=self.device_id_) - self.weight.load_ok = False - self.bias.load_ok = False - - def load_hf_weights(self, weights: Dict[str, torch.Tensor]): - if self.weight_name in weights: - t_weight = weights[self.weight_name] - assert t_weight.shape == (self.in_channels, self.out_channels, self.kernel_size) - self.weight.copy_(t_weight.to(self.data_type_)) - self.weight.load_ok = True - if self.bias_name in weights: - t_bias = weights[self.bias_name] - assert t_bias.shape == (self.out_channels,) - self.bias.copy_(t_bias.to(self.data_type_)) - self.bias.load_ok = True - - def verify_load(self): - return self.weight.load_ok and self.bias.load_ok - - def _native_forward( - self, hidden_state: torch.Tensor, out: Optional[torch.Tensor] = None, _alloc_func=torch.empty - ) -> torch.Tensor: - # hidden_state: [B, C_in, L] - x = torch.nn.functional.conv_transpose1d( - hidden_state, - self.weight, - self.bias, - stride=self.stride, - padding=0, - output_padding=0, - groups=1, - dilation=1, - ) - x = x[..., self.left_pad : x.shape[-1] - self.right_pad].contiguous() - if out is not None: - out.copy_(x) - return out - return x - - def _cuda_forward( - self, hidden_state: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty - ) -> torch.Tensor: - if out is None: - # output length depends on input length; allocate after computing - result = self._native_forward(hidden_state=hidden_state, out=None, _alloc_func=alloc_func) - return result - result = self._native_forward(hidden_state=hidden_state, out=None, _alloc_func=alloc_func) - out.copy_(result) - return out - - def _musa_forward( - self, hidden_state: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty - ) -> torch.Tensor: - return self._cuda_forward(hidden_state=hidden_state, out=out, alloc_func=alloc_func) - - def __call__( - self, hidden_state: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty - ) -> torch.Tensor: - return self._forward(hidden_state=hidden_state, out=out, alloc_func=alloc_func) diff --git a/lightllm/models/qwen3_omni_moe_thinker/layer_weights/meta_weights/code2wav_conv_ne_xt.py b/lightllm/models/qwen3_omni_moe_thinker/layer_weights/meta_weights/code2wav_conv_ne_xt.py deleted file mode 100644 index 21b0462dce..0000000000 --- a/lightllm/models/qwen3_omni_moe_thinker/layer_weights/meta_weights/code2wav_conv_ne_xt.py +++ /dev/null @@ -1,165 +0,0 @@ -import math -import torch -import numpy as np -from typing import Dict, Optional -from transformers.activations import ACT2FN -from lightllm.common.basemodel.layer_weights.meta_weights.base_weight import BaseWeightTpl -from lightllm.common.basemodel.layer_weights.meta_weights.platform_op import PlatformAwareOp -from lightllm.common.basemodel.triton_kernel.embedding import embedding as embedding_kernel -from lightllm.utils.dist_utils import get_dp_world_size, get_current_rank_in_dp - - -class Qwen3OmniMoeConvNeXtBlockWeight(BaseWeightTpl, PlatformAwareOp): - def __init__( - self, - dim: int, - dwconv, - norm_weight_name: str, - norm_bias_name: str, - pwconv1_weight_name: str, - pwconv1_bias_name: str, - pwconv2_weight_name: str, - pwconv2_bias_name: str, - gamma_name: str, - data_type: torch.dtype, - eps: float = 1e-6, - ): - super().__init__() - self.dim = dim - self.dwconv = dwconv - self.norm_weight_name = norm_weight_name - self.norm_bias_name = norm_bias_name - self.pwconv1_weight_name = pwconv1_weight_name - self.pwconv1_bias_name = pwconv1_bias_name - self.pwconv2_weight_name = pwconv2_weight_name - self.pwconv2_bias_name = pwconv2_bias_name - self.gamma_name = gamma_name - self.data_type_ = data_type - self.eps = eps - - self._create_weight() - - def _create_weight(self): - self.norm_weight: torch.Tensor = torch.empty(self.dim, dtype=self.data_type_, device=self.device_id_) - self.norm_bias: torch.Tensor = torch.empty(self.dim, dtype=self.data_type_, device=self.device_id_) - - self.pwconv1_weight: torch.Tensor = torch.empty( - 4 * self.dim, self.dim, dtype=self.data_type_, device=self.device_id_ - ) - self.pwconv1_bias: torch.Tensor = torch.empty(4 * self.dim, dtype=self.data_type_, device=self.device_id_) - - self.pwconv2_weight: torch.Tensor = torch.empty( - self.dim, 4 * self.dim, dtype=self.data_type_, device=self.device_id_ - ) - self.pwconv2_bias: torch.Tensor = torch.empty(self.dim, dtype=self.data_type_, device=self.device_id_) - - self.gamma: torch.Tensor = torch.empty(self.dim, dtype=self.data_type_, device=self.device_id_) - - self.norm_weight.load_ok = False - self.norm_bias.load_ok = False - self.pwconv1_weight.load_ok = False - self.pwconv1_bias.load_ok = False - self.pwconv2_weight.load_ok = False - self.pwconv2_bias.load_ok = False - self.gamma.load_ok = False - - def load_hf_weights(self, weights: Dict[str, torch.Tensor]): - if self.norm_weight_name in weights: - t = weights[self.norm_weight_name] - assert t.shape == (self.dim,) - self.norm_weight.copy_(t.to(self.data_type_)) - self.norm_weight.load_ok = True - - if self.norm_bias_name in weights: - t = weights[self.norm_bias_name] - assert t.shape == (self.dim,) - self.norm_bias.copy_(t.to(self.data_type_)) - self.norm_bias.load_ok = True - - if self.pwconv1_weight_name in weights: - t = weights[self.pwconv1_weight_name] - assert t.shape == (4 * self.dim, self.dim) - self.pwconv1_weight.copy_(t.to(self.data_type_)) - self.pwconv1_weight.load_ok = True - - if self.pwconv1_bias_name in weights: - t = weights[self.pwconv1_bias_name] - assert t.shape == (4 * self.dim,) - self.pwconv1_bias.copy_(t.to(self.data_type_)) - self.pwconv1_bias.load_ok = True - - if self.pwconv2_weight_name in weights: - t = weights[self.pwconv2_weight_name] - assert t.shape == (self.dim, 4 * self.dim) - self.pwconv2_weight.copy_(t.to(self.data_type_)) - self.pwconv2_weight.load_ok = True - - if self.pwconv2_bias_name in weights: - t = weights[self.pwconv2_bias_name] - assert t.shape == (self.dim,) - self.pwconv2_bias.copy_(t.to(self.data_type_)) - self.pwconv2_bias.load_ok = True - - if self.gamma_name in weights: - t = weights[self.gamma_name] - assert t.shape == (self.dim,) - self.gamma.copy_(t.to(self.data_type_)) - self.gamma.load_ok = True - - def verify_load(self): - return ( - self.norm_weight.load_ok - and self.norm_bias.load_ok - and self.pwconv1_weight.load_ok - and self.pwconv1_bias.load_ok - and self.pwconv2_weight.load_ok - and self.pwconv2_bias.load_ok - and self.gamma.load_ok - ) - - def _native_forward( - self, hidden_states: torch.Tensor, out: Optional[torch.Tensor] = None, _alloc_func=torch.empty - ) -> torch.Tensor: - input = hidden_states - - hidden_states = self.dwconv(hidden_states) # [B, C, L] - hidden_states = hidden_states.permute(0, 2, 1) # [B, L, C] - - mean = hidden_states.mean(dim=-1, keepdim=True) - var = (hidden_states - mean).pow(2).mean(dim=-1, keepdim=True) - hidden_states = (hidden_states - mean) / torch.sqrt(var + self.eps) - hidden_states = hidden_states * self.norm_weight.view(1, 1, -1) + self.norm_bias.view(1, 1, -1) - - hidden_states = torch.nn.functional.linear(hidden_states, self.pwconv1_weight, self.pwconv1_bias) - hidden_states = torch.nn.functional.gelu(hidden_states) - hidden_states = torch.nn.functional.linear(hidden_states, self.pwconv2_weight, self.pwconv2_bias) - - hidden_states = hidden_states * self.gamma.view(1, 1, -1) - - hidden_states = hidden_states.permute(0, 2, 1) # [B, C, L] - hidden_states = input + hidden_states - - if out is not None: - out.copy_(hidden_states) - return out - return hidden_states - - def _cuda_forward( - self, hidden_states: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty - ) -> torch.Tensor: - if out is None: - result = self._native_forward(hidden_states=hidden_states, out=None, _alloc_func=alloc_func) - return result - result = self._native_forward(hidden_states=hidden_states, out=None, _alloc_func=alloc_func) - out.copy_(result) - return out - - def _musa_forward( - self, hidden_states: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty - ) -> torch.Tensor: - return self._cuda_forward(hidden_states=hidden_states, out=out, alloc_func=alloc_func) - - def __call__( - self, hidden_states: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty - ) -> torch.Tensor: - return self._forward(hidden_states=hidden_states, out=out, alloc_func=alloc_func) diff --git a/lightllm/models/qwen3_omni_moe_thinker/layer_weights/meta_weights/talker_resize_mlp_weight.py b/lightllm/models/qwen3_omni_moe_thinker/layer_weights/meta_weights/talker_resize_mlp_weight.py deleted file mode 100644 index fac9c5cdc7..0000000000 --- a/lightllm/models/qwen3_omni_moe_thinker/layer_weights/meta_weights/talker_resize_mlp_weight.py +++ /dev/null @@ -1,109 +0,0 @@ -import torch -import numpy as np -from typing import Dict, Optional -from transformers.activations import ACT2FN -from lightllm.common.basemodel.layer_weights.meta_weights.base_weight import BaseWeightTpl -from lightllm.common.basemodel.layer_weights.meta_weights.platform_op import PlatformAwareOp -from lightllm.common.basemodel.triton_kernel.embedding import embedding as embedding_kernel -from lightllm.utils.dist_utils import get_dp_world_size, get_current_rank_in_dp - - -class Qwen3OmniMoeTalkerResizeMLPWeight(BaseWeightTpl, PlatformAwareOp): - def __init__( - self, - in_dim: int, - intermediate_dim: int, - out_dim: int, - fc1_weight_name: str, - fc1_bias_name: str, - fc2_weight_name: str, - fc2_bias_name: str, - hidden_act: str, - data_type: torch.dtype, - ): - super().__init__() - self.in_dim = in_dim - self.intermediate_dim = intermediate_dim - self.out_dim = out_dim - self.fc1_weight_name: str = fc1_weight_name - self.fc1_bias_name: str = fc1_bias_name - self.fc2_weight_name: str = fc2_weight_name - self.fc2_bias_name: str = fc2_bias_name - self.data_type_ = data_type - self.act_fn = ACT2FN[hidden_act] - self._create_weight() - - def _create_weight(self): - self.fc1_weight: torch.Tensor = torch.empty( - self.intermediate_dim, self.in_dim, dtype=self.data_type_, device=self.device_id_ - ) - self.fc1_bias: torch.Tensor = torch.empty(self.intermediate_dim, dtype=self.data_type_, device=self.device_id_) - self.fc2_weight: torch.Tensor = torch.empty( - self.out_dim, self.intermediate_dim, dtype=self.data_type_, device=self.device_id_ - ) - self.fc2_bias: torch.Tensor = torch.empty(self.out_dim, dtype=self.data_type_, device=self.device_id_) - self.fc1_weight.load_ok = False - self.fc1_bias.load_ok = False - self.fc2_weight.load_ok = False - self.fc2_bias.load_ok = False - - def load_hf_weights(self, weights: Dict[str, torch.Tensor]): - if self.fc1_weight_name in weights: - t = weights[self.fc1_weight_name] - assert t.shape == (self.intermediate_dim, self.in_dim) - self.fc1_weight.copy_(t.to(self.data_type_)) - self.fc1_weight.load_ok = True - if self.fc1_bias_name in weights: - t = weights[self.fc1_bias_name] - assert t.shape == (self.intermediate_dim,) - self.fc1_bias.copy_(t.to(self.data_type_)) - self.fc1_bias.load_ok = True - if self.fc2_weight_name in weights: - t = weights[self.fc2_weight_name] - assert t.shape == (self.out_dim, self.intermediate_dim) - self.fc2_weight.copy_(t.to(self.data_type_)) - self.fc2_weight.load_ok = True - if self.fc2_bias_name in weights: - t = weights[self.fc2_bias_name] - assert t.shape == (self.out_dim,) - self.fc2_bias.copy_(t.to(self.data_type_)) - self.fc2_bias.load_ok = True - - def verify_load(self): - return self.fc1_weight.load_ok and self.fc1_bias.load_ok and self.fc2_weight.load_ok and self.fc2_bias.load_ok - - def _native_forward( - self, hidden_state: torch.Tensor, out: Optional[torch.Tensor] = None, _alloc_func=torch.empty - ) -> torch.Tensor: - in_dim = hidden_state.shape[-1] - assert in_dim == self.in_dim - x = hidden_state.reshape(-1, in_dim) - y = torch.nn.functional.linear(x, self.fc1_weight, self.fc1_bias) - y = self.act_fn(y) - y = torch.nn.functional.linear(y, self.fc2_weight, self.fc2_bias) - y = y.reshape(*hidden_state.shape[:-1], self.out_dim) - if out is not None: - out.copy_(y) - return out - return y - - def _cuda_forward( - self, hidden_state: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty - ) -> torch.Tensor: - if out is None: - out = alloc_func( - (*hidden_state.shape[:-1], self.out_dim), dtype=hidden_state.dtype, device=hidden_state.device - ) - result = self._native_forward(hidden_state=hidden_state, out=None, _alloc_func=alloc_func) - out.copy_(result) - return out - - def _musa_forward( - self, hidden_state: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty - ) -> torch.Tensor: - return self._cuda_forward(hidden_state=hidden_state, out=out, alloc_func=alloc_func) - - def __call__( - self, hidden_state: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty - ) -> torch.Tensor: - return self._forward(hidden_state=hidden_state, out=out, alloc_func=alloc_func)