Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion lightllm/common/basemodel/layer_weights/base_layer_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@ def verify_load(self):
for attr_name in dir(self):
attr = getattr(self, attr_name)
if isinstance(attr, BaseWeight):
assert attr.verify_load(), f"Loading {attr_name} of layers {self.layer_num_} fails."
if hasattr(self, "layer_num_"):
layer_num = self.layer_num_
else:
layer_num = None
assert attr.verify_load(), f"Loading {attr_name} of layers {layer_num} fails."

def _cuda(self, cpu_tensor):
return cpu_tensor.contiguous().to(self.data_type_).cuda(get_current_device_id())
1 change: 1 addition & 0 deletions lightllm/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,5 @@
Tarsier2LlamaTpPartModel,
)
from lightllm.models.gpt_oss.model import GptOssTpPartModel
from lightllm.models.qwen3_omni_moe_thinker.model import Qwen3OmniMOETpPartModel
from .registry import get_model, get_model_class
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ def _get_mrope_position_triton(
local_image_start_idx = tl.load(b_image_start_idx + image_start_num + i)
image_start_idx = start_loc + local_image_start_idx - cache_len
image_len = tl.load(b_image_len + image_start_num + i)
image_h = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 1)
# image_h = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 1)
image_w = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 2)
for j in range(0, image_len, BLOCK_SIZE):
off = j + tl.arange(0, BLOCK_SIZE)
# 目前没考虑视频,所以t 恒为 0
t_pos = local_image_start_idx + off * 0
h_pos = local_image_start_idx + off // image_h
h_pos = local_image_start_idx + off // image_w
w_pos = local_image_start_idx + off % image_w
tl.store(
position_ids + off + image_start_idx,
Expand Down
Empty file.
185 changes: 185 additions & 0 deletions lightllm/models/qwen3_omni_moe_thinker/audio_process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
import torch
import numpy as np
from typing import TYPE_CHECKING, Any, Optional, Union, Tuple
from transformers.audio_utils import mel_filter_bank, spectrogram, window_function
from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
from transformers.feature_extraction_utils import BatchFeature
from transformers.utils import TensorType


class WhisperFeatureExtractor(SequenceFeatureExtractor):

model_input_names = ["input_features"]

def __init__(
self,
feature_size=80,
sampling_rate=16000,
hop_length=160,
chunk_length=30,
n_fft=400,
padding_value=0.0,
dither=0.0,
return_attention_mask=False, # pad inputs to max length with silence token (zero) and no attention mask
**kwargs,
):
super().__init__(
feature_size=feature_size,
sampling_rate=sampling_rate,
padding_value=padding_value,
return_attention_mask=return_attention_mask,
**kwargs,
)
self.n_fft = n_fft
self.hop_length = hop_length
self.chunk_length = chunk_length
self.n_samples = chunk_length * sampling_rate
self.nb_max_frames = self.n_samples // hop_length
self.sampling_rate = sampling_rate
self.dither = dither
self.mel_filters = mel_filter_bank(
num_frequency_bins=1 + n_fft // 2,
num_mel_filters=feature_size,
min_frequency=0.0,
max_frequency=8000.0,
sampling_rate=sampling_rate,
norm="slaney",
mel_scale="slaney",
)

def _torch_extract_fbank_features(self, waveform: np.ndarray, device: str = "cpu") -> np.ndarray:
waveform = torch.from_numpy(waveform).to(device, torch.float32)
window = torch.hann_window(self.n_fft, device=device)

if self.dither != 0.0:
waveform += self.dither * torch.randn(waveform.shape, dtype=waveform.dtype, device=waveform.device)

stft = torch.stft(waveform, self.n_fft, self.hop_length, window=window, return_complex=True)
magnitudes = stft[..., :-1].abs() ** 2

mel_filters = torch.from_numpy(self.mel_filters).to(device, torch.float32)
mel_spec = mel_filters.T @ magnitudes

log_spec = torch.clamp(mel_spec, min=1e-10).log10()
if waveform.dim() == 2:
max_val = log_spec.max(dim=2, keepdim=True)[0].max(dim=1, keepdim=True)[0]
log_spec = torch.maximum(log_spec, max_val - 8.0)
else:
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
if device != "cpu":
log_spec = log_spec.detach().cpu()
return log_spec.numpy()

# Copied from transformers.models.wav2vec2.feature_extraction_wav2vec2.
# Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm
def zero_mean_unit_var_norm(
self, input_values: list[np.ndarray], attention_mask: list[np.ndarray], padding_value: float = 0.0
) -> list[np.ndarray]:
if attention_mask is not None:
attention_mask = np.array(attention_mask, np.int32)
normed_input_values = []

for vector, length in zip(input_values, attention_mask.sum(-1)):
normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7)
if length < normed_slice.shape[0]:
normed_slice[length:] = padding_value

