diff --git a/lightllm/common/basemodel/layer_weights/base_layer_weight.py b/lightllm/common/basemodel/layer_weights/base_layer_weight.py index 1875e2c3b..b1d992a7c 100644 --- a/lightllm/common/basemodel/layer_weights/base_layer_weight.py +++ b/lightllm/common/basemodel/layer_weights/base_layer_weight.py @@ -33,7 +33,11 @@ def verify_load(self): for attr_name in dir(self): attr = getattr(self, attr_name) if isinstance(attr, BaseWeight): - assert attr.verify_load(), f"Loading {attr_name} of layers {self.layer_num_} fails." + if hasattr(self, "layer_num_"): + layer_num = self.layer_num_ + else: + layer_num = None + assert attr.verify_load(), f"Loading {attr_name} of layers {layer_num} fails." def _cuda(self, cpu_tensor): return cpu_tensor.contiguous().to(self.data_type_).cuda(get_current_device_id()) diff --git a/lightllm/models/__init__.py b/lightllm/models/__init__.py index 095f73679..32ccbe833 100644 --- a/lightllm/models/__init__.py +++ b/lightllm/models/__init__.py @@ -37,4 +37,5 @@ Tarsier2LlamaTpPartModel, ) from lightllm.models.gpt_oss.model import GptOssTpPartModel +from lightllm.models.qwen3_omni_moe_thinker.model import Qwen3OmniMOETpPartModel from .registry import get_model, get_model_class diff --git a/lightllm/models/qwen2_vl/triton_kernel/get_mrope_position_ids.py b/lightllm/models/qwen2_vl/triton_kernel/get_mrope_position_ids.py index 756198e89..eace676c0 100644 --- a/lightllm/models/qwen2_vl/triton_kernel/get_mrope_position_ids.py +++ b/lightllm/models/qwen2_vl/triton_kernel/get_mrope_position_ids.py @@ -28,13 +28,13 @@ def _get_mrope_position_triton( local_image_start_idx = tl.load(b_image_start_idx + image_start_num + i) image_start_idx = start_loc + local_image_start_idx - cache_len image_len = tl.load(b_image_len + image_start_num + i) - image_h = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 1) + # image_h = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 1) image_w = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 2) for j in range(0, image_len, BLOCK_SIZE): off = j + tl.arange(0, BLOCK_SIZE) # 目前没考虑视频,所以t 恒为 0 t_pos = local_image_start_idx + off * 0 - h_pos = local_image_start_idx + off // image_h + h_pos = local_image_start_idx + off // image_w w_pos = local_image_start_idx + off % image_w tl.store( position_ids + off + image_start_idx, diff --git a/lightllm/models/qwen3_omni_moe_thinker/__init__.py b/lightllm/models/qwen3_omni_moe_thinker/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/qwen3_omni_moe_thinker/audio_process.py b/lightllm/models/qwen3_omni_moe_thinker/audio_process.py new file mode 100644 index 000000000..833cc8f4b --- /dev/null +++ b/lightllm/models/qwen3_omni_moe_thinker/audio_process.py @@ -0,0 +1,185 @@ +import torch +import numpy as np +from typing import TYPE_CHECKING, Any, Optional, Union, Tuple +from transformers.audio_utils import mel_filter_bank, spectrogram, window_function +from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor +from transformers.feature_extraction_utils import BatchFeature +from transformers.utils import TensorType + + +class WhisperFeatureExtractor(SequenceFeatureExtractor): + + model_input_names = ["input_features"] + + def __init__( + self, + feature_size=80, + sampling_rate=16000, + hop_length=160, + chunk_length=30, + n_fft=400, + padding_value=0.0, + dither=0.0, + return_attention_mask=False, # pad inputs to max length with silence token (zero) and no attention mask + **kwargs, + ): + super().__init__( + feature_size=feature_size, + sampling_rate=sampling_rate, + padding_value=padding_value, + return_attention_mask=return_attention_mask, + **kwargs, + ) + self.n_fft = n_fft + self.hop_length = hop_length + self.chunk_length = chunk_length + self.n_samples = chunk_length * sampling_rate + self.nb_max_frames = self.n_samples // hop_length + self.sampling_rate = sampling_rate + self.dither = dither + self.mel_filters = mel_filter_bank( + num_frequency_bins=1 + n_fft // 2, + num_mel_filters=feature_size, + min_frequency=0.0, + max_frequency=8000.0, + sampling_rate=sampling_rate, + norm="slaney", + mel_scale="slaney", + ) + + def _torch_extract_fbank_features(self, waveform: np.ndarray, device: str = "cpu") -> np.ndarray: + waveform = torch.from_numpy(waveform).to(device, torch.float32) + window = torch.hann_window(self.n_fft, device=device) + + if self.dither != 0.0: + waveform += self.dither * torch.randn(waveform.shape, dtype=waveform.dtype, device=waveform.device) + + stft = torch.stft(waveform, self.n_fft, self.hop_length, window=window, return_complex=True) + magnitudes = stft[..., :-1].abs() ** 2 + + mel_filters = torch.from_numpy(self.mel_filters).to(device, torch.float32) + mel_spec = mel_filters.T @ magnitudes + + log_spec = torch.clamp(mel_spec, min=1e-10).log10() + if waveform.dim() == 2: + max_val = log_spec.max(dim=2, keepdim=True)[0].max(dim=1, keepdim=True)[0] + log_spec = torch.maximum(log_spec, max_val - 8.0) + else: + log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) + log_spec = (log_spec + 4.0) / 4.0 + if device != "cpu": + log_spec = log_spec.detach().cpu() + return log_spec.numpy() + + # Copied from transformers.models.wav2vec2.feature_extraction_wav2vec2. + # Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm + def zero_mean_unit_var_norm( + self, input_values: list[np.ndarray], attention_mask: list[np.ndarray], padding_value: float = 0.0 + ) -> list[np.ndarray]: + if attention_mask is not None: + attention_mask = np.array(attention_mask, np.int32) + normed_input_values = [] + + for vector, length in zip(input_values, attention_mask.sum(-1)): + normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7) + if length < normed_slice.shape[0]: + normed_slice[length:] = padding_value + + normed_input_values.append(normed_slice) + else: + normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values] + + return normed_input_values + + def _preprocess( + self, + raw_speech: Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]], + truncation: bool = True, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_attention_mask: Optional[bool] = None, + padding: Optional[str] = "longest", # max_length代表padding到max_length + max_length: Optional[int] = None, + sampling_rate: Optional[int] = 16000, + do_normalize: Optional[bool] = None, + device: Optional[str] = "cpu", + return_token_timestamps: Optional[bool] = None, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + + is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1 + if is_batched_numpy and len(raw_speech.shape) > 2: + raise ValueError(f"Only mono-channel audio is supported for input to {self}") + is_batched = is_batched_numpy or ( + isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list))) + ) + + if is_batched: + raw_speech = [np.asarray([speech], dtype=np.float32).T for speech in raw_speech] + elif not is_batched and not isinstance(raw_speech, np.ndarray): + raw_speech = np.asarray(raw_speech, dtype=np.float32) + elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64): + raw_speech = raw_speech.astype(np.float32) + + # always return batch + if not is_batched: + raw_speech = [np.asarray([raw_speech]).T] + + batched_speech = BatchFeature({"input_features": raw_speech}) + + # convert into correct format for padding + + padded_inputs = self.pad( + batched_speech, + padding=padding, + max_length=max_length if max_length else self.n_samples, + truncation=truncation, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=return_attention_mask or do_normalize, + ) + + # zero-mean and unit-variance normalization + if do_normalize: + padded_inputs["input_features"] = self.zero_mean_unit_var_norm( + padded_inputs["input_features"], + attention_mask=padded_inputs["attention_mask"], + padding_value=self.padding_value, + ) + padded_inputs["input_features"] = np.stack(padded_inputs["input_features"], axis=0) + + # make sure list is in array format + input_features = padded_inputs.get("input_features").transpose(2, 0, 1) + + input_features = self._torch_extract_fbank_features(input_features[0], device) + + if isinstance(input_features[0], list): + padded_inputs["input_features"] = [np.asarray(feature, dtype=np.float32) for feature in input_features] + + else: + padded_inputs["input_features"] = input_features + + if return_attention_mask: + # rescale from sample (48000) to feature (3000) + rescaled_attention_mask = padded_inputs["attention_mask"][:, :: self.hop_length] + + # The STFT computation produces L//hop_length + 1 frames, + # but we skip the last frame (see `_torch_extract_fbank_features`). + # This means we need to trim the rescaled attention mask to match + # the actual number of frames (L//hop_length) when the input length + # is not perfectly divisible by the hop length. + if padded_inputs["attention_mask"].shape[1] % self.hop_length != 0: + rescaled_attention_mask = rescaled_attention_mask[:, :-1] + padded_inputs["attention_mask"] = rescaled_attention_mask + + if return_token_timestamps is not None: + padded_inputs["num_frames"] = [len(raw_speech_i) // self.hop_length for raw_speech_i in raw_speech] + + if return_tensors is not None: + padded_inputs = padded_inputs.convert_to_tensors(return_tensors) + input_features = torch.from_numpy(np.asarray(padded_inputs["input_features"], dtype=np.float32)).to( + device="cuda", dtype=torch.bfloat16 + ) + attention_mask = torch.from_numpy(np.asarray(padded_inputs["attention_mask"], dtype=np.float32)).to( + device="cuda", dtype=torch.int32 + ) + return input_features, attention_mask diff --git a/lightllm/models/qwen3_omni_moe_thinker/infer_struct.py b/lightllm/models/qwen3_omni_moe_thinker/infer_struct.py new file mode 100644 index 000000000..1c09ebf44 --- /dev/null +++ b/lightllm/models/qwen3_omni_moe_thinker/infer_struct.py @@ -0,0 +1,6 @@ +from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo + + +class Qwen3OmniMOEInferStateInfo(Qwen3VLInferStateInfo): + def __init__(self): + super().__init__() diff --git a/lightllm/models/qwen3_omni_moe_thinker/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_omni_moe_thinker/layer_infer/transformer_layer_infer.py new file mode 100644 index 000000000..1a05a752f --- /dev/null +++ b/lightllm/models/qwen3_omni_moe_thinker/layer_infer/transformer_layer_infer.py @@ -0,0 +1,15 @@ +import torch +from lightllm.models.qwen3_vl_moe.layer_infer.transformer_layer_infer import Qwen3VLMOETransformerLayerInfer +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class Qwen3OmniMOETransformerLayerInfer(Qwen3VLMOETransformerLayerInfer): + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) + self.head_dim_ = network_config["head_dim"] + self.mrope_section = torch.tensor( + network_config["rope_scaling"]["mrope_section"], dtype=torch.int32, device="cuda" + ) + return diff --git a/lightllm/models/qwen3_omni_moe_thinker/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen3_omni_moe_thinker/layer_weights/pre_and_post_layer_weight.py new file mode 100644 index 000000000..5ac8060c4 --- /dev/null +++ b/lightllm/models/qwen3_omni_moe_thinker/layer_weights/pre_and_post_layer_weight.py @@ -0,0 +1,30 @@ +from lightllm.models.qwen2.layer_weights.pre_and_post_layer_weight import Qwen2PreAndPostLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LMHeadWeight, RMSNormWeight + + +class Qwen3OmniMOEThinkerPreAndPostLayerWeight(Qwen2PreAndPostLayerWeight): + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) + + hidden_size = network_config["hidden_size"] + vocab_size = network_config["vocab_size"] + self.wte_weight_ = EmbeddingWeight( + dim=hidden_size, + vocab_size=vocab_size, + weight_name="thinker.model.embed_tokens.weight", + data_type=self.data_type_, + ) + tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False) + self.lm_head_weight_ = LMHeadWeight( + dim=hidden_size, + vocab_size=vocab_size, + weight_name="thinker.lm_head.weight", + data_type=self.data_type_, + embedding_weight=self.wte_weight_ if tie_word_embeddings else None, + ) + self.final_norm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name="thinker.model.norm.weight", + data_type=self.data_type_, + ) + return diff --git a/lightllm/models/qwen3_omni_moe_thinker/layer_weights/transformers_layer_weight.py b/lightllm/models/qwen3_omni_moe_thinker/layer_weights/transformers_layer_weight.py new file mode 100644 index 000000000..775ba5ffe --- /dev/null +++ b/lightllm/models/qwen3_omni_moe_thinker/layer_weights/transformers_layer_weight.py @@ -0,0 +1,53 @@ +import os +from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, FusedMoeWeight + + +class Qwen3OmniMOEThinkerTransformerLayerWeight(Qwen3MOETransformerLayerWeight): + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) + return + + def _init_weight_names(self): + self._q_weight_name = f"thinker.model.layers.{self.layer_num_}.self_attn.q_proj.weight" + self._q_norm_name = f"thinker.model.layers.{self.layer_num_}.self_attn.q_norm.weight" + self._q_bias_name = None + self._k_weight_name = f"thinker.model.layers.{self.layer_num_}.self_attn.k_proj.weight" + self._k_norm_name = f"thinker.model.layers.{self.layer_num_}.self_attn.k_norm.weight" + self._k_bias_name = None + self._v_weight_name = f"thinker.model.layers.{self.layer_num_}.self_attn.v_proj.weight" + self._v_bias_name = None + self._kv_weight_name = f"thinker.model.layers.{self.layer_num_}.self_attn.kv_proj.weight" + self._kv_bias_name = None + self._o_weight_name = f"thinker.model.layers.{self.layer_num_}.self_attn.o_proj.weight" + self._o_bias_name = None + self._att_norm_weight_name = f"thinker.model.layers.{self.layer_num_}.input_layernorm.weight" + self._att_norm_bias_name = None + self._ffn_norm_weight_name = f"thinker.model.layers.{self.layer_num_}.post_attention_layernorm.weight" + self._ffn_norm_bias_name = None + + def _init_moe(self): + moe_intermediate_size = self.network_config_["moe_intermediate_size"] + self.moe_gate = ROWMMWeight( + in_dim=self.network_config_["hidden_size"], + out_dims=[self.n_routed_experts], + weight_names=f"thinker.model.layers.{self.layer_num_}.mlp.gate.weight", + data_type=self.data_type_, + quant_method=None, + tp_rank=0, + tp_world_size=1, + ) + self.experts = FusedMoeWeight( + gate_proj_name="gate_proj", + down_proj_name="down_proj", + up_proj_name="up_proj", + e_score_correction_bias_name="", + weight_prefix=f"thinker.model.layers.{self.layer_num_}.mlp.experts", + n_routed_experts=self.n_routed_experts, + hidden_size=self.network_config_["hidden_size"], + moe_intermediate_size=moe_intermediate_size, + data_type=self.data_type_, + quant_method=self.quant_cfg.get_quant_method(self.layer_num_, "fused_moe"), + layer_num=self.layer_num_, + network_config=self.network_config_, + ) diff --git a/lightllm/models/qwen3_omni_moe_thinker/model.py b/lightllm/models/qwen3_omni_moe_thinker/model.py new file mode 100644 index 000000000..2e863da00 --- /dev/null +++ b/lightllm/models/qwen3_omni_moe_thinker/model.py @@ -0,0 +1,158 @@ +import os +import json +import librosa +from io import BytesIO +from lightllm.common.build_utils import repair_config +from lightllm.models.registry import ModelRegistry +from lightllm.models.qwen3_moe.model import Qwen3MOEModel +from lightllm.models.qwen3_vl.layer_infer.pre_layer_infer import Qwen3VLMultimodalPreLayerInfer + +from lightllm.models.qwen3_omni_moe_thinker.layer_infer.transformer_layer_infer import Qwen3OmniMOETransformerLayerInfer +from lightllm.models.qwen3_omni_moe_thinker.layer_weights.pre_and_post_layer_weight import ( + Qwen3OmniMOEThinkerPreAndPostLayerWeight, +) +from lightllm.models.qwen3_omni_moe_thinker.layer_weights.transformers_layer_weight import ( + Qwen3OmniMOEThinkerTransformerLayerWeight, +) + +from lightllm.models.qwen3_vl_moe.model import Qwen3VLMOETpPartModel +from lightllm.models.qwen3_omni_moe_thinker.infer_struct import Qwen3OmniMOEInferStateInfo +from lightllm.models.qwen3_vl.model import QWen3VLTokenizer +from lightllm.server.core.objs import SamplingParams +from lightllm.server.multimodal_params import AudioItem, MultimodalParams, ImageItem + + +MIN_AUDIO_LEN = 480 + + +class QWen3OmniTokenizer(QWen3VLTokenizer): + def __init__(self, tokenizer=None, processor=None, **kwargs): + self.tokenizer = tokenizer + # image + self.image_processor = processor.image_processor + self.min_pixel = self.image_processor.min_pixels + self.max_pixel = self.image_processor.max_pixels + self.patch_size = self.image_processor.patch_size + self.merge_size = self.image_processor.merge_size + + # audio + self.audio_processor = processor.feature_extractor + self.sampling_rate = self.audio_processor.sampling_rate + self.n_samples = self.audio_processor.n_samples + self.hop_length = self.audio_processor.hop_length + + self.image_start_id = kwargs["model_cfg"]["vision_start_token_id"] + self.image_end_id = kwargs["model_cfg"]["vision_end_token_id"] + self.image_token_id = kwargs["model_cfg"]["image_token_id"] + + self.audio_start_id = kwargs["model_cfg"]["audio_start_token_id"] + self.audio_end_id = kwargs["model_cfg"]["audio_end_token_id"] + self.audio_token_id = kwargs["model_cfg"]["audio_token_id"] + + def init_audioitem_extral_params( + self, audio: AudioItem, multi_params: MultimodalParams, sampling_params: SamplingParams + ): + return + + def get_audio_token_length(self, audio: AudioItem): + # 这里得处理对应奖语音长度按照 30 进行限制,后续处理中,超过30的会被截断。 + length = min(audio.audio_length, int(self.n_samples)) + token_num = self._caclu_audio_token_num(length) + return token_num + + def _caclu_audio_token_num(self, input_audio_len: int): + _mel_len = input_audio_len // int(self.hop_length) + input_lengths_leave = _mel_len % 100 + feat_lengths = (input_lengths_leave - 1) // 2 + 1 + output_lengths = ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (_mel_len // 100) * 13 + return output_lengths + + def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs): + origin_ids = self.tokenizer.encode(prompt) + + # -> + origin_ids = [token for token in origin_ids if token not in (self.image_token_id, self.audio_token_id)] + # --> id,id+1...id+num + input_ids = [] + image_id = 0 + while True: + try: + start_idx = origin_ids.index(self.image_start_id) + if start_idx + 1 >= len(origin_ids): + break + if origin_ids[start_idx + 1] == self.image_end_id: + input_ids.extend(origin_ids[: start_idx + 1]) + token_id = multimodal_params.images[image_id].token_id + token_num = multimodal_params.images[image_id].token_num + multimodal_params.images[image_id].start_idx = len(input_ids) + input_ids.extend(range(token_id, token_id + token_num)) + input_ids.append(self.image_end_id) + origin_ids = origin_ids[start_idx + 2 :] + image_id += 1 + else: + raise ValueError("image token error") + except ValueError: + break + if multimodal_params: + image_cnt = len(multimodal_params.images) + if image_cnt != image_id: + raise ValueError(image_cnt == image_id, f"invalid image tag num: {image_cnt} vs {image_id}!") + input_ids.extend(origin_ids) + + # audio + origin_ids = input_ids + input_ids = [] + audio_id = 0 + start_idx = 0 + while True: + try: + start_idx = origin_ids.index(self.audio_start_id) + if start_idx + 1 >= len(origin_ids): + break + if origin_ids[start_idx + 1] == self.audio_end_id: + input_ids.extend(origin_ids[: start_idx + 1]) + token_id = multimodal_params.audios[audio_id].token_id + token_num = multimodal_params.audios[audio_id].token_num + input_ids.extend(range(token_id, token_id + token_num)) + input_ids.append(self.audio_end_id) + origin_ids = origin_ids[start_idx + 2 :] + audio_id += 1 + else: + raise ValueError("audio token error") + except ValueError: + break + if multimodal_params: + audio_cnt = len(multimodal_params.audios) + if audio_cnt != audio_id: + raise ValueError(audio_cnt == audio_id, f"invalid audio tag num: {audio_cnt} vs {audio_id}!") + input_ids.extend(origin_ids) + + return input_ids + + +@ModelRegistry(["qwen3_omni_moe"], is_multimodal=True) +class Qwen3OmniMOETpPartModel(Qwen3VLMOETpPartModel): + + pre_layer_infer_class = Qwen3VLMultimodalPreLayerInfer + transformer_layer_infer_class = Qwen3OmniMOETransformerLayerInfer + + pre_and_post_weight_class = Qwen3OmniMOEThinkerPreAndPostLayerWeight + transformer_weight_class = Qwen3OmniMOEThinkerTransformerLayerWeight + + infer_state_class = Qwen3OmniMOEInferStateInfo + + def __init__(self, kvargs): + super().__init__(kvargs) + return + + def _init_config(self): + with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: + all_config = json.load(json_file) + self.config = all_config["thinker_config"]["text_config"] + # rename keys + repair_config(self.config, same_names=["num_attention_heads", "n_head"]) + repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"]) + repair_config(self.config, same_names=["num_hidden_layers", "n_layer"]) + if self.finetune_config: + self.config["vocab_size"] = self.finetune_config.vocab_size + return diff --git a/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_audio.py b/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_audio.py new file mode 100644 index 000000000..c66033f53 --- /dev/null +++ b/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_audio.py @@ -0,0 +1,388 @@ +import os +import json +import math +import torch +import rpyc +import librosa +import numpy as np +from io import BytesIO +from torch import Tensor, nn +from safetensors import safe_open +from torch.nn import functional as F +from typing import Callable, Optional, Union, List +from rpyc.utils.classic import obtain + +from transformers.activations import ACT2FN + +from lightllm.server.multimodal_params import AudioItem +from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data +from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient +from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager +from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd +from lightllm.models.qwen3_omni_moe_thinker.audio_process import WhisperFeatureExtractor + + +def _get_feat_extract_output_lengths(input_lengths): + """ + Computes the output length of the convolutional layers and the output length of the audio encoder + """ + + input_lengths_leave = input_lengths % 100 + feat_lengths = (input_lengths_leave - 1) // 2 + 1 + output_lengths = ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13 + return output_lengths + + +class Qwen3OmniMoeAudioEncoderLayer(nn.Module): + def __init__( + self, + d_model, + encoder_attention_heads, + attention_dropout, + dropout, + activation_function, + activation_dropout, + encoder_ffn_dim, + ): + super().__init__() + self.embed_dim = d_model + self.self_attn = Qwen3OmniMoeAudioAttention(d_model, encoder_attention_heads, attention_dropout) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = dropout + self.activation_fn = ACT2FN[activation_function] + self.activation_dropout = activation_dropout + self.fc1 = nn.Linear(self.embed_dim, encoder_ffn_dim) + self.fc2 = nn.Linear(encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + **kwargs, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + **kwargs, + ) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + hidden_states = residual + hidden_states + + if hidden_states.dtype == torch.float16: + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + return outputs + + +class Qwen3OmniMoeAudioAttention(nn.Module): + def __init__(self, d_model, encoder_attention_heads, attention_dropout): + super().__init__() + self.embed_dim = d_model + self.num_heads = encoder_attention_heads + self.dropout = attention_dropout + self.head_dim = self.embed_dim // self.num_heads + self.num_key_value_groups = 1 # needed for eager attention + + if (self.head_dim * self.num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {self.num_heads})." + ) + self.scaling = self.head_dim ** -0.5 + self.attention_dropout = 0.0 + self.is_decoder = False + self.is_causal = False + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: int = 0, + **kwargs, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + seq_length, _ = hidden_states.size() + + q = self.q_proj(hidden_states).reshape(seq_length, self.num_heads, -1) + k = self.k_proj(hidden_states).reshape(seq_length, self.num_heads, -1) + v = self.v_proj(hidden_states).reshape(seq_length, self.num_heads, -1) + + attn_output = g_cache_manager.alloc_tensor(q.shape, q.dtype, device=q.device) + + flash_attention_fwd(q, k, v, attn_output, cu_seqlens, max_seqlen) + + attn_output = attn_output.reshape(seq_length, -1).contiguous() + attn_output = self.out_proj(attn_output) + return attn_output + + +class SinusoidsPositionEmbedding(nn.Module): + def __init__(self, length, channels, max_timescale=10000): + super().__init__() + if channels % 2 != 0: + raise ValueError("SinusoidsPositionEmbedding needs even channels input") + log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) + inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2).float()) + scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] + self.positional_embedding = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) + + def forward(self, seqlen: int): + return self.positional_embedding[:seqlen, :] + + +class Qwen3OmniMoeAudioEncoder(nn.Module): + def __init__( + self, + kvargs, + dropout=0, + d_model=1280, + num_mel_bins=128, + max_source_positions=1500, + scale_embedding=False, + n_window=50, + encoder_layers=32, + downsample_hidden_size=480, + activation_function="gelu", + output_dim=2048, + n_window_infer=800, + conv_chunksize=500, + encoder_attention_heads=20, + attention_dropout=0, + activation_dropout=0, + encoder_ffn_dim=5120, + ): + super().__init__() + self.data_type = kvargs.get("data_type", "bfloat16") + self.dropout = dropout + self.embed_dim = d_model + self.num_mel_bins = num_mel_bins + self.max_source_positions = max_source_positions + self.embed_scale = math.sqrt(self.embed_dim) if scale_embedding else 1.0 + self.n_window = n_window + self.positional_embedding = SinusoidsPositionEmbedding(self.max_source_positions, self.embed_dim) + self.layers = nn.ModuleList( + [ + Qwen3OmniMoeAudioEncoderLayer( + d_model, + encoder_attention_heads, + attention_dropout, + dropout, + activation_function, + activation_dropout, + encoder_ffn_dim, + ) + for _ in range(encoder_layers) + ] + ) + self.ln_post = nn.LayerNorm(d_model) + self.gradient_checkpointing = False + self.conv2d1 = nn.Conv2d(1, downsample_hidden_size, 3, 2, padding=1) + self.conv2d2 = nn.Conv2d(downsample_hidden_size, downsample_hidden_size, 3, 2, padding=1) + self.conv2d3 = nn.Conv2d(downsample_hidden_size, downsample_hidden_size, 3, 2, padding=1) + self.conv_out = nn.Linear( + downsample_hidden_size * ((((num_mel_bins + 1) // 2 + 1) // 2 + 1) // 2), + d_model, + bias=False, + ) + self.proj1 = nn.Linear(d_model, d_model) + self.act = ACT2FN[activation_function] + self.proj2 = nn.Linear(d_model, output_dim) + self.n_window_infer = n_window_infer + self.conv_chunksize = conv_chunksize + + self.cache_port = kvargs["cache_port"] + self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True}) + self._init_datatype() + + def _init_datatype(self): + if isinstance(self.data_type, torch.dtype): + return + if self.data_type in ["fp16", "float16"]: + self.data_type = torch.float16 + elif self.data_type in ["bf16", "bfloat16"]: + self.data_type = torch.bfloat16 + elif self.data_type in ["fp32", "float32"]: + self.data_type = torch.float32 + else: + raise ValueError(f"Unsupport datatype {self.data_type}!") + return + + def _freeze_parameters(self): + for param in self.parameters(): + param.requires_grad = False + self._requires_grad = False + + def get_input_embeddings(self) -> nn.Module: + return self.conv1 + + def set_input_embeddings(self, value: nn.Module): + self.conv1 = value + + def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): + """ + Computes the output length of the convolutional layers and the output length of the audio encoder + """ + input_lengths = (input_lengths - 1) // 2 + 1 + output_lengths = (input_lengths - 2) // 2 + 1 + return input_lengths, output_lengths + + def load_model(self, weight_dir, config): + processor_config_path = os.path.join(weight_dir, "preprocessor_config.json") + with open(processor_config_path, "r") as f: + processor_config_dict = json.load(f) + self.processor = WhisperFeatureExtractor(**processor_config_dict) + + bin_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")] + if bin_weight_files: + weight_dict = {} + for file_ in bin_weight_files: + f = torch.load(os.path.join(weight_dir, file_), "cpu") + for k, v in f.items(): + if "thinker.audio_tower" in k: + weight_dict[k[len("thinker.audio_tower.") :]] = v + else: + hf_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".safetensors")] + weight_dict = {} + for file_ in hf_weight_files: + f = safe_open(os.path.join(weight_dir, file_), "pt", "cpu") + for k in f.keys(): + if "thinker.audio_tower" in k: + weight_dict[k[len("thinker.audio_tower.") :]] = f.get_tensor(k) + + self.load_state_dict(weight_dict) + + def forward( + self, + input_features, + feature_lens=None, + aftercnn_lens=None, + ): + aftercnn_lens = _get_feat_extract_output_lengths(feature_lens) + chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long() + + chunk_lengths = torch.tensor( + [self.n_window * 2] * chunk_num.sum(), + dtype=torch.long, + device=feature_lens.device, + ) + tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:] + chunk_lengths[tail_chunk_index] = feature_lens % (self.n_window * 2) + chunk_lengths[chunk_lengths == 0] = self.n_window * 2 + + chunk_list = input_features.T.split(chunk_lengths.tolist(), dim=0) + padded_feature = nn.utils.rnn.pad_sequence(chunk_list, batch_first=True).transpose(1, 2) + feature_lens_after_cnn = _get_feat_extract_output_lengths(chunk_lengths) + padded_mask_after_cnn = nn.utils.rnn.pad_sequence( + [torch.ones(length, dtype=torch.bool, device=padded_feature.device) for length in feature_lens_after_cnn], + batch_first=True, + ) + padded_feature = padded_feature.unsqueeze(1) + # Split to chunk to avoid OOM during convolution + padded_embeds = [] + for chunk in padded_feature.split(self.conv_chunksize, dim=0): + padded_embed = F.gelu(self.conv2d1(chunk)) + padded_embed = F.gelu(self.conv2d2(padded_embed)) + padded_embed = F.gelu(self.conv2d3(padded_embed)) + padded_embeds.append(padded_embed) + padded_embed = torch.cat(padded_embeds, dim=0) + b, c, f, t = padded_embed.size() + padded_embed = self.conv_out(padded_embed.permute(0, 3, 1, 2).contiguous().view(b, t, c * f)) + + positional_embedding = ( + self.positional_embedding.positional_embedding[: padded_embed.shape[1], :] + .unsqueeze(0) + .to(padded_embed.dtype) + .to(padded_embed.device) + ) + padded_embed = padded_embed + positional_embedding + hidden_states = padded_embed[padded_mask_after_cnn] + cu_chunk_lens = [0] + window_aftercnn = padded_mask_after_cnn.shape[-1] * (self.n_window_infer // (self.n_window * 2)) + for cnn_len in aftercnn_lens: + cu_chunk_lens += [window_aftercnn] * (cnn_len // window_aftercnn) + remainder = cnn_len % window_aftercnn + if remainder != 0: + cu_chunk_lens += [remainder] + cu_seqlens = torch.tensor(cu_chunk_lens, device=aftercnn_lens.device).cumsum(-1, dtype=torch.int32) + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + for encoder_layer in self.layers: + layer_outputs = encoder_layer( + hidden_states, + cu_seqlens, + max_seqlen, + ) + hidden_states = layer_outputs[0] + + hidden_states = self.ln_post(hidden_states) + hidden_states = self.proj1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.proj2(hidden_states) + return hidden_states + + def encode(self, audio_items: List[AudioItem], cpu_embed_cache_client: CpuEmbedCacheClient): + uuids = [] + items: List[AudioItem] = [] + per_audio_features: List[torch.Tensor] = [] + for i, item in enumerate(audio_items): + if isinstance(item, AudioItem): + uuids.append(item.uuid) + items.append(item) + audio_data = read_shm(get_shm_name_data(item.uuid)) + audio = BytesIO(audio_data) + audio, _ = librosa.load(audio, sr=self.processor.sampling_rate) + else: + raise ValueError(f"cannot read audio which type is {type(item)}!") + + input_features, feature_attention_mask = self.processor._preprocess(audio, return_attention_mask=True) + if feature_attention_mask is not None: + audio_feature_lengths = torch.sum(feature_attention_mask, dim=1) + input_features = input_features.permute(0, 2, 1)[feature_attention_mask.bool()].permute(1, 0) + else: + audio_feature_lengths = None + + feature_lens = ( + audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) + ) + + audio_features = self.forward( + input_features, + feature_lens=feature_lens, + ) + per_audio_features.append(audio_features) + + ready_audio = obtain(self.cache_client.root.get_items_embed(uuids)) + ids_to_set = [] + for i, ready in enumerate(ready_audio): + if ready: + continue + + uid = uuids[i] + item = items[i] + + cur_embed = per_audio_features[i] + cpu_embed_cache_client.copy_to_cache( + embed_tensor=cur_embed, start_index_in_cache=item.start_index_in_embed_cache + ) + ids_to_set.append(uid) + + if ids_to_set: + self.cache_client.root.set_items_embed(ids=ids_to_set) + torch.cuda.current_stream().synchronize() diff --git a/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py b/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py new file mode 100644 index 000000000..dd9b54ee8 --- /dev/null +++ b/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py @@ -0,0 +1,408 @@ +# coding=utf-8 +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import json +import time +from PIL import Image +from io import BytesIO +from typing import List +from safetensors import safe_open +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers.activations import ACT2FN + +from lightllm.server.multimodal_params import ImageItem +from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data +from lightllm.models.qwen2_vl.vision_process import resize_image, Qwen2VLImageProcessor +from lightllm.models.qwen2_vl.qwen2_visual import VisionRotaryEmbedding, VisionFlashAttention + + +class Qwen3OmniMoeVisionMLP(nn.Module): + def __init__(self, hidden_size: int, intermediate_size: int, hidden_act: str): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.linear_fc1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=True) + self.linear_fc2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=True) + self.act_fn = ACT2FN[hidden_act] + + def forward(self, hidden_state): + return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state))) + + +class Qwen3OmniMoeVisionPatchEmbed(nn.Module): + def __init__( + self, + patch_size: int = 14, + temporal_patch_size: int = 2, + in_channels: int = 3, + embed_dim: int = 1152, + ) -> None: + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.in_channels = in_channels + self.embed_dim = embed_dim + + kernel_size = [temporal_patch_size, patch_size, patch_size] + self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + target_dtype = self.proj.weight.dtype + if hidden_states.dtype != target_dtype: + hidden_states = hidden_states.to(target_dtype) + hidden_states = hidden_states.view( + -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size + ) + hidden_states = self.proj(hidden_states).view(-1, self.embed_dim) + return hidden_states + + +class Qwen3OmniMoeVisionPatchMerger(nn.Module): + def __init__(self, hidden_size, out_hidden_size, spatial_merge_size, use_postshuffle_norm=False) -> None: + super().__init__() + self.hidden_size = hidden_size * (spatial_merge_size ** 2) + self.use_postshuffle_norm = use_postshuffle_norm + self.ln_q = nn.LayerNorm(self.hidden_size if use_postshuffle_norm else hidden_size, eps=1e-6) + self.mlp = nn.ModuleList( + [ + nn.Linear(self.hidden_size, self.hidden_size), + nn.GELU(), + nn.Linear(self.hidden_size, out_hidden_size), + ] + ) + + def forward(self, hidden: torch.Tensor) -> torch.Tensor: + hidden = self.ln_q(hidden.view(-1, self.hidden_size) if self.use_postshuffle_norm else hidden).view( + -1, self.hidden_size + ) + for layer in self.mlp: + hidden = layer(hidden) + return hidden + + +class Qwen3OmniMoeVisionBlock(nn.Module): + def __init__(self, hidden_size, intermediate_size, num_heads, hidden_act) -> None: + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, eps=1e-6) + self.norm2 = nn.LayerNorm(hidden_size, eps=1e-6) + + self.attn = VisionFlashAttention(hidden_size, num_heads=num_heads) + self.mlp = Qwen3OmniMoeVisionMLP(hidden_size, intermediate_size, hidden_act) + + def forward(self, hidden_states, cu_seqlens, max_seqlen, rotary_cos, rotary_sin) -> torch.Tensor: + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + rotary_cos=rotary_cos, + rotary_sin=rotary_sin, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +class Qwen3OmniMoeVisionTransformerPretrainedModel(nn.Module): + def __init__( + self, + kvargs, + depth=27, + out_hidden_size=4096, + hidden_size=1152, + hidden_act="gelu_pytorch_tanh", + intermediate_size=4304, + deepstack_visual_indexes=[8, 16, 24], + num_heads=16, + in_channels=3, + patch_size=16, + spatial_merge_size=2, + temporal_patch_size=2, + num_position_embeddings=2304, + **kwargs, + ): + super().__init__() + self.data_type = kvargs.get("data_type", "bfloat16") + + self.depth = depth + self.out_hidden_size = out_hidden_size + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.num_heads = num_heads + self.in_channels = in_channels + self.patch_size = patch_size + self.intermediate_size = intermediate_size + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + self.num_position_embeddings = num_position_embeddings + self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size + + self.deepstack_visual_indexes = deepstack_visual_indexes + self.merger_list = nn.ModuleList( + [ + Qwen3OmniMoeVisionPatchMerger( + hidden_size=self.hidden_size, + out_hidden_size=self.out_hidden_size, + spatial_merge_size=self.spatial_merge_size, + use_postshuffle_norm=True, + ) + for _ in range(len(self.deepstack_visual_indexes)) + ] + ) + + self.patch_embed = Qwen3OmniMoeVisionPatchEmbed( + patch_size=self.patch_size, + temporal_patch_size=self.temporal_patch_size, + in_channels=self.in_channels, + embed_dim=self.hidden_size, + ) + + self.pos_embed = nn.Embedding(self.num_position_embeddings, self.hidden_size) + self.num_grid_per_side = int(self.num_position_embeddings ** 0.5) + + head_dim = self.hidden_size // self.num_heads + self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2).cuda() + + self.blocks = nn.ModuleList( + [ + Qwen3OmniMoeVisionBlock(self.hidden_size, self.intermediate_size, self.num_heads, self.hidden_act) + for _ in range(self.depth) + ] + ) + self.merger = Qwen3OmniMoeVisionPatchMerger( + hidden_size=self.hidden_size, + out_hidden_size=self.out_hidden_size, + spatial_merge_size=self.spatial_merge_size, + use_postshuffle_norm=False, + ) + self._init_datatype() + + def _init_datatype(self): + if isinstance(self.data_type, torch.dtype): + return + if self.data_type in ["fp16", "float16"]: + self.data_type = torch.float16 + elif self.data_type in ["bf16", "bfloat16"]: + self.data_type = torch.bfloat16 + elif self.data_type in ["fp32", "float32"]: + self.data_type = torch.float32 + else: + raise ValueError(f"Unsupport datatype {self.data_type}!") + return + + def concat_img_embed_and_deepstack_features(self, image_embed, deepstack_feature_lists, valid_ids): + all_chunks = [] + + for start, end in valid_ids: + hs_i = image_embed[start:end] + ds_i_list = [feat[start:end] for feat in deepstack_feature_lists] + combined_i = torch.cat([hs_i, *ds_i_list], dim=1).view((end - start), len(ds_i_list) + 1, hs_i.shape[-1]) + all_chunks.append(combined_i) + + all_img_embeds_ds = torch.cat(all_chunks, dim=0) + return all_img_embeds_ds, valid_ids + + def load_model(self, weight_dir): + + processor_config_path = os.path.join(weight_dir, "preprocessor_config.json") + with open(processor_config_path, "r") as f: + processor_config_dict = json.load(f) + self.processor = Qwen2VLImageProcessor(**processor_config_dict) + + bin_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")] + if bin_weight_files: + weight_dict = {} + for file_ in bin_weight_files: + f = torch.load(os.path.join(weight_dir, file_), "cpu") + for k, v in f.items(): + if "thinker.visual" in k: + weight_dict[k[len("thinker.visual.") :]] = v + else: + hf_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".safetensors")] + weight_dict = {} + for file_ in hf_weight_files: + f = safe_open(os.path.join(weight_dir, file_), "pt", "cpu") + for k in f.keys(): + if "thinker.visual" in k: + weight_dict[k[len("thinker.visual.") :]] = f.get_tensor(k) + + self.load_state_dict(weight_dict) + + def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + merge_size = self.spatial_merge_size + + total_tokens = int(torch.prod(grid_thw, dim=1).sum().item()) + pos_ids = torch.empty((total_tokens, 2), dtype=torch.long) + + offset = 0 + for num_frames, height, width in grid_thw: + merged_h, merged_w = height // merge_size, width // merge_size + + block_rows = torch.arange(merged_h) # block row indices + block_cols = torch.arange(merged_w) # block col indices + intra_row = torch.arange(merge_size) # intra-block row offsets + intra_col = torch.arange(merge_size) # intra-block col offsets + + # Compute full-resolution positions + row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None] + col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :] + + row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) + col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) + + coords = torch.stack((row_idx, col_idx), dim=-1) + + if num_frames > 1: + coords = coords.repeat(num_frames, 1) + + num_tokens = coords.shape[0] + pos_ids[offset : offset + num_tokens] = coords + offset += num_tokens + + max_hw = int(grid_thw[:, 1:].max().item()) + cos_full, sin_full = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2) + cos = cos_full[pos_ids].flatten(1) + sin = sin_full[pos_ids].flatten(1) + return cos, sin + + def fast_pos_embed_interpolate(self, grid_thw): + grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2] + + idx_list = [[] for _ in range(4)] + weight_list = [[] for _ in range(4)] + + for t, h, w in zip(grid_ts, grid_hs, grid_ws): + h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) + w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w) + + h_idxs_floor = h_idxs.int() + w_idxs_floor = w_idxs.int() + h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + + dh = h_idxs - h_idxs_floor + dw = w_idxs - w_idxs_floor + + base_h = h_idxs_floor * self.num_grid_per_side + base_h_ceil = h_idxs_ceil * self.num_grid_per_side + + indices = [ + (base_h[None].T + w_idxs_floor[None]).flatten(), + (base_h[None].T + w_idxs_ceil[None]).flatten(), + (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), + (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), + ] + + weights = [ + ((1 - dh)[None].T * (1 - dw)[None]).flatten(), + ((1 - dh)[None].T * dw[None]).flatten(), + (dh[None].T * (1 - dw)[None]).flatten(), + (dh[None].T * dw[None]).flatten(), + ] + + for i in range(4): + idx_list[i].extend(indices[i].tolist()) + weight_list[i].extend(weights[i].tolist()) + + idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=self.pos_embed.weight.device) + weight_tensor = torch.tensor( + weight_list, dtype=self.pos_embed.weight.dtype, device=self.pos_embed.weight.device + ) + pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None] + patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] + + patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)]) + + patch_pos_embeds_permute = [] + merge_size = self.spatial_merge_size + for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): + pos_embed = pos_embed.repeat(t, 1) + pos_embed = ( + pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1) + .permute(0, 1, 3, 2, 4, 5) + .flatten(0, 4) + ) + patch_pos_embeds_permute.append(pos_embed) + patch_pos_embeds = torch.cat(patch_pos_embeds_permute) + return patch_pos_embeds + + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: + hidden_states = self.patch_embed(hidden_states) + pos_embeds = self.fast_pos_embed_interpolate(grid_thw) + hidden_states = hidden_states + pos_embeds + rotary_cos, rotary_sin = self.rot_pos_emb(grid_thw) + rotary_cos = rotary_cos.to("cuda", non_blocking=True) + rotary_sin = rotary_sin.to("cuda", non_blocking=True) + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, + dtype=torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0).to("cuda", non_blocking=True) + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + deepstack_feature_lists = [] + for layer_num, blk in enumerate(self.blocks): + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + rotary_cos=rotary_cos, + rotary_sin=rotary_sin, + ) + if layer_num in self.deepstack_visual_indexes: + deepstack_feature = self.merger_list[self.deepstack_visual_indexes.index(layer_num)](hidden_states) + deepstack_feature_lists.append(deepstack_feature) + + hidden_states = self.merger(hidden_states) + print(f"hidden_states is {hidden_states}, deepstack is {deepstack_feature_lists}") + return hidden_states, deepstack_feature_lists + + def encode(self, images: List[ImageItem]): + img_tensors = [] + valid_ids = [] + valid_id = 0 + img_grids = [] + uuids = [] + for i, img in enumerate(images): + if isinstance(img, ImageItem): + uuids.append(img.uuid) + image_data = read_shm(get_shm_name_data(img.uuid)) + image_data = Image.open(BytesIO(image_data)) + pixel_values, image_grid_thw = self.processor.preprocess(image_data) + img_tensors.append(pixel_values) + img_grids.append(image_grid_thw) + else: + raise Exception("Unsupport input types: {} for {}".format(type(img), img)) + + # must devide merge_length + cur_num = img_tensors[-1].shape[0] // (self.spatial_merge_size ** 2) + + valid_ids.append([valid_id, valid_id + cur_num]) + valid_id += cur_num + + if len(img_tensors) <= 0: + return None + + imgs = torch.cat(img_tensors, dim=0) + grid_thw = torch.cat(img_grids, dim=0) + + pixel_values = imgs.to("cuda", dtype=self.data_type, non_blocking=True) + image_grid_thw = grid_thw.to("cuda", non_blocking=True) + img_embeds, deepstack_feature_lists = self.forward(pixel_values, grid_thw=image_grid_thw) + all_img_embeds_df, valid_ids = self.concat_img_embed_and_deepstack_features( + img_embeds, deepstack_feature_lists, valid_ids + ) + + return all_img_embeds_df, uuids, valid_ids diff --git a/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py index 4ccc6da37..0bf716fc5 100644 --- a/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py @@ -23,8 +23,10 @@ def _get_qkv( layer_weight: Qwen3MOETransformerLayerWeight, ) -> Tuple[torch.Tensor, torch.Tensor]: input = input.view(-1, self.embed_dim_) - q = layer_weight.q_proj.mm(input) - cache_kv = layer_weight.kv_proj.mm(input) + qkv = layer_weight.qkv_proj.mm(input) + q, cache_kv = qkv.split( + [self.tp_q_head_num_ * self.head_dim_, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_], dim=-1 + ) layer_weight.q_norm_weight_( q, eps=self.eps_, diff --git a/lightllm/server/audioserver/model_infer/model_rpc.py b/lightllm/server/audioserver/model_infer/model_rpc.py index f3c813800..ad18ffd7f 100644 --- a/lightllm/server/audioserver/model_infer/model_rpc.py +++ b/lightllm/server/audioserver/model_infer/model_rpc.py @@ -4,6 +4,7 @@ from typing import Dict, List, Tuple from transformers.configuration_utils import PretrainedConfig from lightllm.models.whisper.whisper_audio import WhisperAudioModel +from lightllm.models.qwen3_omni_moe_thinker.qwen3_omni_audio import Qwen3OmniMoeAudioEncoder from lightllm.server.multimodal_params import AudioItem from lightllm.utils.infer_utils import set_random_seed from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient @@ -19,6 +20,9 @@ def exposed_init_model(self, kvargs): weight_dir = kvargs["weight_dir"] model_cfg, _ = PretrainedConfig.get_config_dict(weight_dir) + if model_cfg.get("thinker_config") is not None: + model_cfg = model_cfg["thinker_config"] + audio_config = model_cfg["audio_config"] model_kvargs = {"cache_port": kvargs["cache_port"], "data_type": kvargs["data_type"]} @@ -26,6 +30,8 @@ def exposed_init_model(self, kvargs): self.model_type = audio_config["model_type"] if self.model_type == "clap_audio_model" or self.model_type == "whisper": self.model = WhisperAudioModel(model_kvargs) + elif self.model_type == "qwen3_omni_moe_audio_encoder": + self.model = Qwen3OmniMoeAudioEncoder(model_kvargs).eval().bfloat16() else: raise Exception(f"can not support {self.model_type} now") diff --git a/lightllm/server/embed_cache/embed_cache_client.py b/lightllm/server/embed_cache/embed_cache_client.py index ad01dd808..8c5b7f71e 100644 --- a/lightllm/server/embed_cache/embed_cache_client.py +++ b/lightllm/server/embed_cache/embed_cache_client.py @@ -47,6 +47,19 @@ def copy_to_cache(self, embed_tensor: torch.Tensor, start_index_in_cache: int): start_index_in_cache=start_index_in_cache, ) + def copy_vision_to_cache(self, embed_tensor: torch.Tensor, start_index_in_cache: int): + from .copy_to_cache import offload_embed_tensor_to_cache + + if embed_tensor.ndim == 3: + # check for qwen3 vision embed tensor shape, use apply deepstack + assert embed_tensor.shape[1] == self.cpu_embed_cache_tensor.shape[1] + + offload_embed_tensor_to_cache( + embed_tensor=embed_tensor, + cache_tensor=self.cpu_embed_cache_tensor, + start_index_in_cache=start_index_in_cache, + ) + def _create_shm_embed_kv_cache(self): shm_ptr = create_shm_kv_cache_ptr( key=self.args.multi_modal_cache_shm_id, size=self.embed_cache_tensor_meta.calcu_size() diff --git a/lightllm/server/tokenizer.py b/lightllm/server/tokenizer.py index e0b2bd425..09bc938f2 100644 --- a/lightllm/server/tokenizer.py +++ b/lightllm/server/tokenizer.py @@ -31,6 +31,7 @@ from ..models.qwen3_vl.model import QWen3VLTokenizer from ..models.internvl.model import InternvlTokenizer from ..models.gemma3.model import Gemma3Tokenizer +from ..models.qwen3_omni_moe_thinker.model import QWen3OmniTokenizer # A fast LLaMA tokenizer with the pre-processed `tokenizer.json` file. _FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer" @@ -100,6 +101,12 @@ def get_tokenizer( tokenizer = QWen3VLTokenizer( tokenizer=tokenizer, image_processor=processor.image_processor, model_cfg=model_cfg ) + elif model_cfg.get("thinker_config") is not None: + from transformers import AutoProcessor + + model_cfg = model_cfg["thinker_config"] + processor = AutoProcessor.from_pretrained(tokenizer_name) + tokenizer = QWen3OmniTokenizer(tokenizer, processor=processor, model_cfg=model_cfg) elif model_type == "internvl_chat": tokenizer = InternvlTokenizer(tokenizer, model_cfg, weight_dir=tokenizer_name) elif model_type == "gemma3": diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 8f4a1ee45..3e97f4de3 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -19,6 +19,7 @@ from lightllm.models.qwen2_5_vl.qwen2_5_visual import Qwen2_5_VisionTransformerPretrainedModel from lightllm.models.qwen3_vl.qwen3_visual import Qwen3VisionTransformerPretrainedModel from lightllm.models.tarsier2.tarsier2_visual import TarsierVisionTransformerPretrainedModel +from lightllm.models.qwen3_omni_moe_thinker.qwen3_omni_visual import Qwen3OmniMoeVisionTransformerPretrainedModel from lightllm.utils.infer_utils import set_random_seed from lightllm.utils.dist_utils import init_vision_distributed_env from lightllm.utils.graceful_utils import graceful_registry @@ -80,6 +81,15 @@ def exposed_init_model(self, kvargs): # self.model = InternVLVisionModel() elif self.model_type == "gemma3": self.model = Gemma3VisionModel() + elif ( + model_cfg.get("thinker_config", {}).get("vision_config", {}).get("model_type") + == "qwen3_omni_moe_vision_encoder" + ): + self.model = ( + Qwen3OmniMoeVisionTransformerPretrainedModel(kvargs, **model_cfg["thinker_config"]["vision_config"]) + .eval() + .bfloat16() + ) else: raise Exception(f"can not support {self.model_type} now") @@ -117,7 +127,7 @@ def exposed_encode(self, images: List[ImageItem]): uid = uuids[i] start, end = valid_ids[i] image = images[i] - self.cpu_embed_cache_client.copy_to_cache( + self.cpu_embed_cache_client.copy_vision_to_cache( embed_tensor=all_img_embeds[start:end], start_index_in_cache=image.start_index_in_embed_cache ) ids_to_set.append(uid) diff --git a/lightllm/utils/config_utils.py b/lightllm/utils/config_utils.py index b06309f96..790f185f2 100644 --- a/lightllm/utils/config_utils.py +++ b/lightllm/utils/config_utils.py @@ -25,6 +25,8 @@ def _get_config_llm_keyvalue(model_path: str, key_name: list[str]): value = config_json["llm_config"][key] except: value = config_json.get("text_config", {}).get(key) + if config_json.get("thinker_config") is not None: + value = config_json.get("thinker_config", {}).get("text_config").get(key) if value is not None: return value @@ -77,6 +79,14 @@ def get_layer_num(model_path: str) -> int: def get_eos_token_ids(model_path: str) -> Optional[List[int]]: + try: + # qwen3-omini special eos_token_id + config_json = get_config_json(model_path) + assert config_json["architectures"][0] == "Qwen3OmniMoeForConditionalGeneration" + return [151645] + except: + pass + eos_token_id = _get_config_llm_keyvalue(model_path=model_path, key_name=["eos_token_id"]) if isinstance(eos_token_id, int): return [eos_token_id] @@ -100,6 +110,9 @@ def get_model_architectures(model_path: str): def get_vocab_size(model_path: str): try: config_json = get_config_json(model_path) + # qwen3-omini special + if "thinker_config" in config_json: + config_json = config_json["thinker_config"] if "llm_config" in config_json: vocab_size = int(config_json["llm_config"]["vocab_size"]) return vocab_size diff --git a/lightllm/utils/embed_utils.py b/lightllm/utils/embed_utils.py index 81b05dc29..dec1537d6 100644 --- a/lightllm/utils/embed_utils.py +++ b/lightllm/utils/embed_utils.py @@ -27,12 +27,12 @@ def calcu_embed_cache_meta() -> "EmbedCacheMeta": args = get_env_start_args() assert args.enable_multimodal from lightllm.utils.llm_utils import get_llm_model_class - from lightllm.models import Qwen3VLTpPartModel, Qwen3VLMOETpPartModel + from lightllm.models import Qwen3VLTpPartModel, Qwen3VLMOETpPartModel, Qwen3OmniMOETpPartModel model_class = get_llm_model_class() model_dir = args.model_dir - if model_class in [Qwen3VLTpPartModel, Qwen3VLMOETpPartModel]: + if model_class in [Qwen3VLTpPartModel, Qwen3VLMOETpPartModel, Qwen3OmniMOETpPartModel]: embed_cache_meta_data = EmbedCacheMeta( token_num=None, layer_num=4, diff --git a/requirements.txt b/requirements.txt index a3b9473f8..89c52ebf8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -61,8 +61,8 @@ sortedcontainers==2.4.0 toolz==0.12.0 torch==2.8.0 tqdm==4.65.0 -transformers==4.53.3 -tokenizers==0.21.1 +transformers==4.57.1 +tokenizers==0.22.1 urllib3==1.26.16 uvicorn==0.19.0 uvloop==0.17.0