normed_input_values.append(normed_slice)
else:
normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values]

return normed_input_values

def _preprocess(
self,
raw_speech: Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]],
truncation: bool = True,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_attention_mask: Optional[bool] = None,
padding: Optional[str] = "longest", # max_length代表padding到max_length
max_length: Optional[int] = None,
sampling_rate: Optional[int] = 16000,
do_normalize: Optional[bool] = None,
device: Optional[str] = "cpu",
return_token_timestamps: Optional[bool] = None,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:

is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
if is_batched_numpy and len(raw_speech.shape) > 2:
raise ValueError(f"Only mono-channel audio is supported for input to {self}")
is_batched = is_batched_numpy or (
isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list)))
)

if is_batched:
raw_speech = [np.asarray([speech], dtype=np.float32).T for speech in raw_speech]
elif not is_batched and not isinstance(raw_speech, np.ndarray):
raw_speech = np.asarray(raw_speech, dtype=np.float32)
elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64):
raw_speech = raw_speech.astype(np.float32)

# always return batch
if not is_batched:
raw_speech = [np.asarray([raw_speech]).T]

batched_speech = BatchFeature({"input_features": raw_speech})

# convert into correct format for padding

padded_inputs = self.pad(
batched_speech,
padding=padding,
max_length=max_length if max_length else self.n_samples,
truncation=truncation,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask or do_normalize,
)

# zero-mean and unit-variance normalization
if do_normalize:
padded_inputs["input_features"] = self.zero_mean_unit_var_norm(
padded_inputs["input_features"],
attention_mask=padded_inputs["attention_mask"],
padding_value=self.padding_value,
)
padded_inputs["input_features"] = np.stack(padded_inputs["input_features"], axis=0)

# make sure list is in array format
input_features = padded_inputs.get("input_features").transpose(2, 0, 1)

input_features = self._torch_extract_fbank_features(input_features[0], device)

if isinstance(input_features[0], list):
padded_inputs["input_features"] = [np.asarray(feature, dtype=np.float32) for feature in input_features]

else:
padded_inputs["input_features"] = input_features

if return_attention_mask:
# rescale from sample (48000) to feature (3000)
rescaled_attention_mask = padded_inputs["attention_mask"][:, :: self.hop_length]

# The STFT computation produces L//hop_length + 1 frames,
# but we skip the last frame (see `_torch_extract_fbank_features`).
# This means we need to trim the rescaled attention mask to match
# the actual number of frames (L//hop_length) when the input length
# is not perfectly divisible by the hop length.
if padded_inputs["attention_mask"].shape[1] % self.hop_length != 0:
rescaled_attention_mask = rescaled_attention_mask[:, :-1]
padded_inputs["attention_mask"] = rescaled_attention_mask

if return_token_timestamps is not None:
padded_inputs["num_frames"] = [len(raw_speech_i) // self.hop_length for raw_speech_i in raw_speech]

if return_tensors is not None:
padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
input_features = torch.from_numpy(np.asarray(padded_inputs["input_features"], dtype=np.float32)).to(
device="cuda", dtype=torch.bfloat16
)
attention_mask = torch.from_numpy(np.asarray(padded_inputs["attention_mask"], dtype=np.float32)).to(
device="cuda", dtype=torch.int32
)
return input_features, attention_mask
6 changes: 6 additions & 0 deletions lightllm/models/qwen3_omni_moe_thinker/infer_struct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo


class Qwen3OmniMOEInferStateInfo(Qwen3VLInferStateInfo):
def __init__(self):
super().__init__()
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch
from lightllm.models.qwen3_vl_moe.layer_infer.transformer_layer_infer import Qwen3VLMOETransformerLayerInfer
from lightllm.utils.log_utils import init_logger

logger = init_logger(__name__)


class Qwen3OmniMOETransformerLayerInfer(Qwen3VLMOETransformerLayerInfer):
def __init__(self, layer_num, network_config):
super().__init__(layer_num, network_config)
self.head_dim_ = network_config["head_dim"]
self.mrope_section = torch.tensor(
network_config["rope_scaling"]["mrope_section"], dtype=torch.int32, device="cuda"
)
return
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from lightllm.models.qwen2.layer_weights.pre_and_post_layer_weight import Qwen2PreAndPostLayerWeight
from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LMHeadWeight, RMSNormWeight


class Qwen3OmniMOEThinkerPreAndPostLayerWeight(Qwen2PreAndPostLayerWeight):
def __init__(self, data_type, network_config):
super().__init__(data_type, network_config)

hidden_size = network_config["hidden_size"]
vocab_size = network_config["vocab_size"]
self.wte_weight_ = EmbeddingWeight(
dim=hidden_size,
vocab_size=vocab_size,
weight_name="thinker.model.embed_tokens.weight",
data_type=self.data_type_,
)
tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False)
self.lm_head_weight_ = LMHeadWeight(
dim=hidden_size,
vocab_size=vocab_size,
weight_name="thinker.lm_head.weight",
data_type=self.data_type_,
embedding_weight=self.wte_weight_ if tie_word_embeddings else None,
)
self.final_norm_weight_ = RMSNormWeight(
dim=hidden_size,
weight_name="thinker.model.norm.weight",
data_type=self.data_type_,
)
return
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import os
from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight
from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, FusedMoeWeight


class Qwen3OmniMOEThinkerTransformerLayerWeight(Qwen3MOETransformerLayerWeight):
def __init__(self, layer_num, data_type, network_config, quant_cfg=None):
super().__init__(layer_num, data_type, network_config, quant_cfg)
return

def _init_weight_names(self):
self._q_weight_name = f"thinker.model.layers.{self.layer_num_}.self_attn.q_proj.weight"
self._q_norm_name = f"thinker.model.layers.{self.layer_num_}.self_attn.q_norm.weight"
self._q_bias_name = None
self._k_weight_name = f"thinker.model.layers.{self.layer_num_}.self_attn.k_proj.weight"
self._k_norm_name = f"thinker.model.layers.{self.layer_num_}.self_attn.k_norm.weight"
self._k_bias_name = None
self._v_weight_name = f"thinker.model.layers.{self.layer_num_}.self_attn.v_proj.weight"
self._v_bias_name = None
self._kv_weight_name = f"thinker.model.layers.{self.layer_num_}.self_attn.kv_proj.weight"
self._kv_bias_name = None
self._o_weight_name = f"thinker.model.layers.{self.layer_num_}.self_attn.o_proj.weight"
self._o_bias_name = None
self._att_norm_weight_name = f"thinker.model.layers.{self.layer_num_}.input_layernorm.weight"
self._att_norm_bias_name = None
self._ffn_norm_weight_name = f"thinker.model.layers.{self.layer_num_}.post_attention_layernorm.weight"
self._ffn_norm_bias_name = None

def _init_moe(self):
moe_intermediate_size = self.network_config_["moe_intermediate_size"]
self.moe_gate = ROWMMWeight(
in_dim=self.network_config_["hidden_size"],
out_dims=[self.n_routed_experts],
weight_names=f"thinker.model.layers.{self.layer_num_}.mlp.gate.weight",
data_type=self.data_type_,
quant_method=None,
tp_rank=0,
tp_world_size=1,
)
self.experts = FusedMoeWeight(
gate_proj_name="gate_proj",
down_proj_name="down_proj",
up_proj_name="up_proj",
e_score_correction_bias_name="",
weight_prefix=f"thinker.model.layers.{self.layer_num_}.mlp.experts",
n_routed_experts=self.n_routed_experts,
hidden_size=self.network_config_["hidden_size"],
moe_intermediate_size=moe_intermediate_size,
data_type=self.data_type_,
quant_method=self.quant_cfg.get_quant_method(self.layer_num_, "fused_moe"),
layer_num=self.layer_num_,
network_config=self.network_config_,
)
Loading