diff --git a/.gitignore b/.gitignore index 63408699f4..3fb49db8b1 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ dist .vscode tmp/ requirements-musa.txt +CLAUDE.md diff --git a/lightllm/common/basemodel/attention/base_att.py b/lightllm/common/basemodel/attention/base_att.py index 859d97ca84..6429bce9a0 100644 --- a/lightllm/common/basemodel/attention/base_att.py +++ b/lightllm/common/basemodel/attention/base_att.py @@ -65,6 +65,7 @@ class AttControl: mla_prefill_dict: Dict = None mla_decode: bool = False mla_decode_dict: Dict = None + scale: float = None @dataclass diff --git a/lightllm/common/basemodel/attention/fa3/fp.py b/lightllm/common/basemodel/attention/fa3/fp.py index 952bb39d91..2f5fccd57b 100644 --- a/lightllm/common/basemodel/attention/fa3/fp.py +++ b/lightllm/common/basemodel/attention/fa3/fp.py @@ -220,8 +220,11 @@ def _normal_decode_att( sink_weight = None k_descale, v_descale = None, None # disable quantization - Lq = q.shape[-1] - sm_scale = 1.0 / (Lq ** 0.5) + if att_control.scale is not None: + sm_scale = att_control.scale + else: + Lq = q.shape[-1] + sm_scale = 1.0 / (Lq ** 0.5) o = flash_attn_with_kvcache( q=q, k_cache=k.view(k.shape[0], 1, k.shape[1], k.shape[2]), diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 5c1d2b8712..6702653c8e 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -6,7 +6,7 @@ import json import torch import torch.nn.functional as F -from typing import final, List +from typing import final, List, Optional from tqdm import tqdm from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights @@ -32,6 +32,10 @@ from lightllm.utils.envs_utils import set_model_init_status, enable_diverse_mode_gqa_decode_fast_kernel from lightllm.common.triton_utils.autotuner import Autotuner from lightllm.utils.infer_utils import post_empty_cache +from lightllm.utils.torch_memory_saver_utils import ( + TorchMemorySaverWrapper, + MemoryTag, +) from .attention import get_prefill_att_backend_class, get_decode_att_backend_class from .attention import BaseAttBackend @@ -90,6 +94,7 @@ def __init__(self, kvargs): self.tp_world_size_ = get_dp_world_size() self.enable_tpsp_mix_mode = get_env_start_args().enable_tpsp_mix_mode + self.torch_memory_saver = TorchMemorySaverWrapper(self.args.enable_torch_memory_saver) self.is_mtp_mode = self.args.mtp_mode in [ "vanilla_with_att", "eagle_with_att", @@ -103,15 +108,17 @@ def __init__(self, kvargs): self._verify_params() self._init_quant() - self._init_weights() - self._init_mem_manager() - self._init_kv_move_buffer() + with self.torch_memory_saver.region(tag=MemoryTag.WEIGHT, enable_cpu_backup=self.args.enable_weight_cpu_backup): + self._init_weights() + with self.torch_memory_saver.region(tag=MemoryTag.KV_CACHE): + self._init_mem_manager() + self._init_kv_move_buffer() + self._init_req_manager() self._check_mem_size() - self._init_req_manager() self._init_infer_layer() self._init_some_value() self._init_custom() - self._load_hf_weights() + self.load_weights(self.weight_dict) # wait必须在init cudagraph 之前,避免错误捕获 self._wait_other_modules_ready() @@ -176,17 +183,15 @@ def _init_weights(self, start_layer_index=0): ] return - def _load_hf_weights(self): + def load_weights(self, weight_dict: dict): + assert weight_dict is None or isinstance(weight_dict, dict), "weight_dict must be a dict or None" load_hf_weights( self.data_type, - weight_dir=self.weight_dir_, + self.weight_dir_, pre_post_layer=self.pre_post_weight, transformer_layer_list=self.trans_layers_weight, - weight_dict=self.weight_dict, + weight_dict=weight_dict, ) - self.pre_post_weight.verify_load() - [weight.verify_load() for weight in self.trans_layers_weight] - return def _init_mem_manager(self): assert self.config["num_attention_heads"] % self.tp_world_size_ == 0 @@ -884,6 +889,7 @@ def _check_max_len_infer(self): ) logger.error(exception_str) raise Exception(exception_str) + torch.cuda.empty_cache() return def autotune_layers(self): @@ -1012,6 +1018,9 @@ def _init_padded_req(self): del b_seq_len del b_ready_cache_len del model_output + del b_mtp_index + del b_prefill_start_loc + del b_q_seq_len torch.cuda.empty_cache() return @@ -1032,3 +1041,71 @@ def _gen_special_model_input(self, token_num: int): special_model_input["mtp_draft_input_hiddens"] = None return special_model_input + + def release_memory_occupation(self, tags: Optional[List[MemoryTag]]): + if tags is None: + self.release_all() + return + if MemoryTag.WEIGHT in tags: + self.release_weight() + if MemoryTag.KV_CACHE in tags: + self.release_kv_cache() + if MemoryTag.GRAPH in tags: + self.release_graph() + return + + def resume_memory_occupation(self, tags: Optional[List[MemoryTag]]): + if tags is None: + self.resume_all() + return + if MemoryTag.WEIGHT in tags: + self.resume_weight() + if MemoryTag.KV_CACHE in tags: + self.resume_kv_cache() + if MemoryTag.GRAPH in tags: + self.resume_graph() + return + + def release_weight(self): + self.torch_memory_saver.pause(tag=MemoryTag.WEIGHT) + torch.cuda.empty_cache() + gc.collect() + + def release_kv_cache(self): + self.torch_memory_saver.pause(tag=MemoryTag.KV_CACHE) + torch.cuda.empty_cache() + gc.collect() + + def release_graph(self): + self.torch_memory_saver.pause(tag=MemoryTag.GRAPH) + torch.cuda.empty_cache() + gc.collect() + + def release_all(self): + self.torch_memory_saver.pause(tag=MemoryTag.WEIGHT) + self.torch_memory_saver.pause(tag=MemoryTag.KV_CACHE) + self.torch_memory_saver.pause(tag=MemoryTag.GRAPH) + torch.cuda.empty_cache() + gc.collect() + + def resume_weight(self): + torch.cuda.empty_cache() + gc.collect() + self.torch_memory_saver.resume(tag=MemoryTag.WEIGHT) + + def resume_kv_cache(self): + torch.cuda.empty_cache() + gc.collect() + self.torch_memory_saver.resume(tag=MemoryTag.KV_CACHE) + + def resume_graph(self): + torch.cuda.empty_cache() + gc.collect() + self.torch_memory_saver.resume(tag=MemoryTag.GRAPH) + + def resume_all(self): + torch.cuda.empty_cache() + gc.collect() + self.torch_memory_saver.resume(tag=MemoryTag.WEIGHT) + self.torch_memory_saver.resume(tag=MemoryTag.KV_CACHE) + self.torch_memory_saver.resume(tag=MemoryTag.GRAPH) diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index dd29c9a833..2007e4db4e 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -7,6 +7,10 @@ from lightllm.utils.envs_utils import get_env_start_args from lightllm.distributed import dist_group_manager, lightllm_capture_graph, CustomProcessGroup from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput +from lightllm.utils.torch_memory_saver_utils import ( + TorchMemorySaverWrapper, + MemoryTag, +) from .infer_struct import InferStateInfo @@ -24,6 +28,7 @@ def __init__(self, max_batch_size=8, max_len_in_batch=8192): self.max_batch_size = max_batch_size self.graph_max_len_in_batch = max_len_in_batch self.enable_decode_microbatch_overlap = self.args.enable_decode_microbatch_overlap + self.torch_memory_saver = TorchMemorySaverWrapper(self.args.enable_torch_memory_saver) # gen cuda graph batch_sizes # cuda graph gen for batch size = [1, 2, 3, ..., graph_split_batch_size] @@ -89,7 +94,7 @@ def _capture_decode(self, decode_func, infer_state: InferStateInfo): delattr(infer_state, param_name) with lightllm_capture_graph(dist_group): - with torch.cuda.graph(graph_obj, pool=self.mempool): + with self.torch_memory_saver.cuda_graph(graph_obj, pool=self.mempool): model_output = decode_func(infer_state) self.graph[batch_size] = (graph_obj, infer_state, model_output) graph_obj.replay() @@ -127,7 +132,7 @@ def _capture_decode_overlap( with lightllm_capture_graph(dist_group1): with lightllm_capture_graph(dist_group): - with torch.cuda.graph(graph_obj, pool=self.mempool): + with self.torch_memory_saver.cuda_graph(graph_obj, pool=self.mempool): model_output, model_output1 = decode_func(infer_state, infer_state1) self.graph[batch_size] = ( graph_obj, diff --git a/lightllm/common/basemodel/layer_weights/hf_load_utils.py b/lightllm/common/basemodel/layer_weights/hf_load_utils.py index 8cf66a5ad6..ec0e282844 100755 --- a/lightllm/common/basemodel/layer_weights/hf_load_utils.py +++ b/lightllm/common/basemodel/layer_weights/hf_load_utils.py @@ -5,6 +5,8 @@ from tqdm import tqdm import lightllm.utils.petrel_helper as utils from lightllm.utils.dist_utils import get_current_device_id +from queue import Queue +from threading import Thread def load_func(file_, use_safetensors=False, pre_post_layer=None, transformer_layer_list=None, weight_dir=None): @@ -65,7 +67,6 @@ def load_hf_weights(data_type, weight_dir, pre_post_layer=None, transformer_laye iterator = p.imap_unordered(partial_func, candidate_files, chunksize=1) desc_str = f"pid {os.getpid()} Loading model weights with {worker} workers" iterator = tqdm(iterator, total=len(candidate_files), desc=desc_str) - for _ in iterator: pass diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py index 6bcf7fc03c..3dc888b6ac 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py @@ -33,6 +33,7 @@ def __init__( num_fused_shared_experts: int = 0, layer_num: int = 0, network_config: Dict[str, Any] = None, + moe_layer_index: int = 0, ) -> None: super().__init__(data_type=data_type) self.w1_weight_name = gate_proj_name @@ -50,6 +51,7 @@ def __init__( self.enable_ep_moe = get_env_start_args().enable_ep_moe self.n_routed_experts = n_routed_experts self.num_fused_shared_experts = num_fused_shared_experts + self.moe_layer_index = moe_layer_index self._init_config(network_config) self._init_redundancy_expert_params() self._init_parallel_params() @@ -130,6 +132,7 @@ def experts( topk_group: int, num_expert_group: int, is_prefill: Optional[bool] = None, + microbatch_index: int = 0, ) -> torch.Tensor: """Backward compatible method that routes to platform-specific implementation.""" return self.fuse_moe_impl( @@ -145,6 +148,8 @@ def experts( topk_group=topk_group, num_expert_group=num_expert_group, is_prefill=is_prefill, + moe_layer_index=self.moe_layer_index, + microbatch_index=microbatch_index, ) def low_latency_dispatch( @@ -295,6 +300,7 @@ def _create_weight(self): device_id=self.device_id_, num_experts=self.local_n_routed_experts, ) + self.w1, self.w3 = w13_param_list self.w1_list: List[WeightPack] = self._get_expert_weight_list(w13_param_list[0]) self.w3_list: List[WeightPack] = self._get_expert_weight_list(w13_param_list[1]) self.w2_list: List[WeightPack] = self._get_expert_weight_list(self.w2) @@ -312,6 +318,8 @@ def _load_weight(self, expert_idx_to_local_idx: Dict[int, int], weights: Dict[st for expert_idx, local_expert_idx in expert_idx_to_local_idx.items(): with self.lock: self._load_expert(expert_idx, local_expert_idx, weights) + # for rl updated weight + self._load_merge_weight(weights) self._load_expert_scale( expert_idx, local_expert_idx, @@ -332,6 +340,7 @@ def _load_expert( w1_weight = f"{self.weight_prefix}.{expert_idx}.{self.w1_weight_name}.{self.quant_method.weight_suffix}" w2_weight = f"{self.weight_prefix}.{expert_idx}.{self.w2_weight_name}.{self.quant_method.weight_suffix}" w3_weight = f"{self.weight_prefix}.{expert_idx}.{self.w3_weight_name}.{self.quant_method.weight_suffix}" + row_slice_func = self.row_slicer._slice_weight col_slice_func = self.col_slicer._slice_weight if w1_weight in weights: @@ -341,6 +350,19 @@ def _load_expert( if w2_weight in weights: self.quant_method.load_weight(col_slice_func(weights[w2_weight]), self.w2_list[local_expert_idx]) + def _load_merge_weight(self, weights: Dict[str, torch.Tensor]): + w1_merge_weight = f"{self.weight_prefix}.{self.w1_weight_name}" + w2_merge_weight = f"{self.weight_prefix}.{self.w2_weight_name}" + w3_merge_weight = f"{self.weight_prefix}.{self.w3_weight_name}" + row_slice_func = self.row_slicer._slice_weight + col_slice_func = self.col_slicer._slice_weight + if w1_merge_weight in weights: + self.quant_method.load_weight(row_slice_func(weights[w1_merge_weight]), self.w1) + if w2_merge_weight in weights: + self.quant_method.load_weight(col_slice_func(weights[w2_merge_weight]), self.w2) + if w3_merge_weight in weights: + self.quant_method.load_weight(row_slice_func(weights[w3_merge_weight]), self.w3) + def _load_expert_scale( self, expert_idx: int, diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py index 6ed0cef0b4..4ca1605be4 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py @@ -8,6 +8,7 @@ from lightllm.common.quantization import Quantcfg from lightllm.common.quantization.quantize_method import QuantizationMethod from lightllm.utils.log_utils import init_logger +from lightllm.common.basemodel import routing_manager as _routing_mgr logger = init_logger(__name__) @@ -46,6 +47,7 @@ def __init__( num_fused_shared_experts: int = 0, layer_num: int = 0, network_config: Dict[str, Any] = None, + moe_layer_index: int = 0, ) -> None: network_config["norm_topk_prob"] = None super().__init__( @@ -62,6 +64,7 @@ def __init__( num_fused_shared_experts=num_fused_shared_experts, layer_num=layer_num, network_config=network_config, + moe_layer_index=moe_layer_index, ) self.hidden_size = network_config["hidden_size"] @@ -144,10 +147,15 @@ def experts( topk_group: int, num_expert_group: int, is_prefill: Optional[bool] = None, + microbatch_index: int = 0, ): topk_weights, topk_ids = self._router(router_logits, top_k) + # Rollout router replay + if _routing_mgr.g_routing_capture_manager is not None: + _routing_mgr.g_routing_capture_manager.capture(self.moe_layer_index, topk_ids, microbatch_index) + w1, w1_scale = self.w1 w2, w2_scale = self.w2 use_fp8_w8a8 = self.quant_method is not None diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py index 00587ac185..1c93cb13dc 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py @@ -62,5 +62,7 @@ def __call__( topk_group: int, num_expert_group: int, is_prefill: Optional[bool] = None, + moe_layer_index: Optional[int] = None, + microbatch_index: int = 0, ) -> torch.Tensor: pass diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py index 8bcdb4bf90..1e81b226ec 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py @@ -3,6 +3,7 @@ from lightllm.common.quantization.no_quant import WeightPack from lightllm.common.quantization.quantize_method import QuantizationMethod from .base_impl import FuseMoeBaseImpl +from lightllm.common.basemodel import routing_manager as _routing_mgr class FuseMoeTriton(FuseMoeBaseImpl): @@ -124,6 +125,8 @@ def __call__( topk_group: int, num_expert_group: int, is_prefill: Optional[bool] = None, + moe_layer_index: Optional[int] = None, + microbatch_index: int = 0, ): topk_weights, topk_ids = self._select_experts( input_tensor=input_tensor, @@ -136,6 +139,10 @@ def __call__( num_expert_group=num_expert_group, scoring_func=scoring_func, ) + + if _routing_mgr.g_routing_capture_manager is not None and moe_layer_index is not None: + _routing_mgr.g_routing_capture_manager.capture(moe_layer_index, topk_ids, microbatch_index) + output = self._fused_experts( input_tensor=input_tensor, w13=w13, diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py index ddbf98a866..15f050c14a 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py @@ -47,17 +47,17 @@ def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Ten # 默认weight 的shape是 outxin,这也是目前最通用的约定。 -# 所以row-wise是沿着dim=0进行切分,col-wise是沿着dim=1进行切分。 +# 这里约定row-wise沿着倒数第二维切分,col-wise沿着第一维切分。 class RowSliceMixin(SliceMixinTpl): def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times: int = 1): super().__init__(tp_rank, tp_world_size, repeat_times) def _slice_weight(self, weight: torch.Tensor) -> torch.Tensor: assert ( - weight.shape[0] * self.repeat_times_ % self.tp_world_size_ == 0 - ), f"tp slice error {weight.shape[0] * self.repeat_times_} % {self.tp_world_size_}" - start, end = self._get_slice_start_end(weight.shape[0]) - return weight[start:end, :] + weight.shape[-2] * self.repeat_times_ % self.tp_world_size_ == 0 + ), f"tp slice error {weight.shape[-2] * self.repeat_times_} % {self.tp_world_size_}" + start, end = self._get_slice_start_end(weight.shape[-2]) + return weight[..., start:end, :] def _slice_bias(self, bias: torch.Tensor) -> torch.Tensor: assert ( @@ -75,17 +75,17 @@ def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times: def _slice_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor: assert ( - weight_scale.shape[0] % self.tp_world_size_ == 0 - ), f"tp slice error {weight_scale.shape[0]} % {self.tp_world_size_}" - start, end = self._get_slice_start_end(weight_scale.shape[0]) - return weight_scale[start:end] + weight_scale.shape[-2] % self.tp_world_size_ == 0 + ), f"tp slice error {weight_scale.shape[-2]} % {self.tp_world_size_}" + start, end = self._get_slice_start_end(weight_scale.shape[-2]) + return weight_scale[..., start:end, :] def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Tensor: assert ( - weight_zero_point.shape[0] % self.tp_world_size_ == 0 - ), f"tp slice error {weight_zero_point.shape[0]} % {self.tp_world_size_}" - start, end = self._get_slice_start_end(weight_zero_point.shape[0]) - return weight_zero_point[start:end] + weight_zero_point.shape[-2] % self.tp_world_size_ == 0 + ), f"tp slice error {weight_zero_point.shape[-2]} % {self.tp_world_size_}" + start, end = self._get_slice_start_end(weight_zero_point.shape[-2]) + return weight_zero_point[..., start:end, :] class ColSliceMixin(SliceMixinTpl): @@ -94,10 +94,10 @@ def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times: def _slice_weight(self, weight: torch.Tensor) -> torch.Tensor: assert ( - weight.shape[1] * self.repeat_times_ % self.tp_world_size_ == 0 - ), f"tp slice error {weight.shape[1] * self.repeat_times_ } % {self.tp_world_size_}" - start, end = self._get_slice_start_end(weight.shape[1]) - return weight[:, start:end] + weight.shape[-1] * self.repeat_times_ % self.tp_world_size_ == 0 + ), f"tp slice error {weight.shape[-1] * self.repeat_times_ } % {self.tp_world_size_}" + start, end = self._get_slice_start_end(weight.shape[-1]) + return weight[..., start:end] def _slice_bias(self, bias: torch.Tensor) -> torch.Tensor: return bias / self.tp_world_size_ * self.repeat_times_ @@ -110,16 +110,16 @@ def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times: def _slice_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor: assert ( weight_scale.shape[1] * self.repeat_times_ % self.tp_world_size_ == 0 - ), f"tp slice error {weight_scale.shape[1] * self.repeat_times_ } % {self.tp_world_size_}" - start, end = self._get_slice_start_end(weight_scale.shape[1]) - return weight_scale[:, start:end] + ), f"tp slice error {weight_scale.shape[-1] * self.repeat_times_ } % {self.tp_world_size_}" + start, end = self._get_slice_start_end(weight_scale.shape[-1]) + return weight_scale[..., start:end] def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Tensor: assert ( - weight_zero_point.shape[1] * self.repeat_times_ % self.tp_world_size_ == 0 - ), f"tp slice error {weight_zero_point.shape[1] * self.repeat_times_ } % {self.tp_world_size_}" - start, end = self._get_slice_start_end(weight_zero_point.shape[1]) - return weight_zero_point[:, start:end] + weight_zero_point.shape[-1] * self.repeat_times_ % self.tp_world_size_ == 0 + ), f"tp slice error {weight_zero_point.shape[-1] * self.repeat_times_ } % {self.tp_world_size_}" + start, end = self._get_slice_start_end(weight_zero_point.shape[-1]) + return weight_zero_point[..., start:end] # awq 的量化权重是inxout存储格式,需要定制实现。 diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py index 3630bc2c00..da9b3f4321 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py @@ -193,7 +193,9 @@ def _create_weight(self): def load_hf_weights(self, weights: Dict[str, torch.Tensor]): for weight_name in self.weight_names: if weight_name in weights: - weight = self.param_slicer._slice_weight(weights[weight_name]) + tp_start = self.tp_rank_ * self.dim0 + tp_end = (self.tp_rank_ + 1) * self.dim0 + weight = weights[weight_name][tp_start:tp_end, :, :] self.weight.copy_(weight) self.weight.load_ok = True return diff --git a/lightllm/common/basemodel/routing_manager.py b/lightllm/common/basemodel/routing_manager.py new file mode 100644 index 0000000000..77b611130f --- /dev/null +++ b/lightllm/common/basemodel/routing_manager.py @@ -0,0 +1,191 @@ +import atexit +import torch +import numpy as np +from multiprocessing import shared_memory +from typing import Optional +from lightllm.utils.log_utils import init_logger +from lightllm.utils.dist_utils import get_current_rank_in_dp +from lightllm.server.router.dynamic_prompt.shared_arr import SharedArray +from lightllm.utils.envs_utils import get_unique_server_name + +logger = init_logger(__name__) + + +def routing_dtype_id_to_np(dtype_id: int): + if dtype_id == 1: + return np.uint8 + elif dtype_id == 2: + return np.int16 + return np.int32 + + +def get_routing_config_shm() -> SharedArray: + service_name = get_unique_server_name() + return SharedArray(f"{service_name}_routing_config", shape=(4,), dtype=np.int32) + + +class RoutingCaptureManager: + def __init__( + self, + num_moe_layers: int, + topk: int, + num_experts: int, + kv_cache_size: int, + max_capture_tokens: int, + ): + self.num_moe_layers = num_moe_layers + self.topk = topk + self.num_experts = num_experts + self.kv_cache_size = kv_cache_size + + self.dtype = torch.uint8 if num_experts <= 255 else torch.int16 + dtype_bytes = 1 if self.dtype == torch.uint8 else 2 + + # Shape: (kv_cache_size, num_moe_layers, topk) — on CPU to save GPU memory. + # Written after forward() via flush_to_routing_buffer(), read on request finish. + routing_buffer_size = num_moe_layers * kv_cache_size * topk * dtype_bytes + self.routing_buffer = torch.zeros( + (kv_cache_size, num_moe_layers, topk), + dtype=self.dtype, + device="cpu", + ) + + # Capture buffers: simple contiguous tensors written to during forward(). + capture_buf_size = max_capture_tokens * num_moe_layers * topk * dtype_bytes + self._capture_buffer = [ + torch.zeros((max_capture_tokens, num_moe_layers, topk), dtype=self.dtype, device="cuda") for _ in range(2) + ] + + dtype_name = "uint8" if self.dtype == torch.uint8 else "int16" + logger.info( + f"RoutingCaptureManager initialized: {num_moe_layers} MoE layers, topk={topk}, " + f"routing_buffer(cpu)={routing_buffer_size / 1024 / 1024:.2f}MB, " + f"capture_buffer={capture_buf_size / 1024 / 1024:.2f}MB x2, dtype={dtype_name}" + ) + + @property + def np_dtype(self): + return np.uint8 if self.dtype == torch.uint8 else np.int16 + + @property + def dtype_id(self) -> int: + return 1 if self.dtype == torch.uint8 else 2 + + def capture(self, moe_layer_index: int, topk_ids: torch.Tensor, microbatch_index: int = 0) -> None: + num_tokens = topk_ids.shape[0] + self._capture_buffer[microbatch_index][:num_tokens, moe_layer_index, :] = topk_ids.to(self.dtype) + + def flush_to_routing_buffer(self, mem_indexes: torch.Tensor, num_tokens: int, microbatch_index: int = 0) -> None: + buf = self._capture_buffer[microbatch_index][:num_tokens] # (num_tokens, num_moe_layers, topk) + self.routing_buffer[mem_indexes[:num_tokens].cpu(), :, :] = buf.cpu() + + def extract_routing_data(self, mem_indexes: torch.Tensor) -> np.ndarray: + cpu_indexes = mem_indexes.cpu() if mem_indexes.is_cuda else mem_indexes + return self.routing_buffer[cpu_indexes, :, :].numpy() + + +g_routing_capture_manager: Optional[RoutingCaptureManager] = None + + +def create_routing_capture_manager( + num_moe_layers: int, + topk: int, + num_experts: int, + kv_cache_size: int, + max_capture_tokens: int, +) -> None: + global g_routing_capture_manager + assert g_routing_capture_manager is None, "RoutingCaptureManager already exists" + g_routing_capture_manager = RoutingCaptureManager( + num_moe_layers=num_moe_layers, + topk=topk, + num_experts=num_experts, + kv_cache_size=kv_cache_size, + max_capture_tokens=max_capture_tokens, + ) + + +def cleanup_routing_shm_pool() -> None: + """Unlink all pre-allocated routing SHM segments. Called at server shutdown.""" + try: + from lightllm.utils.envs_utils import get_env_start_args + + args = get_env_start_args() + except Exception: + return + + service_name = get_unique_server_name() + + for i in range(args.running_max_req_size): + name = f"{service_name}_shm_routing_{i}" + try: + shm = shared_memory.SharedMemory(name=name) + shm.close() + shm.unlink() + except Exception: + pass + + config_name = f"{service_name}_routing_config" + try: + shm = shared_memory.SharedMemory(name=config_name) + shm.close() + shm.unlink() + except Exception: + pass + + +def init_routing_capture(model, num_moe_layers: int) -> None: + dp_rank = get_current_rank_in_dp() + logger.info(f"init_routing_capture called: num_moe_layers={num_moe_layers}, dp_rank={dp_rank}") + if dp_rank != 0: + logger.info(f"Skipping routing capture initialization on dp_rank={dp_rank}") + return + + if num_moe_layers == 0: + logger.warning( + "enable_return_routed_experts is set but no MoE layers found. Routing capture will not be enabled." + ) + return + + num_experts = model.config.get("n_routed_experts", model.config.get("num_experts", 0)) + topk = model.config.get("num_experts_per_tok", 0) + assert num_experts > 0 and topk > 0 + + from lightllm.utils.envs_utils import get_env_start_args + + args = get_env_start_args() + + # Capture buffer must fit the max tokens in any single forward call. + # For prefill that's batch_max_tokens; for decode it's graph_max_batch_size. + batch_max_tokens = args.batch_max_tokens or args.max_req_total_len or 8192 + max_capture_tokens = max(batch_max_tokens, args.graph_max_batch_size) + + logger.info( + f"Initializing routing capture: num_moe_layers={num_moe_layers}, " + f"topk={topk}, num_experts={num_experts}, max_capture_tokens={max_capture_tokens}" + ) + + create_routing_capture_manager( + num_moe_layers=num_moe_layers, + topk=topk, + num_experts=num_experts, + kv_cache_size=model.mem_manager.size + 1, + max_capture_tokens=max_capture_tokens, + ) + + mgr = g_routing_capture_manager + dtype_id = mgr.dtype_id + + max_req_total_len = args.max_req_total_len + + # Write config to cross-process SHM + shm = get_routing_config_shm() + shm.arr[0] = num_moe_layers + shm.arr[1] = topk + shm.arr[2] = dtype_id + shm.arr[3] = max_req_total_len + logger.info( + f"Shared routing config set: num_moe_layers={num_moe_layers}, topk={topk}, " + f"dtype_id={dtype_id}, max_tokens={max_req_total_len}" + ) + atexit.register(cleanup_routing_shm_pool) diff --git a/lightllm/common/basemodel/triton_kernel/norm/qk_norm.py b/lightllm/common/basemodel/triton_kernel/norm/qk_norm.py index 40322e5093..8e0de6a6e3 100644 --- a/lightllm/common/basemodel/triton_kernel/norm/qk_norm.py +++ b/lightllm/common/basemodel/triton_kernel/norm/qk_norm.py @@ -64,3 +64,28 @@ def qk_rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, eps): num_warps=4, ) return x + + +# @torch.no_grad() +# def qk_rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, eps: float): +# assert torch.is_tensor(x) and torch.is_tensor(weight) +# # assert weight.ndim == 1, weight.shape +# # assert x.is_contiguous(), "x.is_contiguous()" + +# head_dim = weight.numel() +# x2d = x.view(-1, x.shape[-1]) # (M2, N) +# M2, N = x2d.shape +# assert N % head_dim == 0, (N, head_dim) +# H = N // head_dim + +# x3 = x2d.view(M2, H, head_dim) # (M2, H, D) + +# x_fp32 = x3.to(torch.float32) +# w = weight.view(1, 1, head_dim) + +# var = x_fp32.pow(2).mean(dim=-1, keepdim=True) +# rstd = torch.rsqrt(var + eps) +# y = (x_fp32 * rstd).to(torch.bfloat16) * w + +# x3.copy_(y.to(dtype=x3.dtype)) +# return x diff --git a/lightllm/common/basemodel/triton_kernel/norm/rmsnorm.py b/lightllm/common/basemodel/triton_kernel/norm/rmsnorm.py index ca8f9a1c81..6988cc4113 100644 --- a/lightllm/common/basemodel/triton_kernel/norm/rmsnorm.py +++ b/lightllm/common/basemodel/triton_kernel/norm/rmsnorm.py @@ -4,7 +4,7 @@ import triton.language as tl import os -rmsnorm_num_warps = int(os.getenv("RMSNORM_WARPS", "8")) +rmsnorm_num_warps = int(os.getenv("RMSNORM_WARPS", "4")) @triton.jit @@ -36,12 +36,12 @@ def _rms_norm_fwd_fused( for off in range(0, N, BLOCK_SIZE): cols = off + tl.arange(0, BLOCK_SIZE) mask = cols < N - w = tl.load(W + cols, mask=mask).to(tl.float32) + w = tl.load(W + cols, mask=mask) x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) - x_hat = x * rstd - y = x_hat * w + x_hat = (x * rstd).to(tl.bfloat16) + y = x_hat * w.to(tl.bfloat16) # Write output - tl.store(Y + cols * y_stride1, y.to(Y.dtype.element_ty), mask=mask) + tl.store(Y + cols * y_stride1, y, mask=mask) def rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, eps: float, out=None): @@ -79,22 +79,19 @@ def rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, eps: float, out=None) return y -def torch_rms_norm(x, weight, eps): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) * weight +# def rmsnorm_forward(hidden_states, weight, eps, out=None): +# input_dtype = hidden_states.dtype +# hidden_states = hidden_states.to(torch.float32) +# variance = hidden_states.pow(2).mean(-1, keepdim=True) +# hidden_states = hidden_states * torch.rsqrt(variance + eps) +# out = weight * hidden_states.to(input_dtype) +# return out -def test_rms_norm(M, N, dtype, eps=1e-5, device="cuda"): - # create data - x_shape = (M, N) - w_shape = (x_shape[-1],) - weight = torch.rand(w_shape, dtype=dtype, device="cuda") - x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") - # forward pass - y_tri = rmsnorm_forward(x, weight, eps) - y_ref = torch_rms_norm(x.to(torch.float32), weight.to(torch.float32), eps).to(dtype) - - # compare - print("type:", y_tri.dtype, y_ref.dtype) - print("max delta:", torch.max(torch.abs(y_tri - y_ref))) - assert torch.allclose(y_tri, y_ref, atol=1e-2, rtol=0) - return +def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + print(f"norm weight dtype:{self.weight.dtype}") + return self.weight * hidden_states.to(input_dtype) diff --git a/lightllm/common/basemodel/triton_kernel/post_process/__init__.py b/lightllm/common/basemodel/triton_kernel/post_process/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/common/basemodel/triton_kernel/post_process/apply_invalid_token.py b/lightllm/common/basemodel/triton_kernel/post_process/apply_invalid_token.py new file mode 100644 index 0000000000..353affd8ed --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/post_process/apply_invalid_token.py @@ -0,0 +1,36 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_apply_invalid_token( + Logits, + invalid_token_ids, + cu_invalid_token_num, + stride_logit_b, +): + cur_batch = tl.program_id(0) + start_index = tl.load(cu_invalid_token_num + cur_batch) + end_index = tl.load(cu_invalid_token_num + cur_batch + 1) + for i in range(start_index, end_index): + cur_invalid_token_id = tl.load(invalid_token_ids + i) + cur_logit_ptr = Logits + cur_batch * stride_logit_b + cur_invalid_token_id + tl.store(cur_logit_ptr, float("-inf")) + return + + +def apply_invalid_token_ids( + Logits: torch.Tensor, + invalid_token_ids: torch.Tensor, + cu_invalid_token_num: torch.Tensor, +): + batch_size = Logits.shape[0] + grid = (batch_size,) + _fwd_kernel_apply_invalid_token[grid]( + Logits=Logits, + invalid_token_ids=invalid_token_ids, + cu_invalid_token_num=cu_invalid_token_num, + stride_logit_b=Logits.stride(0), + ) + return diff --git a/lightllm/common/basemodel/triton_kernel/apply_penalty.py b/lightllm/common/basemodel/triton_kernel/post_process/apply_penalty.py similarity index 100% rename from lightllm/common/basemodel/triton_kernel/apply_penalty.py rename to lightllm/common/basemodel/triton_kernel/post_process/apply_penalty.py diff --git a/lightllm/common/basemodel/triton_kernel/apply_penalty_gpu_cache.py b/lightllm/common/basemodel/triton_kernel/post_process/apply_penalty_gpu_cache.py similarity index 100% rename from lightllm/common/basemodel/triton_kernel/apply_penalty_gpu_cache.py rename to lightllm/common/basemodel/triton_kernel/post_process/apply_penalty_gpu_cache.py diff --git a/lightllm/common/kv_cache_mem_manager/__init__.py b/lightllm/common/kv_cache_mem_manager/__init__.py index 7d516e6728..bcc1292097 100644 --- a/lightllm/common/kv_cache_mem_manager/__init__.py +++ b/lightllm/common/kv_cache_mem_manager/__init__.py @@ -4,6 +4,7 @@ from .ppl_int8kv_mem_manager import PPLINT8KVMemoryManager from .ppl_int4kv_mem_manager import PPLINT4KVMemoryManager from .deepseek2_mem_manager import Deepseek2MemoryManager +from .neo_mem_manager import NeoMemoryManager __all__ = [ "MemoryManager", @@ -13,4 +14,5 @@ "PPLINT4KVMemoryManager", "PPLINT8KVMemoryManager", "Deepseek2MemoryManager", + "NeoMemoryManager", ] diff --git a/lightllm/common/kv_cache_mem_manager/mem_manager.py b/lightllm/common/kv_cache_mem_manager/mem_manager.py index 1203cbdec7..ce628aa8f7 100755 --- a/lightllm/common/kv_cache_mem_manager/mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/mem_manager.py @@ -61,7 +61,7 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False self.size, dtype, head_num, - head_dim, + self.head_dim, layer_num, ) self.HOLD_TOKEN_MEMINDEX = self.size diff --git a/lightllm/common/kv_cache_mem_manager/mem_utils.py b/lightllm/common/kv_cache_mem_manager/mem_utils.py index 1ff58b89a0..686993dde4 100644 --- a/lightllm/common/kv_cache_mem_manager/mem_utils.py +++ b/lightllm/common/kv_cache_mem_manager/mem_utils.py @@ -5,6 +5,7 @@ PPLINT8KVMemoryManager, PPLINT4KVMemoryManager, Deepseek2MemoryManager, + NeoMemoryManager, ) from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import get_env_start_args @@ -19,12 +20,16 @@ def select_mem_manager_class(): # case 1 # 先判断是否是 deepseek 系列的模型 model_class = get_llm_model_class() - from lightllm.models import Deepseek2TpPartModel + from lightllm.models import Deepseek2TpPartModel, NeoTpMOEPartModel, NeoTpPartModel if issubclass(model_class, Deepseek2TpPartModel): mem_class = Deepseek2MemoryManager logger.info(f"Model kv cache using default, mem_manager class: {mem_class}") return mem_class + # 判断是否是 neo 系列的模型 + elif issubclass(model_class, NeoTpMOEPartModel) or issubclass(model_class, NeoTpPartModel): + mem_class = NeoMemoryManager + return mem_class # case normal logger.info(f"mode setting params: {get_env_start_args().llm_kv_type}") diff --git a/lightllm/common/kv_cache_mem_manager/neo_mem_manager.py b/lightllm/common/kv_cache_mem_manager/neo_mem_manager.py new file mode 100755 index 0000000000..0a79aa072b --- /dev/null +++ b/lightllm/common/kv_cache_mem_manager/neo_mem_manager.py @@ -0,0 +1,46 @@ +import torch +from lightllm.utils.dist_utils import get_current_rank_in_node +from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt +from lightllm.common.kv_cache_mem_manager.mem_manager import MemoryManager + + +class NeoMemoryManager(MemoryManager): + def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): + self.size = size + self.head_num = head_num + self.head_dim = head_dim * 2 # neo kv 是[k, k_h, k_w]拼在一起的 + self.layer_num = layer_num + self.always_copy = always_copy + self.dtype = dtype + # profile the max total token num if the size is None + self.profile_size(mem_fraction) + + self.mem_state = torch.arange( + 0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True + ) + self._mem_state_return = torch.arange( + 0, self.size * 3, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True + ) + self._return_start = 0 + self.mark_start = 0 + self.mark_end = self.size + + self.can_use_mem_size = self.size + + # 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。 + from lightllm.utils.envs_utils import get_unique_server_name + + rank_in_node = get_current_rank_in_node() + self.shared_can_use_token_num = SharedInt( + f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}" + ) + + self.shared_can_use_token_num.set_value(self.can_use_mem_size) + self._init_buffers( + self.size, + dtype, + head_num, + self.head_dim, + layer_num, + ) + self.HOLD_TOKEN_MEMINDEX = self.size diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=768,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=768,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 0000000000..c75c871c72 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=768,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 5, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 8 + }, + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "8448": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=384,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=384,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 0000000000..14026090e6 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=384,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,110 @@ +{ + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "32768": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "67584": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "800": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8192": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H200.json new file mode 100644 index 0000000000..939c939523 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H200.json @@ -0,0 +1,74 @@ +{ + "1": { + "BLOCK_DIM": 128, + "BLOCK_M": 2, + "NUM_STAGE": 2, + "num_warps": 4 + }, + "100": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "1024": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "128": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 16 + }, + "16": { + "BLOCK_DIM": 128, + "BLOCK_M": 1, + "NUM_STAGE": 2, + "num_warps": 4 + }, + "2048": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 2 + }, + "256": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "32": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 2 + }, + "4096": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "64": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "8": { + "BLOCK_DIM": 128, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "8448": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=384,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=384,out_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..13ba4ba8e5 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=384,out_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,74 @@ +{ + "1024": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "128": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "16384": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "2048": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "256": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 2, + "num_warps": 4 + }, + "32768": { + "BLOCK_M": 32, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "512": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 8 + }, + "64": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "67584": { + "BLOCK_M": 64, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "8": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 4 + }, + "800": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "8192": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..ee316f610b --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "128": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "32768": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "67584": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "800": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "8192": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..e027701092 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "32768": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "67584": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "800": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8192": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..ddda23d257 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": true, + "num_stages": 4, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": true, + "num_stages": 2, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": true, + "num_stages": 4, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": true, + "num_stages": 2, + "num_warps": 4 + }, + "32768": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": true, + "num_stages": 2, + "num_warps": 4 + }, + "67584": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": true, + "num_stages": 2, + "num_warps": 4 + }, + "800": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": true, + "num_stages": 4, + "num_warps": 4 + }, + "8192": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..560ca6c09d --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 5, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 8 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "8448": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..0713de7996 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 8 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "8448": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..e950ff0954 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": true, + "num_stages": 4, + "num_warps": 4 + }, + "8448": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_align_fused:v1/{topk_num=8}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_align_fused:v1/{topk_num=8}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..7f479b8382 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_align_fused:v1/{topk_num=8}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,50 @@ +{ + "1": { + "BLOCK_SIZE": 256, + "num_warps": 2 + }, + "100": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "1024": { + "BLOCK_SIZE": 128, + "num_warps": 2 + }, + "128": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "16": { + "BLOCK_SIZE": 256, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE": 256, + "num_warps": 8 + }, + "256": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "32": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "4096": { + "BLOCK_SIZE": 128, + "num_warps": 1 + }, + "64": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "8": { + "BLOCK_SIZE": 256, + "num_warps": 2 + }, + "8448": { + "BLOCK_SIZE": 256, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..b3051c6584 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,74 @@ +{ + "1": { + "BLOCK_DIM": 128, + "BLOCK_M": 1, + "NUM_STAGE": 2, + "num_warps": 8 + }, + "100": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "1024": { + "BLOCK_DIM": 1024, + "BLOCK_M": 2, + "NUM_STAGE": 4, + "num_warps": 4 + }, + "128": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "16": { + "BLOCK_DIM": 128, + "BLOCK_M": 1, + "NUM_STAGE": 2, + "num_warps": 4 + }, + "2048": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "256": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 2 + }, + "32": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 2 + }, + "4096": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 1 + }, + "64": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "8": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 2 + }, + "8448": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.float16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.float16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..fdb3212216 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.float16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,74 @@ +{ + "1": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 8 + }, + "100": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "1024": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 2 + }, + "128": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "16": { + "BLOCK_DIM": 128, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 4 + }, + "2048": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "256": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "32": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 2, + "num_warps": 8 + }, + "4096": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 2 + }, + "64": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "8": { + "BLOCK_DIM": 128, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "8448": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..a94e669353 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,74 @@ +{ + "1024": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "128": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "16384": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "2048": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "256": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "32768": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "512": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "64": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "67584": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "8": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "800": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "8192": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.float16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.float16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..441421fd5d --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.float16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,74 @@ +{ + "1024": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "128": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "16384": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "2048": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "256": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "32768": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "512": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "64": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "67584": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "8": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "800": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "8192": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/models/__init__.py b/lightllm/models/__init__.py index 095f736791..81208286b2 100644 --- a/lightllm/models/__init__.py +++ b/lightllm/models/__init__.py @@ -37,4 +37,6 @@ Tarsier2LlamaTpPartModel, ) from lightllm.models.gpt_oss.model import GptOssTpPartModel +from lightllm.models.neo_chat_moe.model import NeoTpMOEPartModel +from lightllm.models.neo_chat.model import NeoTpPartModel from .registry import get_model, get_model_class diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index 98cc7c229e..97015f6b20 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -312,6 +312,7 @@ def _moe_ffn( use_grouped_topk=self.n_group, topk_group=self.topk_group, num_expert_group=self.n_group, + microbatch_index=infer_state.microbatch_index, ) if self.n_shared_experts is not None and layer_weight.num_fused_shared_experts == 0: @@ -339,6 +340,7 @@ def _moe_ffn_edp( topk_group=self.topk_group, num_expert_group=self.n_group, is_prefill=infer_state.is_prefill, + microbatch_index=infer_state.microbatch_index, ) if self.n_shared_experts is not None: diff --git a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py index 3eb09f9176..bd72035072 100644 --- a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py @@ -242,6 +242,9 @@ def _init_moe(self): # == 0 时,说明不存在融合共享专家,共享专家单独加载和进行推理。 if self.num_fused_shared_experts == 0: self._load_mlp(f"model.layers.{self.layer_num_}.mlp.shared_experts", is_shared_experts=True) + first_moe = self.network_config_["first_k_dense_replace"] + freq = self.network_config_.get("moe_layer_freq", 1) + moe_layer_index = (self.layer_num_ - first_moe) // freq self.experts = FusedMoeWeight( gate_proj_name="gate_proj", down_proj_name="down_proj", @@ -256,6 +259,7 @@ def _init_moe(self): num_fused_shared_experts=self.num_fused_shared_experts, layer_num=self.layer_num_, network_config=self.network_config_, + moe_layer_index=moe_layer_index, ) def _init_ffn(self): diff --git a/lightllm/models/deepseek2/model.py b/lightllm/models/deepseek2/model.py index e596eed97c..79bd327068 100644 --- a/lightllm/models/deepseek2/model.py +++ b/lightllm/models/deepseek2/model.py @@ -6,6 +6,7 @@ from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo from lightllm.models.llama.model import LlamaTpPartModel from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class +from lightllm.common.basemodel.routing_manager import init_routing_capture from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args, get_added_mtp_kv_layer_num from lightllm.distributed.communication_op import dist_group_manager @@ -49,6 +50,9 @@ def _init_some_value(self): def _init_custom(self): self._init_to_get_yarn_rotary() dist_group_manager.new_deepep_group(self.config["n_routed_experts"], self.config["hidden_size"]) + if self.args.enable_return_routed_experts: + num_moe_layers = sum(1 for w in self.trans_layers_weight if w.is_moe) + init_routing_capture(self, num_moe_layers) def _verify_params(self): return super()._verify_params() diff --git a/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py b/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py index d80eefd16e..e5672f8210 100644 --- a/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py @@ -51,6 +51,7 @@ def _ffn(self, input, infer_state, layer_weight: GptOssTransformerLayerWeight) - use_grouped_topk=False, topk_group=None, num_expert_group=None, + microbatch_index=infer_state.microbatch_index, ) return hidden_states.view(num_tokens, hidden_dim) diff --git a/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py b/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py index 7c8c30940e..7278c62fec 100644 --- a/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py @@ -55,6 +55,7 @@ def _init_moe(self): num_fused_shared_experts=0, layer_num=self.layer_num_, network_config=self.network_config_, + moe_layer_index=self.layer_num_, ) def _init_weight_names(self): diff --git a/lightllm/models/gpt_oss/model.py b/lightllm/models/gpt_oss/model.py index 9e9561eb24..cff748933d 100644 --- a/lightllm/models/gpt_oss/model.py +++ b/lightllm/models/gpt_oss/model.py @@ -2,6 +2,7 @@ from lightllm.models.gpt_oss.layer_weights.transformer_layer_weight import GptOssTransformerLayerWeight from lightllm.models.llama.model import LlamaTpPartModel from lightllm.models.registry import ModelRegistry +from lightllm.common.basemodel.routing_manager import init_routing_capture from lightllm.utils.envs_utils import get_env_start_args from lightllm.utils.log_utils import init_logger from lightllm.common.basemodel.attention import get_prefill_att_backend_class, get_decode_att_backend_class @@ -21,6 +22,12 @@ class GptOssTpPartModel(LlamaTpPartModel): def __init__(self, kvargs): super().__init__(kvargs) + def _init_custom(self): + super()._init_custom() + if self.args.enable_return_routed_experts: + num_moe_layers = len(self.trans_layers_weight) + init_routing_capture(self, num_moe_layers) + def _init_att_backend(self): self.prefill_att_backend: BaseAttBackend = get_prefill_att_backend_class(index=0, priority_list=["fa3"])( model=self diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py old mode 100644 new mode 100755 diff --git a/lightllm/models/llama/model.py b/lightllm/models/llama/model.py index c104ebccc9..cc1dc28178 100644 --- a/lightllm/models/llama/model.py +++ b/lightllm/models/llama/model.py @@ -74,14 +74,19 @@ def _init_custom(self): rope_scaling = self.config.get("rope_scaling", None) if rope_scaling is None: self._init_to_get_rotary() - return - - if "rope_type" in rope_scaling: + elif "rope_type" in rope_scaling: scaling_type = rope_scaling["rope_type"] + self._init_rotary_by_scaling_type(scaling_type, rope_scaling) elif "type" in rope_scaling: scaling_type = rope_scaling["type"] + self._init_rotary_by_scaling_type(scaling_type, rope_scaling) else: raise ValueError(f"Unknown RoPE scaling format {rope_scaling}") + if "rope_theta_hw" in self.config: + self._init_to_get_hw_rotary() + super()._init_custom() + + def _init_rotary_by_scaling_type(self, scaling_type, rope_scaling): if scaling_type == "default" or "mrope_section" in rope_scaling: self._init_to_get_rotary() elif scaling_type == "yarn": @@ -96,7 +101,6 @@ def _init_custom(self): self._init_to_get_rotary() else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") - return def _init_to_get_rotary(self, default_base=10000): partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_) @@ -106,7 +110,6 @@ def _init_to_get_rotary(self, default_base=10000): rope_scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0) base = self.config.get("rope_theta", float(default_base)) - if "max_sequence_length" in self.config: max_seq_len = self.config["max_sequence_length"] else: @@ -139,6 +142,46 @@ def _init_to_get_rotary(self, default_base=10000): self._sin_cached = torch.sin(freqs).to(self.data_type).cuda() return + def _init_to_get_hw_rotary(self, default_base=10000): + partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_ // 2) + if self.config.get("rope_scaling", {}) is None: + rope_scaling_factor = 1.0 + else: + rope_scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0) + + base = self.config.get("rope_theta_hw", float(default_base)) + if "max_sequence_length" in self.config: + max_seq_len = self.config["max_sequence_length"] + else: + max_position_embeddings = self.config.get( + "max_position_embeddings_hw", 2048 if base <= 10000.0 + 1e-5 else 16384 + ) + max_seq_len = max_position_embeddings * rope_scaling_factor + + # NTK + try: + ntk_alpha = float(os.environ.get("LIGHTLLM_NTK_ALPHA", 1)) + assert ntk_alpha >= 1 + if ntk_alpha > 1: + logger.info(f"Note: NTK enabled, alpha set to {ntk_alpha}") + max_seq_len *= ntk_alpha + base = base * (ntk_alpha ** (partial_head_dim / (partial_head_dim - 2))) # Base change formula + except: + pass + + inv_freq = 1.0 / ( + base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim) + ) + t = ( + torch.arange(max(max_seq_len + 1024 * 128, self.max_seq_length), device="cpu", dtype=torch.float32) + / rope_scaling_factor + ) + freqs = torch.outer(t, inv_freq) + + self._hw_cos_cached = torch.cos(freqs).to(self.data_type).cuda() + self._hw_sin_cached = torch.sin(freqs).to(self.data_type).cuda() + return + def _init_to_get_dynamic_ntk_rotary(self): partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_) max_position_embeddings = self.config.get("max_position_embeddings", 2048) diff --git a/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py b/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py index 45de83e989..223a64ad51 100644 --- a/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py +++ b/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py @@ -75,7 +75,8 @@ def token_att_fwd(q, k, att_out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen Lq, Lk = q.shape[-1], k.shape[-1] assert Lq == Lk assert Lk in {16, 32, 64, 128, 256} - sm_scale = 1.0 / (Lk ** 0.5) + Lk_scale = Lk // 2 + sm_scale = 1.0 / (Lk_scale ** 0.5) batch, head_num = B_req_idx.shape[0], q.shape[1] diff --git a/lightllm/models/mixtral/layer_infer/_custom_ops.py b/lightllm/models/mixtral/layer_infer/_custom_ops.py deleted file mode 100644 index b0e27ac1de..0000000000 --- a/lightllm/models/mixtral/layer_infer/_custom_ops.py +++ /dev/null @@ -1,46 +0,0 @@ -import functools -import json -import os -from typing import Any, Dict, Optional, Tuple - -import torch -import triton -import triton.language as tl -from lightllm.utils.log_utils import init_logger - -logger = init_logger(__name__) - -# Pytorch version -# Triton version in progress -def topk_softmax( - topk_weights, - topk_ids, - token_expert_indicies, - gating_output, - topk=2, -): - scores = torch.softmax(gating_output, dim=-1) - topk_weights, topk_ids = torch.topk(scores, k=topk, dim=-1, sorted=False) - return topk_weights, topk_ids - - -def fused_topk( - hidden_states: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, - alloc_tensor_func=torch.empty, -): - assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" - - M, _ = hidden_states.shape - - topk_weights = alloc_tensor_func((M, topk), dtype=torch.float32, device=hidden_states.device) - topk_ids = alloc_tensor_func((M, topk), dtype=torch.int32, device=hidden_states.device) - token_expert_indicies = alloc_tensor_func((M, topk), dtype=torch.int32, device=hidden_states.device) - topk_weights, topk_ids = topk_softmax(topk_weights, topk_ids, token_expert_indicies, gating_output.float(), topk) - del token_expert_indicies # Not used. Will be used in the future. - - if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - return topk_weights, topk_ids diff --git a/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py b/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py index 44e66cff2d..a2968f5ab1 100644 --- a/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py @@ -1,9 +1,6 @@ -import os import torch -import torch.nn.functional as F from lightllm.common.basemodel.infer_struct import InferStateInfo from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer -from lightllm.models.mixtral.layer_infer._custom_ops import fused_topk from lightllm.models.mixtral.layer_weights.transformer_layer_weight import MixtralTransformerLayerWeight @@ -19,25 +16,15 @@ def _ffn(self, input, infer_state: InferStateInfo, layer_weight: MixtralTransfor hidden_states = input.view(-1, self.embed_dim_) num_tokens, hidden_dim = hidden_states.shape - router_logits = layer_weight.moe_gate.mm(input.view(-1, self.embed_dim_)) - topk_weights, topk_ids = fused_topk( - hidden_states=hidden_states, - gating_output=router_logits, - topk=self.num_experts_per_tok, + router_logits = layer_weight.moe_gate.mm(hidden_states) + layer_weight.experts.experts( + hidden_states, + router_logits=router_logits, + top_k=self.num_experts_per_tok, renormalize=self.renormalize, - alloc_tensor_func=self.alloc_tensor, - ) - from lightllm.common.fused_moe.grouped_fused_moe import fused_experts_impl - - return fused_experts_impl( - hidden_states=hidden_states, - w1=layer_weight.experts.w1[0], - w2=layer_weight.experts.w2[0], - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - use_fp8_w8a8=False, - w1_scale=None, - w2_scale=None, - alloc_tensor_func=self.alloc_tensor, + use_grouped_topk=False, + topk_group=None, + num_expert_group=None, + microbatch_index=getattr(infer_state, "microbatch_index", 0), ) + return hidden_states.view(num_tokens, hidden_dim) diff --git a/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py b/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py index 51c62fd4cb..d93cb5fb58 100644 --- a/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py @@ -57,4 +57,5 @@ def _init_moe(self): quant_method=self.quant_cfg.get_quant_method(self.layer_num_, "fused_moe"), layer_num=self.layer_num_, network_config=self.network_config_, + moe_layer_index=self.layer_num_, ) diff --git a/lightllm/models/mixtral/model.py b/lightllm/models/mixtral/model.py index 3c2d7b4e87..35bf38de58 100644 --- a/lightllm/models/mixtral/model.py +++ b/lightllm/models/mixtral/model.py @@ -2,6 +2,7 @@ import numpy as np from lightllm.models.registry import ModelRegistry from lightllm.common.basemodel.basemodel import TpPartBaseModel +from lightllm.common.basemodel.routing_manager import init_routing_capture from lightllm.common.kv_cache_mem_manager import MemoryManager from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer @@ -45,6 +46,9 @@ def _verify_params(self): def _init_custom(self): self._init_to_get_rotary() + if self.args.enable_return_routed_experts: + num_moe_layers = len(self.trans_layers_weight) + init_routing_capture(self, num_moe_layers) return def _init_mem_manager(self): diff --git a/lightllm/models/neo_chat/__init__.py b/lightllm/models/neo_chat/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/neo_chat/layer_infer/__init__.py b/lightllm/models/neo_chat/layer_infer/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py new file mode 100644 index 0000000000..a3436b28ee --- /dev/null +++ b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py @@ -0,0 +1,159 @@ +import torch +from functools import partial +from typing import Tuple +from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd +from lightllm.models.neo_chat_moe.infer_struct import NeoChatInferStateInfo +from lightllm.models.neo_chat_moe.triton_kernel.context_attention_fwd_neo import context_attention_fwd_neo +from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd +from lightllm.models.qwen3.layer_infer.transformer_layer_infer import Qwen3TransformerLayerInfer +from lightllm.models.neo_chat.layer_weights.transformer_layer_weight import NeoChatTransformerLayerWeight +from lightllm.distributed import all_reduce +import torch.distributed as dist +from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer + + +class NeoChatTransformerLayerInfer(Qwen3TransformerLayerInfer): + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) + return + + def _bind_attention(self): + self._context_attention_kernel = self._context_attention_kernel + self._token_attention_kernel = self._token_decode_attention_normal + self._copy_kv_to_mem_cache = self._copy_kv_to_mem_cache_normal + return + + def _get_qkv(self, input, infer_state: NeoChatInferStateInfo, layer_weight: NeoChatTransformerLayerWeight): + input = input.view(-1, self.embed_dim_) + q = layer_weight.q_proj.mm(input) # [T, Hq*D] + + q_hw = layer_weight.q_hw_proj.mm(input) + q_hw = q_hw.view(-1, self.tp_q_head_num_, self.head_dim_) + q_h, q_w = q_hw.chunk(2, dim=-1) + + k_hw = layer_weight.k_hw_proj.mm(input) + k_hw = k_hw.view(-1, self.tp_k_head_num_, self.head_dim_) + k_h, k_w = k_hw.chunk(2, dim=-1) + + cache_kv = layer_weight.kv_proj.mm(input) # [T, (Hk+Hv)*D] + + layer_weight.q_norm_weight_(q, eps=self.eps_) + + q_h_2d = q_h.reshape(q.shape[0], -1) + q_w_2d = q_w.reshape(q.shape[0], -1) + layer_weight.q_norm_h_weight_(q_h_2d, eps=self.eps_) + layer_weight.q_norm_w_weight_(q_w_2d, eps=self.eps_) + q_h = q_h_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2) + q_w = q_w_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2) + + layer_weight.k_norm_weight_( + cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], + eps=self.eps_, + ) + + k_h_2d = k_h.reshape(q.shape[0], -1) # [T, Hk*(D/2)] + k_w_2d = k_w.reshape(q.shape[0], -1) + layer_weight.k_norm_h_weight_(k_h_2d, eps=self.eps_) + layer_weight.k_norm_w_weight_(k_w_2d, eps=self.eps_) + k_h = k_h_2d.view(q.shape[0], self.tp_k_head_num_, self.head_dim_ // 2) + k_w = k_w_2d.view(q.shape[0], self.tp_k_head_num_, self.head_dim_ // 2) + + cache_kv = cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) + + rotary_emb_fwd( + q.view(-1, self.tp_q_head_num_, self.head_dim_), + cache_kv[:, : self.tp_k_head_num_, :], + infer_state.position_cos, + infer_state.position_sin, + ) + rotary_emb_fwd( + q_h, + k_h, + infer_state.position_cos_h, + infer_state.position_sin_h, + ) + rotary_emb_fwd( + q_w, + k_w, + infer_state.position_cos_w, + infer_state.position_sin_w, + ) + + q3 = q.view(-1, self.tp_q_head_num_, self.head_dim_) + q3 = torch.cat([q3, q_h, q_w], dim=-1) + q = q3.reshape(q3.shape[0], -1) + + k = cache_kv[:, : self.tp_k_head_num_, :] + k = torch.cat([k, k_h, k_w], dim=-1) + + v = cache_kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :] + v_pad = torch.zeros((v.shape[0], v.shape[1], self.head_dim_), device=v.device, dtype=v.dtype) + v = torch.cat([v, v_pad], dim=-1) + + cache_kv = torch.cat([k, v], dim=1) + return q, cache_kv + + def _context_attention_kernel( + self, q, kv, infer_state: NeoChatInferStateInfo, layer_weight, out=None + ) -> torch.Tensor: + o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out + kv = infer_state.mem_manager.kv_buffer[self.layer_num_] + context_attention_fwd_neo( + q.view(-1, self.tp_q_head_num_, self.head_dim_ * 2), + kv[:, 0 : self.tp_k_head_num_, :], + kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :], + o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2), + infer_state.position_ids[0], # [0,0,1,2,3,3,3,4] + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.b_ready_cache_len, + infer_state.max_len_in_batch, + infer_state.req_manager.req_to_token_indexs, + ) + o3 = o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2) + o3 = o3[:, :, : self.head_dim_].contiguous() + return o3.view(o3.shape[0], -1) + + def _token_attention_kernel(self, q, infer_state: NeoChatInferStateInfo, layer_weight): + total_token_num = infer_state.total_token_num + batch_size = infer_state.batch_size + + q_3d = q.view(batch_size, self.tp_q_head_num_, self.head_dim_ * 2) + + att_m_tensor = self.alloc_tensor((self.tp_q_head_num_, total_token_num), torch.float32) + + k_3d = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] + token_att_fwd( + q_3d, + k_3d, + att_m_tensor, + infer_state.req_manager.req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_kv_start_loc, + infer_state.b_seq_len, + infer_state.max_kv_seq_len, + ) + + from lightllm.common.basemodel.triton_kernel.att.decode_att.mha.stage3_decode_att import ( + token_attention_softmax_and_reducev, + ) + + token_softmax_reducev_fwd = token_attention_softmax_and_reducev.token_softmax_reducev_fwd + + v_3d = infer_state.mem_manager.kv_buffer[self.layer_num_][ + :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : self.head_dim_ + ] + + o_3d = self.alloc_tensor((batch_size, self.tp_q_head_num_, self.head_dim_), q.dtype) + + token_softmax_reducev_fwd( + att_m_tensor, + v_3d, + o_3d, + infer_state.req_manager.req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_kv_start_loc, + infer_state.b_seq_len, + ) + return o_3d.view(batch_size, -1) diff --git a/lightllm/models/neo_chat/layer_weights/__init__.py b/lightllm/models/neo_chat/layer_weights/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/neo_chat/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/neo_chat/layer_weights/pre_and_post_layer_weight.py new file mode 100644 index 0000000000..e6489f39af --- /dev/null +++ b/lightllm/models/neo_chat/layer_weights/pre_and_post_layer_weight.py @@ -0,0 +1,23 @@ +import torch +import numpy as np +from lightllm.models.qwen2.layer_weights.pre_and_post_layer_weight import Qwen2PreAndPostLayerWeight + +# add key: language_model.xxx -> xxx +# only change keys at PreAndPostLayerWeight load, TransformLayerWeight is correct now +def rename_weight_keys(weights): + prefix = "language_model." + keys = list(weights.keys()) + for k in keys: + if prefix in k: + weights[k.replace(prefix, "")] = weights.pop(k) + + +class NeoChatPreAndPostLayerWeight(Qwen2PreAndPostLayerWeight): + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) + return + + def load_hf_weights(self, weights): + rename_weight_keys(weights) + super().load_hf_weights(weights) + return diff --git a/lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py b/lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py new file mode 100644 index 0000000000..e62afae9bc --- /dev/null +++ b/lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py @@ -0,0 +1,67 @@ +from lightllm.models.qwen3.layer_weights.transformer_layer_weight import Qwen3TransformerLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ( + QKRMSNORMWeight, + ROWMMWeight, +) + + +class NeoChatTransformerLayerWeight(Qwen3TransformerLayerWeight): + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) + return + + def _init_weight_names(self): + super()._init_weight_names() + self._q_weight_hw_name = f"model.layers.{self.layer_num_}.self_attn.q_proj_hw.weight" + self._q_bias_hw_name = None + self._k_weight_hw_name = f"model.layers.{self.layer_num_}.self_attn.k_proj_hw.weight" + self._k_bias_hw_name = None + + self._q_norm_h_name = f"model.layers.{self.layer_num_}.self_attn.q_norm_h.weight" + self._q_norm_w_name = f"model.layers.{self.layer_num_}.self_attn.q_norm_w.weight" + + self._k_norm_h_name = f"model.layers.{self.layer_num_}.self_attn.k_norm_h.weight" + self._k_norm_w_name = f"model.layers.{self.layer_num_}.self_attn.k_norm_w.weight" + + def _init_qkv(self): + super()._init_qkv() + self.q_hw_proj = ROWMMWeight( + in_dim=self.network_config_["hidden_size"], + out_dims=[self.q_head_num_ * self.head_dim], + weight_names=self._q_weight_hw_name, + data_type=self.data_type_, + bias_names=self._q_bias_hw_name, + quant_method=self.get_quant_method("q_hw_proj"), + ) + self.k_hw_proj = ROWMMWeight( + in_dim=self.network_config_["hidden_size"], + out_dims=[self.k_head_num_ * self.head_dim], + weight_names=self._k_weight_hw_name, + data_type=self.data_type_, + bias_names=self._k_bias_hw_name, + quant_method=self.get_quant_method("k_hw_proj"), + ) + + def _init_norm(self): + super()._init_norm() + + self.q_norm_h_weight_ = QKRMSNORMWeight( + dim=self.head_dim // 2, + weight_name=self._q_norm_h_name, + data_type=self.data_type_, + ) + self.q_norm_w_weight_ = QKRMSNORMWeight( + dim=self.head_dim // 2, + weight_name=self._q_norm_w_name, + data_type=self.data_type_, + ) + self.k_norm_h_weight_ = QKRMSNORMWeight( + dim=self.head_dim // 2, + weight_name=self._k_norm_h_name, + data_type=self.data_type_, + ) + self.k_norm_w_weight_ = QKRMSNORMWeight( + dim=self.head_dim // 2, + weight_name=self._k_norm_w_name, + data_type=self.data_type_, + ) diff --git a/lightllm/models/neo_chat/model.py b/lightllm/models/neo_chat/model.py new file mode 100644 index 0000000000..14d9f96dc7 --- /dev/null +++ b/lightllm/models/neo_chat/model.py @@ -0,0 +1,53 @@ +import os +import json +from lightllm.common.build_utils import repair_config +from lightllm.models.registry import ModelRegistry, llm_model_type_is +from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo +from lightllm.models.qwen3_vl.layer_infer.pre_layer_infer import Qwen3VLMultimodalPreLayerInfer +from lightllm.models.qwen3_vl.layer_infer.transformer_layer_infer import Qwen3VLTransformerLayerInfer +from lightllm.models.qwen3_vl.layer_weights.pre_and_post_layer_weight import Qwen3VLPreAndPostLayerWeight +from lightllm.models.qwen2_vl.model import QWen2VLTokenizer +from lightllm.models.qwen3.model import Qwen3TpPartModel +from lightllm.server.core.objs import SamplingParams +from lightllm.models.qwen3_moe.model import Qwen3MOEModel +from lightllm.server.multimodal_params import AudioItem, MultimodalParams, ImageItem +from lightllm.models.neo_chat_moe.vision_process import smart_resize +from lightllm.models.internvl.model import InternvlTokenizer +from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer +from lightllm.models.neo_chat.layer_infer.transformer_layer_infer import NeoChatTransformerLayerInfer +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.models.neo_chat.layer_weights.transformer_layer_weight import NeoChatTransformerLayerWeight +from lightllm.models.neo_chat.layer_weights.pre_and_post_layer_weight import NeoChatPreAndPostLayerWeight +from lightllm.common.basemodel.multimodal_tokenizer import BaseMultiModalTokenizer +from lightllm.models.neo_chat_moe.infer_struct import NeoChatInferStateInfo + + +@ModelRegistry(["neo_chat"], is_multimodal=True, condition=llm_model_type_is("qwen3")) +class NeoTpPartModel(Qwen3TpPartModel): + + pre_layer_infer_class = LlamaMultimodalPreLayerInfer + transformer_layer_infer_class = NeoChatTransformerLayerInfer + + pre_and_post_weight_class = NeoChatPreAndPostLayerWeight + transformer_weight_class = NeoChatTransformerLayerWeight + + infer_state_class = NeoChatInferStateInfo + + def __init__(self, kvargs): + super().__init__(kvargs) + return + + def _init_inferstate_cls(self): + pass + + def _init_config(self): + with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: + all_config = json.load(json_file) + self.config = all_config["llm_config"] + # rename keys + repair_config(self.config, same_names=["num_attention_heads", "n_head"]) + repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"]) + repair_config(self.config, same_names=["num_hidden_layers", "n_layer"]) + if self.finetune_config: + self.config["vocab_size"] = self.finetune_config.vocab_size + return diff --git a/lightllm/models/neo_chat_moe/__init__.py b/lightllm/models/neo_chat_moe/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/neo_chat_moe/infer_struct.py b/lightllm/models/neo_chat_moe/infer_struct.py new file mode 100644 index 0000000000..961ed2a61d --- /dev/null +++ b/lightllm/models/neo_chat_moe/infer_struct.py @@ -0,0 +1,103 @@ +from typing import Optional, List +import torch +import numpy as np +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.common.req_manager import ReqManager +from lightllm.models.neo_chat_moe.triton_kernel.get_neo_position import get_neo_position_triton +from lightllm.models.llama.model import LlamaTpPartModel + + +class NeoChatInferStateInfo(LlamaInferStateInfo): + def __init__(self): + super().__init__() + self.position_cos = None + self.position_sin = None + self.position_cos_h = None + self.position_sin_h = None + self.position_cos_w = None + self.position_sin_w = None + + def init_some_extra_state(self, model: LlamaTpPartModel): + LlamaInferStateInfo.init_some_extra_state(self, model) + if self.is_prefill: + self.b_image_token_tag = torch.zeros([self.position_ids.size(0)], dtype=torch.bool, device="cpu").cuda( + non_blocking=True + ) + self.position_ids = self.get_neo_position(self.multimodal_params) + else: + b_position_delta = [0 for _ in range(self.b_seq_len.shape[0])] + for batch_idx, p in enumerate(self.multimodal_params): + position_delta = 0 + for image in p["images"]: + position_delta += image["grid_thwd"][3] + b_position_delta[batch_idx] = position_delta + position_ids = self.position_ids + torch.tensor(b_position_delta, device=self.position_ids.device) + self.position_ids = position_ids.unsqueeze(0).expand(3, -1).clone() + self.position_ids[1:].zero_() + + self.position_ids = self.position_ids.contiguous() + self.position_cos = model._cos_cached[self.position_ids[0]] + self.position_sin = model._sin_cached[self.position_ids[0]] + self.position_cos_h = model._hw_cos_cached[self.position_ids[1]] + self.position_sin_h = model._hw_sin_cached[self.position_ids[1]] + self.position_cos_w = model._hw_cos_cached[self.position_ids[2]] + self.position_sin_w = model._hw_sin_cached[self.position_ids[2]] + return + + def get_neo_position(self, multimodal_params: List[dict]) -> torch.Tensor: + if len(multimodal_params) == 0: + position_ids = self.position_ids.new_zeros((3, self.position_ids.size(0))) + position_ids[0].copy_(self.position_ids) + return position_ids + b_image_start_idx = [] + b_image_nums = [] + b_image_start_num = [] + b_image_len = [] + image_start_num = 0 + b_image_thwd = [] + + # pad multimodal_params to batch size. + batch_size = self.b_q_seq_len.shape[0] + multimodal_params = multimodal_params + [ + {"images": [], "audios": []} for _ in range(batch_size - len(multimodal_params)) + ] + + for _, p in enumerate(multimodal_params): + images = p.get("images", []) + for img in images: + b_image_start_idx.append(img["start_idx"]) + a = img["start_idx"] + print(f"img start_idx: {a}") + b_image_len.append(img["token_num"]) + b_image_thwd.append(img["grid_thwd"]) + b_image_nums.append(len(images)) + b_image_start_num.append(image_start_num) + image_start_num += len(images) + + # 没有任何图片 + if image_start_num == 0: + position_ids = self.position_ids.new_zeros((3, self.position_ids.size(0))) + position_ids[0].copy_(self.position_ids) + return position_ids.contiguous() + b_image_start_idx = torch.tensor(b_image_start_idx, device="cpu").cuda(non_blocking=True) + b_image_thwd = torch.tensor(b_image_thwd, device="cpu").cuda(non_blocking=True) # image_num x 4 + b_image_nums = torch.tensor(b_image_nums, device="cpu").cuda(non_blocking=True) + b_image_start_num = torch.tensor(b_image_start_num, device="cpu").cuda(non_blocking=True) + b_image_len = torch.tensor(b_image_len, device="cpu").cuda(non_blocking=True) + + position_ids = self.position_ids.new_zeros((3, self.position_ids.size(0))) + position_ids[0].copy_(self.position_ids) + + get_neo_position_triton( + b_image_start_idx=b_image_start_idx, + b_image_thwd=b_image_thwd, + b_image_nums=b_image_nums, + b_image_start_num=b_image_start_num, + b_image_len=b_image_len, + position_ids=position_ids, + b_ready_cache_len=self.b_ready_cache_len, + b_q_seq_len=self.b_q_seq_len, + b_start_loc=self.b_q_start_loc, + b_image_token_tag=self.b_image_token_tag, + ) + return position_ids diff --git a/lightllm/models/neo_chat_moe/layer_infer/__init__.py b/lightllm/models/neo_chat_moe/layer_infer/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py new file mode 100644 index 0000000000..1518d68748 --- /dev/null +++ b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py @@ -0,0 +1,250 @@ +import torch +from functools import partial +from typing import Tuple +from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd +from lightllm.models.neo_chat_moe.infer_struct import NeoChatInferStateInfo +from lightllm.models.neo_chat_moe.triton_kernel.context_attention_fwd_neo import context_attention_fwd_neo +from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd +from lightllm.models.qwen3_moe.layer_infer.transformer_layer_infer import Qwen3MOETransformerLayerInfer +from lightllm.models.neo_chat_moe.layer_weights.transformer_layer_weight import NeoChatMOETransformerLayerWeight +from lightllm.distributed import all_reduce +import torch.distributed as dist +from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer +from lightllm.common.basemodel.attention.base_att import AttControl + + +class NeoChatMOETransformerLayerInfer(Qwen3MOETransformerLayerInfer): + def __init__(self, data_type, network_config): + self._is_merge_kv = network_config.get("merge_kv", True) + super().__init__(data_type, network_config) + return + + def _bind_attention(self): + self._context_attention_kernel = self._context_attention_kernel + self._token_attention_kernel = self._token_decode_attention_normal + self._copy_kv_to_mem_cache = self._copy_kv_to_mem_cache_normal + return + + def _get_qkv(self, input, infer_state: NeoChatInferStateInfo, layer_weight: NeoChatMOETransformerLayerWeight): + if self._is_merge_kv: + return self._get_qkv_mergekv(input, infer_state, layer_weight) + else: + return self._get_qkv_not_mergekv(input, infer_state, layer_weight) + + def _get_qkv_not_mergekv( + self, input, infer_state: NeoChatInferStateInfo, layer_weight: NeoChatMOETransformerLayerWeight + ): + input = input.view(-1, self.embed_dim_) + q = layer_weight.q_proj.mm(input) # [T, Hq*D] + + q_hw = layer_weight.q_hw_proj.mm(input) + q_hw = q_hw.view(-1, self.tp_q_head_num_, self.head_dim_) + q_h, q_w = q_hw.chunk(2, dim=-1) + + k_hw = layer_weight.k_hw_proj.mm(input) + k_hw = k_hw.view(-1, self.tp_k_head_num_, self.head_dim_) + k_h, k_w = k_hw.chunk(2, dim=-1) + + cache_kv = layer_weight.kv_proj.mm(input) # [T, (Hk+Hv)*D] + + layer_weight.q_norm_weight_(q, eps=self.eps_) + + q_h_2d = q_h.reshape(q.shape[0], -1) + q_w_2d = q_w.reshape(q.shape[0], -1) + layer_weight.q_norm_h_weight_(q_h_2d, eps=self.eps_) + layer_weight.q_norm_w_weight_(q_w_2d, eps=self.eps_) + q_h = q_h_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2) + q_w = q_w_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2) + + layer_weight.k_norm_weight_( + cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], + eps=self.eps_, + ) + + k_h_2d = k_h.reshape(q.shape[0], -1) # [T, Hk*(D/2)] + k_w_2d = k_w.reshape(q.shape[0], -1) + layer_weight.k_norm_h_weight_(k_h_2d, eps=self.eps_) + layer_weight.k_norm_w_weight_(k_w_2d, eps=self.eps_) + k_h = k_h_2d.view(q.shape[0], self.tp_k_head_num_, self.head_dim_ // 2) + k_w = k_w_2d.view(q.shape[0], self.tp_k_head_num_, self.head_dim_ // 2) + + cache_kv = cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) + + rotary_emb_fwd( + q.view(-1, self.tp_q_head_num_, self.head_dim_), + cache_kv[:, : self.tp_k_head_num_, :], + infer_state.position_cos, + infer_state.position_sin, + ) + rotary_emb_fwd( + q_h, + k_h, + infer_state.position_cos_h, + infer_state.position_sin_h, + ) + rotary_emb_fwd( + q_w, + k_w, + infer_state.position_cos_w, + infer_state.position_sin_w, + ) + + q3 = q.view(-1, self.tp_q_head_num_, self.head_dim_) + q3 = torch.cat([q3, q_h, q_w], dim=-1) + q = q3.reshape(q3.shape[0], -1) + + k = cache_kv[:, : self.tp_k_head_num_, :] + k = torch.cat([k, k_h, k_w], dim=-1) + + v = cache_kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :] + v_pad = torch.zeros((v.shape[0], v.shape[1], self.head_dim_), device=v.device, dtype=v.dtype) + v = torch.cat([v, v_pad], dim=-1) + + cache_kv = torch.cat([k, v], dim=1) + return q, cache_kv + + def _get_qkv_mergekv( + self, input, infer_state: NeoChatInferStateInfo, layer_weight: NeoChatMOETransformerLayerWeight + ): + input = input.view(-1, self.embed_dim_) + + qkv = layer_weight.qkv_proj.mm(input) + q, cache_kv = qkv.split( + [self.tp_q_head_num_ * self.head_dim_, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_], dim=-1 + ) + q_hw = layer_weight.q_hw_proj.mm(input) + k_hw = layer_weight.k_hw_proj.mm(input) + + layer_weight.q_norm_weight_(q, eps=self.eps_) + layer_weight.q_norm_hw_weight_(q_hw, eps=self.eps_) + layer_weight.k_norm_hw_weight_(k_hw, eps=self.eps_) + + q_hw = q_hw.view(q.shape[0], self.tp_q_head_num_, self.head_dim_) + q_h, q_w = q_hw.chunk(2, dim=-1) + + layer_weight.k_norm_weight_( + cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], + eps=self.eps_, + ) + + k_hw = k_hw.view(q.shape[0], self.tp_k_head_num_, self.head_dim_) + k_h, k_w = k_hw.chunk(2, dim=-1) + + cache_kv = cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) + + rotary_emb_fwd( + q.view(-1, self.tp_q_head_num_, self.head_dim_), + cache_kv[:, : self.tp_k_head_num_, :], + infer_state.position_cos, + infer_state.position_sin, + ) + rotary_emb_fwd( + q_h, + k_h, + infer_state.position_cos_h, + infer_state.position_sin_h, + ) + rotary_emb_fwd( + q_w, + k_w, + infer_state.position_cos_w, + infer_state.position_sin_w, + ) + + q3 = q.view(-1, self.tp_q_head_num_, self.head_dim_) + q3 = torch.cat([q3, q_h, q_w], dim=-1) + q = q3.reshape(q3.shape[0], -1) + + k = cache_kv[:, : self.tp_k_head_num_, :] + k = torch.cat([k, k_h, k_w], dim=-1) + + v = cache_kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :] + v_pad = torch.zeros((v.shape[0], v.shape[1], self.head_dim_), device=v.device, dtype=v.dtype) + v = torch.cat([v, v_pad], dim=-1) + + cache_kv = torch.cat([k, v], dim=1) + return q, cache_kv + + def _context_attention_kernel( + self, q, kv, infer_state: NeoChatInferStateInfo, layer_weight, out=None + ) -> torch.Tensor: + o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out + kv = infer_state.mem_manager.kv_buffer[self.layer_num_] + context_attention_fwd_neo( + q.view(-1, self.tp_q_head_num_, self.head_dim_ * 2), + kv[:, 0 : self.tp_k_head_num_, :], + kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :], + o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2), + infer_state.position_ids[0], # [0,0,1,2,3,3,3,4] + infer_state.b_req_idx, + infer_state.b_q_start_loc, + infer_state.b_seq_len, + infer_state.b_ready_cache_len, + infer_state.max_q_seq_len, + infer_state.req_manager.req_to_token_indexs, + infer_state.b_image_token_tag, + ) + o3 = o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2) + o3 = o3[:, :, : self.head_dim_].contiguous() + return o3.view(o3.shape[0], -1) + + def _token_attention_kernel( + self, + q: torch.Tensor, + infer_state: NeoChatInferStateInfo, + layer_weight: NeoChatMOETransformerLayerWeight, + ) -> torch.Tensor: + _k, _v = infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_) + _q = q.view(-1, self.tp_q_head_num_, self.head_dim_ * 2) + att_control = AttControl() + if att_control.scale is None: + att_control.scale = 1.0 / (self.head_dim_ ** 0.5) + # att_control.mla_decode_dict["softmax_scale"] = 1.0 / (self.head_dim_ ** 0.5) + o_tensor = infer_state.decode_att_state.decode_att( + q=_q, k=_k, v=_v, att_control=att_control, alloc_func=self.alloc_tensor + ) + o_tensor = o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2)[:, :, : self.head_dim_].contiguous() + return o_tensor + + # def _token_attention_kernel(self, q, infer_state: NeoChatInferStateInfo, layer_weight): + # total_token_num = infer_state.total_token_num + # batch_size = infer_state.batch_size + + # q_3d = q.view(batch_size, self.tp_q_head_num_, self.head_dim_ * 2) + + # att_m_tensor = self.alloc_tensor((self.tp_q_head_num_, total_token_num), torch.float32) + + # k_3d = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] + # token_att_fwd( + # q_3d, + # k_3d, + # att_m_tensor, + # infer_state.req_manager.req_to_token_indexs, + # infer_state.b_req_idx, + # infer_state.b_kv_start_loc, + # infer_state.b_seq_len, + # infer_state.max_kv_seq_len, + # ) + + # from lightllm.common.basemodel.triton_kernel.att.decode_att.mha.stage3_decode_att import ( + # token_attention_softmax_and_reducev, + # ) + + # token_softmax_reducev_fwd = token_attention_softmax_and_reducev.token_softmax_reducev_fwd + + # v_3d = infer_state.mem_manager.kv_buffer[self.layer_num_][ + # :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : self.head_dim_ + # ] + + # o_3d = self.alloc_tensor((batch_size, self.tp_q_head_num_, self.head_dim_), q.dtype) + + # token_softmax_reducev_fwd( + # att_m_tensor, + # v_3d, + # o_3d, + # infer_state.req_manager.req_to_token_indexs, + # infer_state.b_req_idx, + # infer_state.b_kv_start_loc, + # infer_state.b_seq_len, + # ) + # return o_3d.view(batch_size, -1) diff --git a/lightllm/models/neo_chat_moe/layer_weights/__init__.py b/lightllm/models/neo_chat_moe/layer_weights/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/neo_chat_moe/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/neo_chat_moe/layer_weights/pre_and_post_layer_weight.py new file mode 100644 index 0000000000..4b0eae91c3 --- /dev/null +++ b/lightllm/models/neo_chat_moe/layer_weights/pre_and_post_layer_weight.py @@ -0,0 +1,23 @@ +import torch +import numpy as np +from lightllm.models.qwen2.layer_weights.pre_and_post_layer_weight import Qwen2PreAndPostLayerWeight + +# add key: language_model.xxx -> xxx +# only change keys at PreAndPostLayerWeight load, TransformLayerWeight is correct now +def rename_weight_keys(weights): + prefix = "language_model." + keys = list(weights.keys()) + for k in keys: + if prefix in k: + weights[k.replace(prefix, "")] = weights.pop(k) + + +class NeoChatMOEPreAndPostLayerWeight(Qwen2PreAndPostLayerWeight): + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) + return + + def load_hf_weights(self, weights): + rename_weight_keys(weights) + super().load_hf_weights(weights) + return diff --git a/lightllm/models/neo_chat_moe/layer_weights/transformer_layer_weight.py b/lightllm/models/neo_chat_moe/layer_weights/transformer_layer_weight.py new file mode 100644 index 0000000000..26e986cdd7 --- /dev/null +++ b/lightllm/models/neo_chat_moe/layer_weights/transformer_layer_weight.py @@ -0,0 +1,83 @@ +from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ( + QKRMSNORMWeight, + ROWMMWeight, +) + + +class NeoChatMOETransformerLayerWeight(Qwen3MOETransformerLayerWeight): + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + self._is_merge_kv = network_config.get("merge_kv", True) + super().__init__(layer_num, data_type, network_config, quant_cfg) + return + + def _init_weight_names(self): + super()._init_weight_names() + self._q_weight_hw_name = f"model.layers.{self.layer_num_}.self_attn.q_proj_hw.weight" + self._q_bias_hw_name = None + self._k_weight_hw_name = f"model.layers.{self.layer_num_}.self_attn.k_proj_hw.weight" + self._k_bias_hw_name = None + + if self._is_merge_kv: + self._q_norm_hw_name = f"model.layers.{self.layer_num_}.self_attn.q_norm_hw.weight" + self._k_norm_hw_name = f"model.layers.{self.layer_num_}.self_attn.k_norm_hw.weight" + else: + self._q_norm_h_name = f"model.layers.{self.layer_num_}.self_attn.q_norm_h.weight" + self._q_norm_w_name = f"model.layers.{self.layer_num_}.self_attn.q_norm_w.weight" + + self._k_norm_h_name = f"model.layers.{self.layer_num_}.self_attn.k_norm_h.weight" + self._k_norm_w_name = f"model.layers.{self.layer_num_}.self_attn.k_norm_w.weight" + + def _init_qkv(self): + super()._init_qkv() + self.q_hw_proj = ROWMMWeight( + in_dim=self.network_config_["hidden_size"], + out_dims=[self.q_head_num_ * self.head_dim], + weight_names=self._q_weight_hw_name, + data_type=self.data_type_, + bias_names=self._q_bias_hw_name, + quant_method=self.get_quant_method("q_hw_proj"), + ) + self.k_hw_proj = ROWMMWeight( + in_dim=self.network_config_["hidden_size"], + out_dims=[self.k_head_num_ * self.head_dim], + weight_names=self._k_weight_hw_name, + data_type=self.data_type_, + bias_names=self._k_bias_hw_name, + quant_method=self.get_quant_method("k_hw_proj"), + ) + + def _init_norm(self): + super()._init_norm() + if self._is_merge_kv: + self.q_norm_hw_weight_ = QKRMSNORMWeight( + dim=self.head_dim, + weight_name=self._q_norm_hw_name, + data_type=self.data_type_, + ) + self.k_norm_hw_weight_ = QKRMSNORMWeight( + dim=self.head_dim, + weight_name=self._k_norm_hw_name, + data_type=self.data_type_, + ) + else: + self.q_norm_h_weight_ = QKRMSNORMWeight( + dim=self.head_dim // 2, + weight_name=self._q_norm_h_name, + data_type=self.data_type_, + ) + self.q_norm_w_weight_ = QKRMSNORMWeight( + dim=self.head_dim // 2, + weight_name=self._q_norm_w_name, + data_type=self.data_type_, + ) + self.k_norm_h_weight_ = QKRMSNORMWeight( + dim=self.head_dim // 2, + weight_name=self._k_norm_h_name, + data_type=self.data_type_, + ) + self.k_norm_w_weight_ = QKRMSNORMWeight( + dim=self.head_dim // 2, + weight_name=self._k_norm_w_name, + data_type=self.data_type_, + ) diff --git a/lightllm/models/neo_chat_moe/model.py b/lightllm/models/neo_chat_moe/model.py new file mode 100644 index 0000000000..cf4404090f --- /dev/null +++ b/lightllm/models/neo_chat_moe/model.py @@ -0,0 +1,150 @@ +import os +import json +from lightllm.common.build_utils import repair_config +from lightllm.models.registry import ModelRegistry, llm_model_type_is +from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo +from lightllm.models.qwen3_vl.layer_infer.pre_layer_infer import Qwen3VLMultimodalPreLayerInfer +from lightllm.models.qwen3_vl.layer_infer.transformer_layer_infer import Qwen3VLTransformerLayerInfer +from lightllm.models.qwen3_vl.layer_weights.pre_and_post_layer_weight import Qwen3VLPreAndPostLayerWeight +from lightllm.models.qwen2_vl.model import QWen2VLTokenizer +from lightllm.models.qwen3.model import Qwen3TpPartModel +from lightllm.server.core.objs import SamplingParams +from lightllm.models.qwen3_moe.model import Qwen3MOEModel +from lightllm.server.multimodal_params import AudioItem, MultimodalParams, ImageItem +from lightllm.models.neo_chat_moe.vision_process import smart_resize +from lightllm.models.internvl.model import InternvlTokenizer +from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer +from lightllm.models.neo_chat_moe.layer_infer.transformer_layer_infer import NeoChatMOETransformerLayerInfer +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.models.neo_chat_moe.layer_weights.transformer_layer_weight import NeoChatMOETransformerLayerWeight +from lightllm.models.neo_chat_moe.layer_weights.pre_and_post_layer_weight import NeoChatMOEPreAndPostLayerWeight +from lightllm.common.basemodel.multimodal_tokenizer import BaseMultiModalTokenizer +from lightllm.models.neo_chat_moe.infer_struct import NeoChatInferStateInfo + +IMG_START_TOKEN = "" +IMG_END_TOKEN = "" +IMG_TOKEN = "" +AUDIO_START_TOKEN = "" + + +class NeoChatTokenizer(BaseMultiModalTokenizer): + def __init__(self, tokenizer, model_cfg, **kwargs): + super().__init__(tokenizer) + self.tokenizer = tokenizer + self.min_pixel = model_cfg.get("vision_config").get("min_pixels") + self.max_pixel = model_cfg.get("vision_config").get("max_pixels") + self.patch_size = model_cfg.get("vision_config").get("patch_size") + self.downsample_ratio = model_cfg.get("vision_config").get("downsample_ratio") + + self.image_token_id = model_cfg.get("image_token_id") + self.image_start_tag = IMG_START_TOKEN + self.image_start_id = tokenizer.convert_tokens_to_ids(self.image_start_tag) + self.image_end_tag = IMG_END_TOKEN + self.image_end_id = tokenizer.convert_tokens_to_ids(self.image_end_tag) + + def init_imageitem_extral_params( + self, img: ImageItem, multi_params: MultimodalParams, sampling_params: SamplingParams + ): + img.extra_params["min_pixels"] = ( + sampling_params.min_pixels if sampling_params.min_pixels > 0 else self.min_pixel + ) + img.extra_params["max_pixels"] = ( + sampling_params.max_pixels if sampling_params.max_pixels > 0 else self.max_pixel + ) + assert ( + img.extra_params["min_pixels"] <= img.extra_params["max_pixels"] + ), "min_pixels should be less than or equal to max_pixels" + return + + def init_audioitem_extral_params( + self, audio: AudioItem, multi_params: MultimodalParams, sampling_params: SamplingParams + ): + raise NotImplementedError + + def get_audio_token_length(self, audio: AudioItem): + raise NotImplementedError + + def get_image_token_length(self, img: ImageItem): + width, height = img.image_w, img.image_h + resized_height, resized_width = smart_resize( + height=height, + width=width, + factor=int(self.patch_size // self.downsample_ratio), + min_pixels=img.extra_params["min_pixels"], + max_pixels=img.extra_params["max_pixels"], + ) + grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size + token_num = int((grid_h * grid_w) * (self.downsample_ratio ** 2)) + # 这里的grid_h和grid_w需要* self.downsample_ratio么?再仔细看下代码 + img.grid_thwd = (1, int(grid_h * self.downsample_ratio), int(grid_w * self.downsample_ratio), 1 - token_num) + return token_num + + # only change the impl of the encode func: + def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs): + # TEXTTEXTTEXT --> TEXTTEXTTEXT + image_tokens = IMG_START_TOKEN + IMG_END_TOKEN + if multimodal_params is None: + add_special_tokens = kwargs.get("add_special_tokens", True) + return self.tokenizer.encode(prompt, add_special_tokens=add_special_tokens) + image_count = len(multimodal_params.images) + if not kwargs.get("already_tokenized", False): + prompt = prompt.replace(IMG_TOKEN, image_tokens, image_count) + origin_ids = self.tokenizer.encode(prompt, add_special_tokens=kwargs["add_special_tokens"]) + else: + origin_ids = prompt + # --> id,id+1...id+num + input_ids = [] + image_id = 0 + start_idx = 0 + while True: + try: + start_idx = origin_ids.index(self.image_start_id) + if start_idx + 1 >= len(origin_ids): + break + if origin_ids[start_idx + 1] == self.image_end_id: + input_ids.extend(origin_ids[: start_idx + 1]) + token_id = multimodal_params.images[image_id].token_id + token_num = multimodal_params.images[image_id].token_num + multimodal_params.images[image_id].start_idx = len(input_ids) + input_ids.extend(range(token_id, token_id + token_num)) + input_ids.append(self.image_end_id) + origin_ids = origin_ids[start_idx + 2 :] + image_id += 1 + else: + raise ValueError("image token error") + except ValueError: + break + input_ids.extend(origin_ids) + return input_ids + + +@ModelRegistry(["neo_chat"], is_multimodal=True, condition=llm_model_type_is("qwen3_moe")) +class NeoTpMOEPartModel(Qwen3MOEModel): + + pre_layer_infer_class = LlamaMultimodalPreLayerInfer + transformer_layer_infer_class = NeoChatMOETransformerLayerInfer + + pre_and_post_weight_class = NeoChatMOEPreAndPostLayerWeight + transformer_weight_class = NeoChatMOETransformerLayerWeight + + infer_state_class = NeoChatInferStateInfo + + def __init__(self, kvargs): + super().__init__(kvargs) + return + + def _init_inferstate_cls(self): + pass + + def _init_config(self): + with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: + all_config = json.load(json_file) + self.config = all_config["llm_config"] + # rename keys + repair_config(self.config, same_names=["num_attention_heads", "n_head"]) + repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"]) + repair_config(self.config, same_names=["num_hidden_layers", "n_layer"]) + if self.finetune_config: + self.config["vocab_size"] = self.finetune_config.vocab_size + return diff --git a/lightllm/models/neo_chat_moe/neo_visual.py b/lightllm/models/neo_chat_moe/neo_visual.py new file mode 100644 index 0000000000..60fa82f2b9 --- /dev/null +++ b/lightllm/models/neo_chat_moe/neo_visual.py @@ -0,0 +1,281 @@ +import os +import torch +import torch.nn.functional as F +from PIL import Image +from typing import List +from io import BytesIO +import torch.nn as nn +from transformers.activations import ACT2FN +from safetensors import safe_open +from lightllm.server.multimodal_params import ImageItem +from transformers.modeling_outputs import BaseModelOutputWithPooling +from transformers.modeling_utils import PreTrainedModel +from lightllm.models.neo_chat_moe.vision_process import load_image_native +from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data + + +def apply_rotary_emb_1d( + x: torch.Tensor, + cos_cached: torch.Tensor, + sin_cached: torch.Tensor, + positions: torch.Tensor, +): + """对输入张量的一部分应用1D RoPE。""" + # x: (..., seq_len, dim_part) + # positions: (..., seq_len) + # cos_cached: (max_pos, dim_part / 2) + cos_cached = cos_cached.to(device=positions.device) + sin_cached = sin_cached.to(device=positions.device) + + cos = cos_cached[positions] # Shape: (positions.shape, dim_part / 2) + sin = sin_cached[positions] # Shape: (positions.shape, dim_part / 2) + + x1 = x[..., 0::2] + x2 = x[..., 1::2] + + rotated_x1 = x1 * cos - x2 * sin + rotated_x2 = x1 * sin + x2 * cos + + x_rotated = torch.empty_like(x) + x_rotated[..., 0::2] = rotated_x1 + x_rotated[..., 1::2] = rotated_x2 + return x_rotated + + +def apply_2d_rotary_pos_emb( + x: torch.Tensor, + cos_cached_x: torch.Tensor, + sin_cached_x: torch.Tensor, + cos_cached_y: torch.Tensor, + sin_cached_y: torch.Tensor, + abs_positions_x: torch.Tensor, + abs_positions_y: torch.Tensor, +): + """应用2D RoPE到输入张量x。""" + dim = x.shape[-1] + dim_half = dim // 2 + + # 假设我们将embedding的前半部分用于一个方向的RoPE,后半部分用于另一个方向 + # 例如,前一半给X坐标,后一半给Y坐标 (或者反过来,但要保持一致) + x_part_1 = x[..., :dim_half] + x_part_2 = x[..., dim_half:] + + # 将与 abs_positions_x 相关的旋转应用于 x_part_1 + rotated_part_1 = apply_rotary_emb_1d(x_part_1, cos_cached_x, sin_cached_x, abs_positions_x) + # 将与 abs_positions_y 相关的旋转应用于 x_part_2 + rotated_part_2 = apply_rotary_emb_1d(x_part_2, cos_cached_y, sin_cached_y, abs_positions_y) + + # 将它们重新拼接起来。确保顺序与你分割时一致。 + return torch.cat((rotated_part_1, rotated_part_2), dim=-1) + + +def build_abs_positions_from_grid_hw(grid_hw: torch.Tensor, device=None): + """ + Compute patch coordinates (x, y) + + Args: + grid_hw: (B, 2) tensor representing (H, W) per image + """ + device = grid_hw.device + B = grid_hw.shape[0] + + # Get the number of patches per image + H = grid_hw[:, 0] + W = grid_hw[:, 1] + N = H * W + N_total = N.sum() + + # Create the batch index for each patch (B x patch count) + patch_to_sample = torch.repeat_interleave(torch.arange(B, device=device), N) # (N_total,) + + # Generate intra-image patch index (row-major order) + patch_id_within_image = torch.arange(N_total, device=device) + patch_id_within_image = ( + patch_id_within_image + - torch.cumsum(torch.cat([torch.tensor([0], device=device), N[:-1]]), dim=0)[patch_to_sample] + ) + + # Get H/W for each patch according to its image + W_per_patch = W[patch_to_sample] + abs_x = patch_id_within_image % W_per_patch + abs_y = patch_id_within_image // W_per_patch + + return abs_x, abs_y + + +class NeoVisionTransformerPretrainedModel(nn.Module): + def __init__( + self, + kvargs, + hidden_size: int = 1024, + llm_hidden_size: int = 2048, + downsample_ratio: float = 0.5, + patch_size: int = 16, + num_channels: int = 3, + max_position_embeddings_vision: int = 10000, + rope_theta_vision: float = 10000.0, + min_pixels: int = 65536, + max_pixels: int = 2408448, + **kwargs, + ): + super().__init__() + self.weight_dir = kvargs["weight_dir"] + self.data_type = kvargs.get("data_type", "bfloat16") + self.embed_dim = hidden_size + self.llm_hidden_size = llm_hidden_size + self.patch_size = patch_size + self.num_channels = num_channels + self.downsample_ratio = downsample_ratio + self.downsample_factor = int(1 / downsample_ratio) + self.max_position_embeddings_vision = max_position_embeddings_vision + self.rope_theta_vision = rope_theta_vision + self.rope_dim_part = self.embed_dim // 2 + self.min_pixels = min_pixels + self.max_pixels = max_pixels + + self.patch_embedding = nn.Conv2d( + in_channels=num_channels, out_channels=self.embed_dim, kernel_size=patch_size, stride=patch_size + ) + + self.dense_embedding = nn.Conv2d( + in_channels=self.embed_dim, + out_channels=self.llm_hidden_size, + kernel_size=self.downsample_factor, + stride=self.downsample_factor, + ) + self.gelu = nn.GELU() + + self.repe_dim_part = self.embed_dim // 2 + self.cos_x, self.sin_x = self.precompute_rope_freqs_sincos() + self.cos_y, self.sin_y = self.precompute_rope_freqs_sincos() + self._init_datatype() + + def _init_datatype(self): + if isinstance(self.data_type, torch.dtype): + return + if self.data_type in ["fp16", "float16"]: + self.data_type = torch.float16 + elif self.data_type in ["bf16", "bfloat16"]: + self.data_type = torch.bfloat16 + elif self.data_type in ["fp32", "float32"]: + self.data_type = torch.float32 + else: + raise ValueError(f"Unsupport datatype {self.data_type}!") + return + + def load_model(self, weight_dir): + bin_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")] + if bin_weight_files: + weight_dict = {} + for file_ in bin_weight_files: + f = torch.load(os.path.join(weight_dir, file_), "cpu") + for k, v in f.items(): + if "vision_model" in k: + weight_dict[k[len("vision_model.embeddings.") :]] = v + else: + hf_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".safetensors")] + weight_dict = {} + for file_ in hf_weight_files: + f = safe_open(os.path.join(weight_dir, file_), "pt", "cpu") + for k in f.keys(): + if "vision_model" in k: + weight_dict[k[len("vision_model.embeddings.") :]] = f.get_tensor(k) + self.load_state_dict(weight_dict) + + def precompute_rope_freqs_sincos(self): + inv_freq = 1.0 / ( + self.rope_theta_vision ** (torch.arange(0, self.rope_dim_part, 2).float() / self.rope_dim_part) + ) + t = torch.arange(self.max_position_embeddings_vision).type_as(inv_freq) + freqs = torch.outer(t, inv_freq) + return torch.cos(freqs), torch.sin(freqs) + + def _apply_2d_rotary_pos_emb(self, patch_embeds, grid_hw): + """ + Apply 2D Rotary Position Embedding to the patch embeddings. + """ + abs_pos_x, abs_pos_y = build_abs_positions_from_grid_hw(grid_hw, device=patch_embeds.device) + embeddings = apply_2d_rotary_pos_emb( + patch_embeds.to(torch.float32), # RoPE calculations are often more stable in float32 + self.cos_x, + self.sin_x, + self.cos_y, + self.sin_y, + abs_pos_x, + abs_pos_y, + ).to(self.patch_embedding.weight.dtype) + return embeddings + + def forward(self, pixel_values: torch.Tensor, grid_hw: torch.Tensor) -> torch.Tensor: + pixel_values = pixel_values.view( + -1, + 3, + self.patch_size, + self.patch_size, + ) + patch_embeds = self.gelu(self.patch_embedding(pixel_values)).view(-1, self.embed_dim) + patch_embeds = self._apply_2d_rotary_pos_emb(patch_embeds, grid_hw) + assert (grid_hw[:, 0] * grid_hw[:, 1]).sum() == patch_embeds.shape[ + 0 + ], "Grid size and patch embeds size mismatch." + + patches_list = [] + cur_position = 0 + for i in range(grid_hw.shape[0]): + h, w = grid_hw[i] + patches_per_img = patch_embeds[cur_position : cur_position + h * w].view(h, w, -1).unsqueeze(0) + patches_per_img = self.dense_embedding(patches_per_img.permute(0, 3, 1, 2)) + patches_per_img = patches_per_img.permute(0, 2, 3, 1) + patches_list.append(patches_per_img.view(-1, patches_per_img.shape[-1])) + cur_position += h * w + + embeddings = torch.cat(patches_list, dim=0) # (N_total // downsample_factor**2, C) + assert cur_position == patch_embeds.shape[0] + assert embeddings.shape[0] == int(patch_embeds.shape[0] / self.downsample_factor ** 2) + + return embeddings + + def encode(self, images: List[ImageItem]): + img_tensors = [] + valid_ids = [] + valid_id = 0 + img_grids = [] + uuids = [] + + for i, img in enumerate(images): + if isinstance(img, ImageItem): + uuids.append(img.uuid) + image_data = read_shm(get_shm_name_data(img.uuid)) + image_data = Image.open(BytesIO(image_data)) + # a = img.extra_params["min_pixels"] + # b = img.extra_params["max_pixels"] + # print(f"self.min_pixels is {a} ,max_pixelx is {b}") + pixel_values, image_grid_hw = load_image_native( + image_data, + patch_size=self.patch_size, + downsample_ratio=self.downsample_ratio, + min_pixels=img.extra_params["min_pixels"], + max_pixels=img.extra_params["max_pixels"], + ) + img_tensors.append(pixel_values) + img_grids.append(image_grid_hw) + else: + raise Exception("Unsupport input types: {} for {}".format(type(img), img)) + + # must devide merge_length + cur_num = int(img_tensors[-1].shape[0] * (self.downsample_ratio ** 2)) + valid_ids.append([valid_id, valid_id + cur_num]) + valid_id += cur_num + + if len(img_tensors) <= 0: + return None + + imgs = torch.cat(img_tensors, dim=0) + grid_hw = torch.cat(img_grids, dim=0) + + pixel_values = imgs.to("cuda", dtype=self.data_type, non_blocking=True) + image_grid_hw = grid_hw.to("cuda", non_blocking=True) + + all_img_embeds = self.forward(pixel_values, grid_hw=image_grid_hw) + + return all_img_embeds, uuids, valid_ids diff --git a/lightllm/models/neo_chat_moe/triton_kernel/__init__.py b/lightllm/models/neo_chat_moe/triton_kernel/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py b/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py new file mode 100644 index 0000000000..42c3254e27 --- /dev/null +++ b/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py @@ -0,0 +1,431 @@ +import math +import torch +import triton +import triton.language as tl + +from lightllm.utils.device_utils import is_tesla + + +@triton.jit +def _fwd_kernel( + Q, + K, + V, + sm_scale, + Out, + position_ids, # 1D: packed like Q (only NEW tokens), length == Q.shape[0] + B_Start_Loc, + B_Seqlen, + Req_to_tokens, + B_req_idx, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_req_to_tokens_b, + stride_req_to_tokens_s, + kv_group_num, + b_prompt_cache_len, + b_image_token_tag, + H: tl.constexpr, + QK_HEAD_DIM: tl.constexpr, + V_HEAD_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + start_m = tl.program_id(0) + cur_bh = tl.program_id(1) + cur_batch = cur_bh // H + cur_head = cur_bh % H + + cur_kv_head = cur_head // kv_group_num + + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + prompt_cache_len = tl.load(b_prompt_cache_len + cur_batch) + total_len = tl.load(B_Seqlen + cur_batch) + cur_batch_seq_len = total_len - prompt_cache_len # NEW len + cur_batch_req_idx = tl.load(B_req_idx + cur_batch) + + block_start_loc = BLOCK_M * start_m + if block_start_loc >= cur_batch_seq_len: + return + + offs_n = tl.arange(0, BLOCK_N) + offs_d_qk = tl.arange(0, QK_HEAD_DIM) + offs_d_v = tl.arange(0, V_HEAD_DIM) + offs_m = block_start_loc + tl.arange(0, BLOCK_M) + + # Q pointers + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + + offs_d_qk[None, :] * stride_qd + ) + + q_valid = offs_m < cur_batch_seq_len + q = tl.load(Q + off_q, mask=q_valid[:, None], other=0.0) + + # online softmax state + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM], dtype=tl.float32) + block_end_loc = total_len + + # absolute q positions in the request + q_pos = prompt_cache_len + offs_m # [M] + q_image_token_tag = tl.load(b_image_token_tag + cur_batch_in_all_start_index + offs_m, mask=q_valid, other=False) + + for start_n in range(0, block_end_loc, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + + k_pos = start_n + offs_n # [N] + k_valid = k_pos < block_end_loc + + # map logical pos -> mem_index (for K/V) + kv_loc = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * k_pos, + mask=k_valid, + other=0, + ).to(tl.int64) + + # load K + off_k = kv_loc[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d_qk[:, None] * stride_kd + k = tl.load(K + off_k, mask=k_valid[None, :], other=0.0) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + + # mask: causal OR same gid (only possible inside NEW part) + mask = (q_pos[:, None] >= k_pos[None, :]) | q_image_token_tag[:, None] + qk = tl.where(mask, qk * sm_scale, -1.0e8) + + # online softmax + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_ij[:, None] + p = tl.math.exp2(qk) + l_ij = tl.sum(p, 1) + + alpha = tl.math.exp2(m_i - m_ij) + l_i = l_i * alpha + l_ij + acc = acc * alpha[:, None] + + # load V + off_v = kv_loc[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d_v[None, :] * stride_vd + v = tl.load(V + off_v, mask=k_valid[:, None], other=0.0) + + p = p.to(v.dtype) + acc = tl.dot(p, v, acc) + + m_i = m_ij + + acc = acc / l_i[:, None] + + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + + offs_d_v[None, :] * stride_od + ) + tl.store(Out + off_o, acc, mask=q_valid[:, None]) + + +@torch.no_grad() +def context_attention_fwd_neo( + q, + k, + v, + o, + position_ids, # 1D packed like q (only NEW tokens) + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + max_input_len, + req_to_token_indexs, + b_image_token_tag, +): + # minimal safety: position_ids must cover packed q rows + assert position_ids.numel() >= q.shape[0], (position_ids.numel(), q.shape[0]) + + BLOCK_M = 128 if not is_tesla() else 64 + + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128, 256} + base_head_dim = Lq // 2 + sm_scale = 1.0 / (base_head_dim ** 0.5) * 1.4426950408889634 + + batch, head = b_seq_len.shape[0], q.shape[1] + kv_group_num = q.shape[1] // k.shape[1] + + grid = lambda meta: (triton.cdiv(max_input_len, meta["BLOCK_M"]), batch * head, 1) + + BLOCK_N = BLOCK_M + num_warps = 4 if Lk <= 64 else 8 + num_stages = 1 + + _fwd_kernel[grid]( + q, + k, + v, + sm_scale, + o, + position_ids, + b_start_loc, + b_seq_len, + req_to_token_indexs, + b_req_idx, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + req_to_token_indexs.stride(0), + req_to_token_indexs.stride(1), + kv_group_num=kv_group_num, + b_prompt_cache_len=b_prompt_cache_len, + b_image_token_tag=b_image_token_tag, + H=head, + QK_HEAD_DIM=Lk, + V_HEAD_DIM=Lk // 2, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + num_warps=num_warps, + num_stages=num_stages, + ) + + +def reference_attention( + q, + k, + v, + position_ids_q, # 1D packed like q (only NEW tokens) + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + req_to_token_indexs, +): + device = q.device + dtype = q.dtype + sum_q, Hq, D = q.shape + Hk = k.shape[1] + kv_group_num = Hq // Hk + + batch = b_seq_len.shape[0] + out = torch.empty_like(q) + scale = 1.0 / math.sqrt(D) + + for b in range(batch): + req = int(b_req_idx[b].item()) + total_len = int(b_seq_len[b].item()) + prompt_len = int(b_prompt_cache_len[b].item()) + new_len = total_len - prompt_len + + q_start = int(b_start_loc[b].item()) + q_blk = q[q_start : q_start + new_len] # [M, Hq, D] + gid_new = position_ids_q[q_start : q_start + new_len].to(torch.int64) # [M] + + # gather K/V for full request by logical pos -> mem_index + token_locs = req_to_token_indexs[req, :total_len].to(torch.int64) # [L] + k_blk = k[token_locs] # [L, Hk, D] + v_blk = v[token_locs] # [L, Hk, D] + + # expand kv heads to q heads (GQA) + k_hq = k_blk.repeat_interleave(kv_group_num, dim=1) # [L, Hq, D] + v_hq = v_blk.repeat_interleave(kv_group_num, dim=1) # [L, Hq, D] + + # positions + q_pos = torch.arange(prompt_len, total_len, device=device, dtype=torch.int64) # [M] + k_pos = torch.arange(0, total_len, device=device, dtype=torch.int64) # [L] + + # build allow mask: + # causal always + allow = k_pos[None, :] <= q_pos[:, None] + + # full-attn only inside NEW part by gid + # compare only when k_pos in NEW + k_in_new = k_pos >= prompt_len + k_rel = (k_pos - prompt_len).clamp_min(0) # [L] + # map k_rel to gid_new, but only valid where k_in_new + k_gid = torch.empty((total_len,), device=device, dtype=torch.int64) + k_gid[:] = 10 ** 12 + k_pos # never equal to gid_new + k_gid[k_in_new] = gid_new[k_rel[k_in_new]] + + allow = allow | (gid_new[q_pos - prompt_len][:, None] == k_gid[None, :]) + + # scores: [Hq, M, L] + q_t = q_blk.permute(1, 0, 2).to(torch.float32) # [Hq, M, D] + k_t = k_hq.permute(1, 2, 0).to(torch.float32) # [Hq, D, L] + scores = torch.matmul(q_t, k_t) * scale # [Hq, M, L] + + neg = torch.tensor(-1.0e9, device=device, dtype=torch.float32) + scores = torch.where(allow[None, :, :], scores, neg) + + p = torch.softmax(scores, dim=-1).to(torch.float32) # [Hq, M, L] + v_t = v_hq.permute(1, 0, 2).to(torch.float32) # [Hq, L, D] + out_hq = torch.matmul(p, v_t) # [Hq, M, D] + out_blk = out_hq.permute(1, 0, 2).to(dtype) # [M, Hq, D] + + out[q_start : q_start + new_len] = out_blk + + return out + + +def make_test_case( + device="cuda", + dtype=torch.float16, + batch=3, + Hq=8, + Hk=4, + D=64, + seed=0, + base_index=50000, +): + torch.manual_seed(seed) + + # prompt (cached) len and new len + prompt_lens = torch.randint(low=2, high=8, size=(batch,), device=device) + new_lens = torch.randint(low=1, high=8, size=(batch,), device=device) + total_lens = (prompt_lens + new_lens).to(torch.int32) + + max_total_len = int(total_lens.max().item()) + max_new_len = int(new_lens.max().item()) + + # packed q start + b_start_loc = torch.zeros((batch,), device=device, dtype=torch.int32) + cur = 0 + for b in range(batch): + b_start_loc[b] = cur + cur += int(new_lens[b].item()) + sum_q = cur + + b_seq_len = total_lens + b_prompt_cache_len = prompt_lens.to(torch.int32) + + # one req per batch + num_req = batch + b_req_idx = torch.arange(batch, device=device, dtype=torch.int32) + + # global KV space large, indices not small + sum_kv = int(total_lens.sum().item()) + kv_size = base_index + sum_kv + 1024 + pool = torch.randperm(kv_size - base_index, device=device, dtype=torch.int64)[:sum_kv] + base_index + + # Req_to_tokens [num_req, max_total_len] + req_to_token_indexs = torch.zeros((num_req, max_total_len), device=device, dtype=torch.int32) + p = 0 + for r in range(num_req): + L = int(total_lens[r].item()) + req_to_token_indexs[r, :L] = pool[p : p + L].to(torch.int32) + p += L + + # position_ids_q: only NEW tokens, packed like q + position_ids_q = torch.empty((sum_q,), device=device, dtype=torch.int32) + for b in range(batch): + M = int(new_lens[b].item()) + start = int(b_start_loc[b].item()) + + gid = torch.arange(M, device=device, dtype=torch.int32) + + # make one repeated block inside NEW part to simulate image tokens + if M >= 4 and torch.rand((), device=device).item() > 0.3: + s = int(torch.randint(0, M - 2, (1,), device=device).item()) + e = min(M, s + 3) + gid[s:e] = gid[s] + + position_ids_q[start : start + M] = gid + + q = torch.randn((sum_q, Hq, D), device=device, dtype=dtype) + k = torch.randn((kv_size, Hk, D), device=device, dtype=dtype) + v = torch.randn((kv_size, Hk, D), device=device, dtype=dtype) + o = torch.empty((sum_q, Hq, D), device=device, dtype=dtype) + + return ( + q, + k, + v, + o, + position_ids_q, + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + max_new_len, + req_to_token_indexs, + ) + + +def check_once(device="cuda", dtype=torch.float16, seed=0): + ( + q, + k, + v, + o, + position_ids_q, + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + max_new_len, + req_to_token_indexs, + ) = make_test_case(device=device, dtype=dtype, seed=seed) + + context_attention_fwd_neo( + q, + k, + v, + o, + position_ids_q, + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + max_new_len, + req_to_token_indexs, + ) + + ref = reference_attention( + q, + k, + v, + position_ids_q, + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + req_to_token_indexs, + ) + + diff = (o - ref).abs() + max_abs = diff.max().item() + denom = ref.abs().max().item() + 1e-6 + max_rel = max_abs / denom + + print(f"seed={seed}, dtype={dtype}") + print(f"max_abs_error = {max_abs:.6e}") + print(f"max_rel_error = {max_rel:.6e}") + print("allclose(fp16 tol)?", torch.allclose(o, ref, atol=5e-2, rtol=5e-2)) + + +if __name__ == "__main__": + if not torch.cuda.is_available(): + print("No CUDA, skip.") + else: + torch.cuda.synchronize() + check_once(dtype=torch.bfloat16, seed=0) + check_once(dtype=torch.bfloat16, seed=1) + check_once(dtype=torch.bfloat16, seed=2) diff --git a/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py b/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py new file mode 100644 index 0000000000..1a3d4af73b --- /dev/null +++ b/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py @@ -0,0 +1,191 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _get_neo_position_triton( + b_image_start_idx: torch.Tensor, + b_image_thwd: torch.Tensor, + b_image_thwd_stride0: torch.Tensor, + b_image_nums: torch.Tensor, + b_image_start_num: torch.Tensor, + b_image_len: torch.Tensor, + position_ids: torch.Tensor, + position_ids_stride0: torch.Tensor, + b_ready_cache_len: torch.Tensor, + b_q_seq_len: torch.Tensor, + b_start_loc: torch.Tensor, + b_image_token_tag: torch.Tensor, + BLOCK_SIZE: tl.constexpr, +) -> torch.Tensor: + cur_batch = tl.program_id(0) + cache_len = tl.load(b_ready_cache_len + cur_batch) + q_seq_len = tl.load(b_q_seq_len + cur_batch) + image_num = tl.load(b_image_nums + cur_batch) + image_start_num = tl.load(b_image_start_num + cur_batch) + start_loc = tl.load(b_start_loc + cur_batch) + for i in range(image_num): + local_image_start_idx = tl.load(b_image_start_idx + image_start_num + i) + image_start_idx = start_loc + local_image_start_idx - cache_len + image_len = tl.load(b_image_len + image_start_num + i) + # image_h = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 1) + image_w = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 2) + for j in range(0, image_len, BLOCK_SIZE): + off = j + tl.arange(0, BLOCK_SIZE) + # 目前没考虑视频,所以t 恒为 0 + t_pos = local_image_start_idx + off * 0 + h_pos = off // image_w + w_pos = off % image_w + tl.store( + b_image_token_tag + off + image_start_idx, + True, + mask=(off < image_len) + & (off + local_image_start_idx - cache_len < q_seq_len) + & (local_image_start_idx - cache_len + off >= 0), + ) + tl.store( + position_ids + off + image_start_idx, + t_pos, + mask=(off < image_len) + & (off + local_image_start_idx - cache_len < q_seq_len) + & (local_image_start_idx - cache_len + off >= 0), + ) + tl.store( + position_ids + position_ids_stride0 + off + image_start_idx, + h_pos, + mask=(off < image_len) + & (off + local_image_start_idx - cache_len < q_seq_len) + & (local_image_start_idx - cache_len + off >= 0), + ) + tl.store( + position_ids + position_ids_stride0 * 2 + off + image_start_idx, + w_pos, + mask=(off < image_len) + & (off + local_image_start_idx - cache_len < q_seq_len) + & (local_image_start_idx - cache_len + off >= 0), + ) + + for i in range(image_num): + local_image_start_idx = tl.load(b_image_start_idx + image_start_num + i) + image_len = tl.load(b_image_len + image_start_num + i) + image_delta = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 3) + image_end = local_image_start_idx + image_len - cache_len + text_start = tl.maximum(0, image_end) + for j in range(text_start, q_seq_len, BLOCK_SIZE): + off = j + tl.arange(0, BLOCK_SIZE) + t_pos = tl.load(position_ids + off + start_loc, mask=(off < q_seq_len), other=0.0) + image_delta + h_pos = tl.load(position_ids + position_ids_stride0 + off + start_loc, mask=(off < q_seq_len), other=0.0) + w_pos = tl.load( + position_ids + position_ids_stride0 * 2 + off + start_loc, mask=(off < q_seq_len), other=0.0 + ) + tl.store(position_ids + off + start_loc, t_pos, mask=(off < q_seq_len)) + tl.store(position_ids + position_ids_stride0 + off + start_loc, h_pos, mask=(off < q_seq_len)) + tl.store(position_ids + position_ids_stride0 * 2 + off + start_loc, w_pos, mask=(off < q_seq_len)) + return + + +def get_neo_position_triton( + b_image_start_idx: torch.Tensor, + b_image_thwd: torch.Tensor, + b_image_nums: torch.Tensor, + b_image_start_num: torch.Tensor, + b_image_len: torch.Tensor, + position_ids: torch.Tensor, + b_ready_cache_len: torch.Tensor, + b_q_seq_len: torch.Tensor, + b_start_loc: torch.Tensor, + b_image_token_tag: torch.Tensor, +) -> torch.Tensor: + + batch_size = b_q_seq_len.shape[0] + assert batch_size == b_image_nums.shape[0] + grid = (batch_size,) + BLOCK_SIZE = 64 + _get_neo_position_triton[grid]( + b_image_start_idx=b_image_start_idx, + b_image_thwd=b_image_thwd, + b_image_thwd_stride0=b_image_thwd.stride(0), + b_image_nums=b_image_nums, + b_image_start_num=b_image_start_num, + b_image_len=b_image_len, + position_ids=position_ids, + position_ids_stride0=position_ids.stride(0), + b_ready_cache_len=b_ready_cache_len, + b_q_seq_len=b_q_seq_len, + b_start_loc=b_start_loc, + b_image_token_tag=b_image_token_tag, + BLOCK_SIZE=BLOCK_SIZE, + ) + + +def test(): + b_image_start_idx = torch.tensor([0, 0, 4], dtype=torch.int32, device="cuda") + b_image_thwd = torch.tensor([[1, 2, 2, -3], [1, 2, 2, -3], [1, 2, 4, -7]], dtype=torch.int32, device="cuda") + b_image_nums = torch.tensor([1, 2], dtype=torch.int32, device="cuda") + b_image_start_num = torch.tensor([0, 1], dtype=torch.int32, device="cuda") + b_image_len = torch.tensor([4, 4, 8], dtype=torch.int32, device="cuda") + position_ids = ( + torch.tensor([0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=torch.int32, device="cuda") + .unsqueeze(0) + .expand(3, -1) + .contiguous() + ) + b_image_token_tag = torch.zeros([position_ids.size(1)], dtype=torch.bool, device="cuda") + position_ids[1:].zero_() + b_ready_cache_len = torch.tensor([0, 0], dtype=torch.int32, device="cuda") + b_q_seq_len = torch.tensor([7, 13], dtype=torch.int32, device="cuda") + b_start_loc = torch.tensor([0, 7], dtype=torch.int32, device="cuda") + get_neo_position_triton( + b_image_start_idx, + b_image_thwd, + b_image_nums, + b_image_start_num, + b_image_len, + position_ids, + b_ready_cache_len, + b_q_seq_len, + b_start_loc, + b_image_token_tag, + ) + + print(b_image_token_tag) + print(position_ids) + # old_value = torch.cat([position_ids[:, 2:7], position_ids[:, 7 + 2 :]], dim=1) + + # position_ids = ( + # torch.tensor([2, 3, 4, 5, 6, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=torch.int32, device="cuda") + # .unsqueeze(0) + # .expand(3, -1) + # .contiguous() + # ) + # b_ready_cache_len = torch.tensor([2, 2], dtype=torch.int32, device="cuda") + # b_q_seq_len = torch.tensor([5, 11], dtype=torch.int32, device="cuda") + # b_start_loc = torch.tensor([0, 5], dtype=torch.int32, device="cuda") + + # get_neo_position_triton( + # b_image_start_idx, + # b_image_thwd, + # b_image_nums, + # b_image_start_num, + # b_image_len, + # position_ids, + # b_ready_cache_len, + # b_q_seq_len, + # b_start_loc, + # ) + + # print(f"old_value:\n{old_value}") + # print(f"position_ids:\n{position_ids}") + # assert torch.equal(old_value, position_ids) + + """ + tensor([[0, 0, 0, 0, 2, 3, 4, 0, 0, 0, 0, 2, 2, 2, 2, 4, 5, 6, 7, 8], + [0, 0, 1, 1, 2, 3, 4, 0, 0, 1, 1, 2, 2, 3, 3, 4, 5, 6, 7, 8], + [0, 1, 0, 1, 2, 3, 4, 0, 1, 0, 1, 2, 3, 2, 3, 4, 5, 6, 7, 8]], + device='cuda:0', dtype=torch.int32) + """ + + +if __name__ == "__main__": + test() diff --git a/lightllm/models/neo_chat_moe/vision_process.py b/lightllm/models/neo_chat_moe/vision_process.py new file mode 100644 index 0000000000..fbd57a5e9c --- /dev/null +++ b/lightllm/models/neo_chat_moe/vision_process.py @@ -0,0 +1,141 @@ +import re +import math +import torch +import string +import numpy as np +import pandas as pd +from PIL import Image +import torch.distributed as dist +import torchvision.transforms as T + +IMAGENET_MEAN = (0.485, 0.456, 0.406) +IMAGENET_STD = (0.229, 0.224, 0.225) + + +def round_by_factor(number: int, factor: int) -> int: + """Returns the closest integer to 'number' that is divisible by 'factor'.""" + return round(number / factor) * factor + + +def ceil_by_factor(number: int, factor: int) -> int: + """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" + return math.ceil(number / factor) * factor + + +def floor_by_factor(number: int, factor: int) -> int: + """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" + return math.floor(number / factor) * factor + + +# copy from https://github.com/QwenLM/Qwen2.5-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py#L60 +def smart_resize( + height: int, width: int, factor: int = 32, min_pixels: int = 65536, max_pixels: int = 4194304 +) -> tuple[int, int]: + """ + Rescales the image so that the following conditions are met: + + 1. Both dimensions (height and width) are divisible by 'factor'. + + 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. + + 3. The aspect ratio of the image is maintained as closely as possible. + """ + if max(height, width) / min(height, width) > 200: + raise ValueError( + f"absolute aspect ratio must be smaller than {200}, got {max(height, width) / min(height, width)}" + ) + h_bar = max(factor, round_by_factor(height, factor)) + w_bar = max(factor, round_by_factor(width, factor)) + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = max(factor, floor_by_factor(height / beta, factor)) + w_bar = max(factor, floor_by_factor(width / beta, factor)) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = ceil_by_factor(height * beta, factor) + w_bar = ceil_by_factor(width * beta, factor) + return h_bar, w_bar + + +def dynamic_preprocess_native_resolution(image, size_factor=32, min_pixels=65536, max_pixels=4194304, **kwargs): + width, height = image.size + resized_height, resized_width = smart_resize( + height, + width, + factor=size_factor, + min_pixels=min_pixels, + max_pixels=max_pixels, + ) + image = image.resize((resized_width, resized_height)) + + return image + + +def preprocess_pixel_values(pixel_values, patch_size=16): + c, h, w = pixel_values.shape + grid_h = h // patch_size + grid_w = w // patch_size + + flatten_pixel_values = ( + pixel_values.view(c, grid_h, patch_size, grid_w, patch_size) + .permute(1, 3, 0, 2, 4) # [grid_h, grid_w, c, patch_size, patch_size] + .reshape(grid_h * grid_w, c * patch_size ** 2) + ) + + grid_hw = torch.tensor([[grid_h, grid_w]]).to(device=pixel_values.device) + + return flatten_pixel_values, grid_hw + + +def get_contrasting_background(image): + """ + Calculate the color (white or black) that is different from the average foreground color + to use as the background color + """ + image_np = np.array(image) + if (image_np[:, :, 3] == 0).any(): + non_transparent_pixels = image_np[:, :, :3][image_np[:, :, 3] > 0] + if non_transparent_pixels.size == 0: + return None + pixel_mean = non_transparent_pixels.mean() + contrasting_color = (0, 0, 0) if pixel_mean > 382.5 else (255, 255, 255) + return contrasting_color + else: + return None + + +def load_image_native(image, patch_size=16, downsample_ratio=0.5, min_pixels=65536, max_pixels=4194304, upscale=False): + """ + Load and preprocess an image file, converting it to RGB mode, + resizing, normalizing, and optionally adding a thumbnail version. + """ + if image.mode == "RGBA": + bg_color = get_contrasting_background(image) + if bg_color: + background = Image.new("RGB", image.size, bg_color) + background.paste(image, mask=image.split()[3]) + image = background.convert("RGB") + else: + image = image.convert("RGB") + else: + image = image.convert("RGB") + + if upscale: + image = image.resize((image.width * 2, image.height * 2), Image.BILINEAR) + + transform = T.Compose( + [ + T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), + T.ToTensor(), + T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), + ] + ) + + new_image = dynamic_preprocess_native_resolution( + image, size_factor=int(patch_size // downsample_ratio), min_pixels=min_pixels, max_pixels=max_pixels + ) + pixel_values, grid_hw = preprocess_pixel_values(transform(new_image).to(torch.float32), patch_size=patch_size) + + # print(f"Transfer image_size from ({image.height, image.width}) to ({new_image.height, new_image.width})") + + return pixel_values, grid_hw diff --git a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py index 9eccddffc1..af035e81b6 100644 --- a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py @@ -93,8 +93,10 @@ def _tpsp_get_qkv( input = gather_input[0 : len(infer_state.input_ids), :] input = input.view(-1, self.embed_dim_) - q = layer_weight.q_proj.mm(input) - cache_kv = layer_weight.kv_proj.mm(input) + qkv = layer_weight.qkv_proj.mm(input) + q, cache_kv = qkv.split( + [self.tp_q_head_num_ * self.head_dim_, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_], dim=-1 + ) layer_weight.q_norm_weight_(q, eps=self.eps_) layer_weight.k_norm_weight_( cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], @@ -130,6 +132,7 @@ def _moe_ffn( use_grouped_topk=False, topk_group=None, num_expert_group=None, + microbatch_index=infer_state.microbatch_index, ) return hidden_states.view(num_tokens, hidden_dim) @@ -150,6 +153,7 @@ def _moe_ffn_edp( topk_group=None, num_expert_group=None, is_prefill=infer_state.is_prefill, + microbatch_index=infer_state.microbatch_index, ) ep_output = ep_output.view(token_num, hidden_dim) diff --git a/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py index 13ba6cbe0f..5a857fd093 100644 --- a/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py @@ -52,6 +52,11 @@ def _init_moe(self): tp_rank=0, tp_world_size=1, ) + mlp_only = set(self.network_config_.get("mlp_only_layers", [])) + step = self.network_config_.get("decoder_sparse_step", 1) + moe_layer_index = sum( + 1 for i in range(self.layer_num_) if self.n_routed_experts > 0 and i not in mlp_only and (i + 1) % step == 0 + ) self.experts = FusedMoeWeight( gate_proj_name="gate_proj", down_proj_name="down_proj", @@ -65,6 +70,7 @@ def _init_moe(self): quant_method=self.quant_cfg.get_quant_method(self.layer_num_, "fused_moe"), layer_num=self.layer_num_, network_config=self.network_config_, + moe_layer_index=moe_layer_index, ) def _init_qkv(self): diff --git a/lightllm/models/qwen3_moe/model.py b/lightllm/models/qwen3_moe/model.py index 10a5051276..2926a12b1f 100644 --- a/lightllm/models/qwen3_moe/model.py +++ b/lightllm/models/qwen3_moe/model.py @@ -4,6 +4,7 @@ from lightllm.models.qwen3_moe.layer_infer.transformer_layer_infer import Qwen3MOETransformerLayerInfer from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight from lightllm.models.qwen3.model import Qwen3TpPartModel +from lightllm.common.basemodel.routing_manager import init_routing_capture from lightllm.utils.log_utils import init_logger from lightllm.distributed.communication_op import dist_group_manager @@ -26,3 +27,6 @@ def __init__(self, kvargs): def _init_custom(self): super()._init_custom() dist_group_manager.new_deepep_group(self.config["num_experts"], self.config["hidden_size"]) + if self.args.enable_return_routed_experts: + num_moe_layers = sum(1 for w in self.trans_layers_weight if w.is_moe) + init_routing_capture(self, num_moe_layers) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 73b9bad4a4..409460feb0 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -1,8 +1,7 @@ import argparse -def make_argument_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() +def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: parser.add_argument( "--run_mode", @@ -607,6 +606,12 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--disk_cache_storage_size", type=float, default=10, help="""The capacity of disk cache. GB used.""" ) + parser.add_argument( + "--enable_torch_memory_saver", + action="store_true", + help="""enable torch memory saver, which is used for release_memory and resume_memory during RL training.""", + ) + parser.add_argument("--enable_weight_cpu_backup", action="store_true", help="""enable weight cpu backup.""") parser.add_argument( "--disk_cache_dir", type=str, @@ -639,4 +644,10 @@ def make_argument_parser() -> argparse.ArgumentParser: If the op is not implemented for the platform and the hardware support triton, it will use triton implementation.""", ) + parser.add_argument( + "--enable_return_routed_experts", + action="store_true", + default=False, + help="Enable returning routed expert indices for MoE models (R3 feature).", + ) return parser diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index be2315d34c..2502df9777 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -33,7 +33,7 @@ import uuid from PIL import Image import multiprocessing as mp -from typing import AsyncGenerator, Union +from typing import Any, AsyncGenerator, Union from typing import Callable from lightllm.server import TokenLoad from fastapi import BackgroundTasks, FastAPI, Request, WebSocket, WebSocketDisconnect @@ -49,6 +49,7 @@ from lightllm.utils.error_utils import ServerBusyError from lightllm.server.metrics.manager import MetricClient from lightllm.utils.envs_utils import get_unique_server_name +from lightllm.server.io_struct import ReleaseMemoryReq, ResumeMemoryReq from dataclasses import dataclass from .api_openai import chat_completions_impl, completions_impl @@ -58,6 +59,15 @@ CompletionRequest, CompletionResponse, ) +from .io_struct import ( + AbortReq, + FlushCacheReq, + InitWeightsUpdateGroupReq, + DestroyWeightsUpdateGroupReq, + UpdateWeightsFromDistributedReq, + UpdateWeightsFromTensorReq, + GeneralModelToHttpRpcRsp, +) from .build_prompt import build_prompt, init_tokenizer logger = init_logger(__name__) @@ -132,6 +142,22 @@ def get_model_name(): return {"model_name": g_objs.args.model_name} +@app.get("/get_server_info") +@app.post("/get_server_info") +def get_server_info(): + # 将 StartArgs 转换为字典格式 + from dataclasses import asdict + + server_info: dict[str, Any] = asdict(g_objs.args) + return {**server_info} + + +@app.get("/get_weight_version") +@app.post("/get_weight_version") +def get_weight_version(): + return {"weight_version": g_objs.args.weight_version} + + @app.get("/healthz", summary="Check server health") @app.get("/health", summary="Check server health") @app.head("/health", summary="Check server health") @@ -293,6 +319,83 @@ async def metrics() -> Response: return response +@app.post("/abort_request") +async def abort_request(request: AbortReq, raw_request: Request): + """Abort a request.""" + try: + await g_objs.httpserver_manager.abort_request(request) + return Response(status_code=200) + except Exception as e: + return create_error_response(HTTPStatus.EXPECTATION_FAILED, f"error: {str(e)}") + + +async def handle_request_common(request_obj, handler): + try: + ret: GeneralModelToHttpRpcRsp = await handler(request_obj) + if ret.success: + return JSONResponse({"success": ret.success, "message": ret.msg}, status_code=200) + else: + return create_error_response(HTTPStatus.BAD_REQUEST, ret.msg) + except Exception as e: + logger.error("handle_request_common (%s) error occurred: %s", str(request_obj), str(e), exc_info=True) + return create_error_response(HTTPStatus.EXPECTATION_FAILED, f"error: {str(e)}") + + +@app.post("/init_weights_update_group") +async def init_weights_update_group(request: InitWeightsUpdateGroupReq, raw_request: Request): + """Init weights update group.""" + return await handle_request_common(request, g_objs.httpserver_manager.init_weights_update_group) + + +@app.post("/destroy_weights_update_group") +async def destroy_weights_update_group(request: DestroyWeightsUpdateGroupReq, raw_request: Request): + """Destroy weights update group.""" + return await handle_request_common(request, g_objs.httpserver_manager.destroy_weights_update_group) + + +@app.post("/update_weights_from_distributed") +async def update_weights_from_distributed(request: UpdateWeightsFromDistributedReq, raw_request: Request): + """Update model parameter from distributed online.""" + return await handle_request_common(request, g_objs.httpserver_manager.update_weights_from_distributed) + + +@app.post("/update_weights_from_tensor") +async def update_weights_from_tensor(request: UpdateWeightsFromTensorReq, raw_request: Request): + """Update model parameter from distributed online.""" + return await handle_request_common(request, g_objs.httpserver_manager.update_weights_from_tensor) + + +@app.post("/flush_cache") +@app.get("/flush_cache") +async def flush_cache(): + """Flush the radix cache.""" + return await handle_request_common(FlushCacheReq(), g_objs.httpserver_manager.flush_cache) + + +@app.post("/pause_generation") +async def pause_generation(): + await g_objs.httpserver_manager.pause_generation() + return Response(content="Generation paused successfully.", status_code=200) + + +@app.post("/continue_generation") +async def continue_generation(): + await g_objs.httpserver_manager.continue_generation() + return Response(content="Generation continued successfully.", status_code=200) + + +@app.get("/release_memory_occupation") +@app.post("/release_memory_occupation") +async def release_memory_occupation(request: ReleaseMemoryReq): + return await handle_request_common(request, g_objs.httpserver_manager.release_memory_occupation) + + +@app.get("/resume_memory_occupation") +@app.post("/resume_memory_occupation") +async def resume_memory_occupation(request: ResumeMemoryReq): + return await handle_request_common(request, g_objs.httpserver_manager.resume_memory_occupation) + + @app.websocket("/pd_register") async def register_and_keep_alive(websocket: WebSocket): await websocket.accept() diff --git a/lightllm/server/api_lightllm.py b/lightllm/server/api_lightllm.py index d3592a5f54..d15bec6485 100644 --- a/lightllm/server/api_lightllm.py +++ b/lightllm/server/api_lightllm.py @@ -35,6 +35,9 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana prompt = request_dict.pop("inputs") sample_params_dict = request_dict["parameters"] return_details = sample_params_dict.pop("return_details", False) + return_routed_experts = sample_params_dict.pop( + "return_routed_experts", httpserver_manager.args.enable_return_routed_experts + ) sampling_params = SamplingParams() sampling_params.init(tokenizer=httpserver_manager.tokenizer, **sample_params_dict) sampling_params.verify() @@ -53,6 +56,7 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana prompt_token_ids = None is_first_metadata = True input_usage = None + routed_experts_data = None async for sub_req_id, request_output, metadata, finish_status in results_generator: # when set "--return_all_prompt_logprobs", the first token metadata will contains # prompt_logprobs and prompt_token_ids @@ -78,6 +82,8 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana if finish_status.is_finished(): finish_reason_dict[sub_req_id] = finish_status + if "routed_experts" in metadata: + routed_experts_data = metadata["routed_experts"] n = sampling_params.n sub_ids = list(final_output_dict.keys())[:n] final_output_list = ["".join(final_output_dict[sub_id]) for sub_id in sub_ids] @@ -102,6 +108,8 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana ret["prompt_logprobs"] = prompt_logprobs if input_usage is not None: ret["input_usage"] = input_usage + if return_routed_experts and routed_experts_data is not None: + ret["routed_experts"] = routed_experts_data return Response(content=json.dumps(ret, ensure_ascii=False).encode("utf-8")) @@ -112,6 +120,7 @@ async def lightllm_generate_stream(request: Request, httpserver_manager: HttpSer prompt = request_dict.pop("inputs") sample_params_dict = request_dict["parameters"] _ = sample_params_dict.pop("return_details", False) + _ = sample_params_dict.pop("return_routed_experts", None) sampling_params = SamplingParams() sampling_params.init(tokenizer=httpserver_manager.tokenizer, **sample_params_dict) sampling_params.verify() diff --git a/lightllm/server/api_server.py b/lightllm/server/api_server.py index b4447d808a..1eb5ff24c0 100755 --- a/lightllm/server/api_server.py +++ b/lightllm/server/api_server.py @@ -1,15 +1,36 @@ import torch -from .api_cli import make_argument_parser +from .api_cli import add_cli_args +from lightllm.server.core.objs.start_args_type import StartArgs +from lightllm.utils.log_utils import init_logger -if __name__ == "__main__": - torch.multiprocessing.set_start_method("spawn") # this code will not be ok for settings to fork to subprocess - parser = make_argument_parser() - args = parser.parse_args() +logger = init_logger(__name__) + + +def launch_server(args: StartArgs): from .api_start import pd_master_start, normal_or_p_d_start, config_server_start + try: + # this code will not be ok for settings to fork to subprocess + torch.multiprocessing.set_start_method("spawn") + except RuntimeError as e: + logger.warning(f"Failed to set start method: {e}") + except Exception as e: + logger.error(f"Failed to set start method: {e}") + raise e + if args.run_mode == "pd_master": pd_master_start(args) elif args.run_mode == "config_server": config_server_start(args) else: normal_or_p_d_start(args) + + +if __name__ == "__main__": + from argparse import ArgumentParser + + parser = ArgumentParser() + add_cli_args(parser) + args = parser.parse_args() + + launch_server(StartArgs(**vars(args))) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index bd8e4db8b2..58dac941b0 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -1,3 +1,4 @@ +import multiprocessing as mp import os import sys import time @@ -16,6 +17,7 @@ from lightllm.utils.process_check import is_process_active from lightllm.utils.multinode_utils import send_and_receive_node_ip from lightllm.utils.shm_size_check import check_recommended_shm_size +from lightllm.server.core.objs.start_args_type import StartArgs logger = init_logger(__name__) @@ -51,20 +53,44 @@ def signal_handler(sig, frame): process_manager.terminate_all_processes() logger.info("All processes have been terminated gracefully.") sys.exit(0) + elif sig == signal.SIGHUP: + logger.info("Received SIGHUP (terminal closed), shutting down gracefully...") + if http_server_process and http_server_process.poll() is None: + http_server_process.send_signal(signal.SIGTERM) + + start_time = time.time() + while (time.time() - start_time) < 60: + if not is_process_active(http_server_process.pid): + logger.info("httpserver exit") + break + time.sleep(1) + + if time.time() - start_time < 60: + logger.info("HTTP server has exited gracefully") + else: + logger.warning("HTTP server did not exit in time, killing it...") + kill_recursive(http_server_process) + + process_manager.terminate_all_processes() + logger.info("All processes have been terminated gracefully due to terminal closure.") + sys.exit(0) signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGHUP, signal_handler) logger.info(f"start process pid {os.getpid()}") logger.info(f"http server pid {http_server_process.pid}") return -def normal_or_p_d_start(args): - from lightllm.server.core.objs.start_args_type import StartArgs +def _set_envs_and_config(args: StartArgs): + mp.set_start_method("spawn", force=True) + - args: StartArgs = args +def _launch_subprocesses(args: StartArgs): + _set_envs_and_config(args) set_unique_server_name(args) if not args.disable_shm_warning: @@ -138,6 +164,10 @@ def normal_or_p_d_start(args): assert args.mtp_draft_model_dir is None assert args.mtp_step == 0 + # automatically set visual_dp based on visual_tp and tp + if args.visual_tp < args.tp and args.tp % args.visual_tp == 0: + args.visual_dp = args.tp // args.visual_tp + # 检查GPU数量是否足够 if args.visual_gpu_ids is None: args.visual_gpu_ids = list(range(args.visual_dp * args.visual_tp)) @@ -238,6 +268,7 @@ def normal_or_p_d_start(args): visual_nccl_ports.append(can_use_ports[0]) can_use_ports = can_use_ports[1:] + args.visual_nccl_ports = visual_nccl_ports # 将申请好的端口放入args参数中 if args.nccl_port is None: args.nccl_port = nccl_port @@ -333,6 +364,13 @@ def normal_or_p_d_start(args): ], ) + return process_manager + + +def normal_or_p_d_start(args: StartArgs): + + process_manager = _launch_subprocesses(args) + # 启动 gunicorn command = [ "gunicorn", @@ -372,7 +410,7 @@ def normal_or_p_d_start(args): return -def pd_master_start(args): +def pd_master_start(args: StartArgs): set_unique_server_name(args) if args.run_mode != "pd_master": return @@ -435,7 +473,7 @@ def pd_master_start(args): http_server_process.wait() -def config_server_start(args): +def config_server_start(args: StartArgs): set_unique_server_name(args) if args.run_mode != "config_server": return diff --git a/lightllm/server/audioserver/manager.py b/lightllm/server/audioserver/manager.py index 945e67681d..fdea28f576 100644 --- a/lightllm/server/audioserver/manager.py +++ b/lightllm/server/audioserver/manager.py @@ -11,7 +11,7 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) from lightllm.utils.log_utils import init_logger -from lightllm.server.core.objs.io_objs.group_req import GroupReqIndexes +from lightllm.server.io_struct import BaseReq, GenerateReqIndex from lightllm.server.core.objs.shm_req_manager import ShmReqManager, StartArgs from lightllm.server.multimodal_params import AudioItem from .model_infer.model_rpc import start_model_process, AudioModelRpcClient @@ -42,7 +42,7 @@ def __init__( self.cache_client = rpyc.connect("localhost", args.cache_port, config={"allow_pickle": True}) self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) self.cache_port = args.cache_port - self.waiting_reqs: List[GroupReqIndexes] = [] + self.waiting_reqs: List[GenerateReqIndex] = [] self.model_weightdir = args.model_dir self.tp_world_size = args.tp self.world_size = 1 @@ -136,8 +136,8 @@ async def loop_for_fwd(self): async def loop_for_netio_req(self): while True: - recv_req: GroupReqIndexes = await self.zmq_recv_socket.recv_pyobj() - if isinstance(recv_req, GroupReqIndexes): + recv_req: BaseReq = await self.zmq_recv_socket.recv_pyobj() + if isinstance(recv_req, GenerateReqIndex): self.waiting_reqs.append(recv_req) else: assert False, f"Error Req Inf {recv_req}" diff --git a/lightllm/server/core/objs/io_objs/__init__.py b/lightllm/server/core/objs/io_objs/__init__.py index c9b806c47d..10386b70e6 100644 --- a/lightllm/server/core/objs/io_objs/__init__.py +++ b/lightllm/server/core/objs/io_objs/__init__.py @@ -1 +1 @@ -from .group_req import GroupReqIndexes, GroupReqObjs, AbortedReqCmd, StopStrMatchedReqCmd +from .group_req import AbortedReqCmd, StopStrMatchedReqCmd diff --git a/lightllm/server/core/objs/io_objs/group_req.py b/lightllm/server/core/objs/io_objs/group_req.py index dfcbdd2562..d644c0c316 100644 --- a/lightllm/server/core/objs/io_objs/group_req.py +++ b/lightllm/server/core/objs/io_objs/group_req.py @@ -1,33 +1,10 @@ from dataclasses import dataclass from lightllm.server.multimodal_params import MultimodalParams +from lightllm.server.core.objs.sampling_params import SamplingParams from typing import List from ..req import Req -@dataclass -class GroupReqIndexes: - group_req_id: int - multimodal_params: MultimodalParams - shm_req_indexes: List[int] - time_mark: float - - -@dataclass -class GroupReqObjs: - group_req_id: int - multimodal_params: MultimodalParams - shm_req_objs: List[Req] - time_mark: float - - def to_group_req_index(self): - return GroupReqIndexes( - group_req_id=self.group_req_id, - multimodal_params=self.multimodal_params, - shm_req_indexes=[req.index_in_shm_mem for req in self.shm_req_objs], - time_mark=self.time_mark, - ) - - @dataclass class AbortedReqCmd: req_id: int diff --git a/lightllm/server/core/objs/py_sampling_params.py b/lightllm/server/core/objs/py_sampling_params.py index 887f360c84..08921317e8 100644 --- a/lightllm/server/core/objs/py_sampling_params.py +++ b/lightllm/server/core/objs/py_sampling_params.py @@ -54,6 +54,8 @@ def __init__( # processor which only retains scores for the given token ids. Defaults to None. # allowed_token_ids only can be used in "--output_constraint_mode outlines" started server. allowed_token_ids: Optional[List[int]] = None, + # if provided, the invalid token ids will be ignored during generation + invalid_token_ids: Optional[List[int]] = None, # p d mode used params group_request_id: Optional[int] = None, # move kv to deocde node, only used in pd mode @@ -88,6 +90,7 @@ def __init__( self.guided_grammar = guided_grammar self.guided_json = guided_json self.allowed_token_ids = allowed_token_ids + self.invalid_token_ids = invalid_token_ids self.group_request_id = group_request_id self.move_kv_to_decode_node = move_kv_to_decode_node self.suggested_dp_index = suggested_dp_index @@ -109,13 +112,18 @@ def __init__( def load_generation_cfg(cls, weight_dir): try: generation_cfg = GenerationConfig.from_pretrained(weight_dir, trust_remote_code=True).to_dict() - cls._do_sample = generation_cfg.get("do_sample", False) - cls._presence_penalty = generation_cfg.get("presence_penalty", 0.0) - cls._frequency_penalty = generation_cfg.get("frequency_penalty", 0.0) - cls._repetition_penalty = generation_cfg.get("repetition_penalty", 1.0) - cls._temperature = generation_cfg.get("temperature", 1.0) - cls._top_p = generation_cfg.get("top_p", 1.0) - cls._top_k = generation_cfg.get("top_k", -1) + + def _cfg(key, default): + v = generation_cfg.get(key) + return v if v is not None else default + + cls._do_sample = _cfg("do_sample", False) + cls._presence_penalty = _cfg("presence_penalty", 0.0) + cls._frequency_penalty = _cfg("frequency_penalty", 0.0) + cls._repetition_penalty = _cfg("repetition_penalty", 1.0) + cls._temperature = _cfg("temperature", 1.0) + cls._top_p = _cfg("top_p", 1.0) + cls._top_k = _cfg("top_k", -1) cls._stop_sequences = generation_cfg.get("stop", None) except: pass @@ -267,6 +275,7 @@ def to_dict(self): ret["guided_grammar"] = self.guided_grammar ret["guided_json"] = self.guided_json ret["allowed_token_ids"] = self.allowed_token_ids + ret["invalid_token_ids"] = self.invalid_token_ids ret["move_kv_to_decode_node"] = self.move_kv_to_decode_node return ret diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py index f489aac9c2..5c7e56843b 100644 --- a/lightllm/server/core/objs/req.py +++ b/lightllm/server/core/objs/req.py @@ -1,6 +1,7 @@ import os import math import ctypes +import base64 import numpy as np import time from .sampling_params import SamplingParams @@ -13,6 +14,7 @@ from lightllm.utils.kv_cache_utils import compute_token_list_hash from typing import List, Any, Union from lightllm.utils.log_utils import init_logger +from lightllm.utils.shm_utils import create_or_link_shm logger = init_logger(__name__) @@ -24,19 +26,20 @@ class FinishStatus(ctypes.Structure): NO_FINISH = 0 FINISHED_STOP = 1 FINISHED_LENGTH = 2 + FINISHED_ABORTED = 3 def __init__(self, init_state=NO_FINISH): self.status = init_state def set_status(self, new_status): - assert 0 <= new_status <= 2 + assert 0 <= new_status <= 3 self.status = new_status def get_status(self): return self.status def is_finished(self): - return self.FINISHED_STOP <= self.status <= self.FINISHED_LENGTH + return self.FINISHED_STOP <= self.status <= self.FINISHED_ABORTED def is_stopped(self): return self.status == self.FINISHED_STOP @@ -49,6 +52,8 @@ def get_finish_reason(self): return "stop" elif self.status == self.FINISHED_LENGTH: return "length" + elif self.status == self.FINISHED_ABORTED: + return "abort" return None @@ -122,6 +127,8 @@ class Req(ctypes.Structure): ("cpu_cache_match_page_indexes", CpuCachePageList), # 分块hash的块大小 ("cpu_cache_token_page_size", ctypes.c_int), + # Number of tokens in routing data SHM, written by model worker, read by HTTP server. + ("shm_routing_num_tokens", ctypes.c_int), ] def get_str(self): @@ -179,6 +186,7 @@ def init( self._mtp_step = get_env_start_args().mtp_step self.stop_str_matched = False self.stop_str_matched_token_index = -1 + self.shm_routing_num_tokens = 0 self.post_init() @@ -227,6 +235,69 @@ def link_logprobs_shm_array(self): self.shm_logprobs.link_shm() return + def create_routing_data_shm_array(self, num_moe_layers: int, num_tokens: int, topk: int, np_dtype=np.int8): + """Create routing SHM at actual size (on-demand, not pre-allocated). + + Uses smart mode: links if same-sized SHM exists, otherwise creates new. + """ + service_uni_name = get_unique_server_name() + name = f"{service_uni_name}_shm_routing_{self.index_in_shm_mem}" + shape = (num_tokens, num_moe_layers, topk) + self.shm_routing_data = ShmArray(name, shape, dtype=np_dtype) + self.shm_routing_data.create_shm() + self.shm_routing_num_tokens = num_tokens + return + + def link_routing_data_shm_array(self, num_moe_layers: int, topk: int, np_dtype=np.int8): + """Link to routing SHM from the reader side (HTTP server).""" + if num_moe_layers == 0: + return + num_tokens = self.shm_routing_num_tokens + if num_tokens <= 0: + return + service_uni_name = get_unique_server_name() + name = f"{service_uni_name}_shm_routing_{self.index_in_shm_mem}" + shape = (num_tokens, num_moe_layers, topk) + self.shm_routing_data = ShmArray(name, shape, dtype=np_dtype) + self.shm_routing_data.link_shm() + return + + def get_routing_data(self): + if not hasattr(self, "shm_routing_data") or self.shm_routing_data is None: + return None + return self.shm_routing_data.arr + + def close_routing_data_shm_array(self): + """Close and unlink routing SHM (on-demand, no longer pooled).""" + if hasattr(self, "shm_routing_data") and self.shm_routing_data is not None: + self.shm_routing_data.close_shm() + self.shm_routing_data = None + self.shm_routing_num_tokens = 0 + return + + def get_routing_metadata(self, num_moe_layers: int, topk: int, dtype_id: int = 1): + if num_moe_layers == 0 or topk == 0: + return None + if self.shm_routing_num_tokens <= 0: + return None + try: + from lightllm.common.basemodel.routing_manager import routing_dtype_id_to_np + + np_dtype = routing_dtype_id_to_np(dtype_id) + if not hasattr(self, "shm_routing_data") or self.shm_routing_data is None: + self.link_routing_data_shm_array(num_moe_layers, topk, np_dtype=np_dtype) + routing_data = self.get_routing_data() + if routing_data is None: + return None + return { + "shape": list(routing_data.shape), + "dtype": str(routing_data.dtype), + "data": base64.b64encode(routing_data.tobytes()).decode("ascii"), + } + except Exception as e: + logger.warning(f"Failed to read routing data for req {self.request_id}: {e}") + return None + def get_prompt_ids(self): return self.shm_prompt_ids.arr[: self.input_len].tolist() @@ -247,9 +318,8 @@ def can_release(self): ref_count_ok = self.ref_count == 1 can_released_mark = self.can_released_mark - if self.is_aborted and can_released_mark and ref_count_ok: - return True - + # if self.is_aborted and can_released_mark and ref_count_ok: + # return True ok_finished_gen_req = self.finish_status.is_finished() or self.stop_str_matched if ok_finished_gen_req and can_released_mark and ref_count_ok and self.out_tokens_queue.is_empty(): diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py index d955aa6a87..31e2fbefed 100644 --- a/lightllm/server/core/objs/sampling_params.py +++ b/lightllm/server/core/objs/sampling_params.py @@ -17,6 +17,7 @@ REGULAR_CONSTRAINT_MAX_LENGTH = int(os.getenv("LIGHTLLM_REGULAR_CONSTRAINT_MAX_LENGTH", 2048)) GRAMMAR_CONSTRAINT_MAX_LENGTH = int(os.getenv("LIGHTLLM_GRAMMAR_CONSTRAINT_MAX_LENGTH", 2048)) JSON_SCHEMA_MAX_LENGTH = int(os.getenv("LIGHTLLM_JSON_SCHEMA_MAX_LENGTH", 2048)) +INVALID_TOKEN_IDS_MAX_LENGTH = int(os.getenv("LIGHTLLM_INVALID_TOKEN_IDS_MAX_LENGTH", 10)) class StopSequence(ctypes.Structure): @@ -205,6 +206,25 @@ def to_list(self): return list(self.ids[: self.size]) +class InvalidTokenIds(ctypes.Structure): + _pack_ = 4 + _fields_ = [ + ("ids", ctypes.c_int * INVALID_TOKEN_IDS_MAX_LENGTH), + ("size", ctypes.c_int), + ] + + def initialize(self, ids: List[int]): + self.size = len(ids) + assert ( + self.size <= INVALID_TOKEN_IDS_MAX_LENGTH + ), f"Too many invalid token IDs {self.size} > {INVALID_TOKEN_IDS_MAX_LENGTH}." + self.ids[: self.size] = ids[:] + return + + def to_list(self): + return list(self.ids[: self.size]) + + class ExponentialDecayLengthPenalty(ctypes.Structure): _pack_ = 4 _fields_ = [ @@ -293,6 +313,8 @@ class SamplingParams(ctypes.Structure): ("ignore_eos", ctypes.c_bool), # the max number of image patches to be used in the internvl model, for the test ("image_max_patch_num", ctypes.c_int), + ("min_pixels", ctypes.c_int), + ("max_pixels", ctypes.c_int), ("max_new_tokens", ctypes.c_int), ("min_new_tokens", ctypes.c_int), # Whether to count input tokens for presence_penalty, frequency_penalty and repetition_penalty @@ -304,6 +326,8 @@ class SamplingParams(ctypes.Structure): # processor which only retains scores for the given token ids. Defaults to None. # allowed_token_ids only can be used in "--output_constraint_mode outlines" started server. ("allowed_token_ids", AllowedTokenIds), + # if provided, the invalid token ids will be ignored during generation + ("invalid_token_ids", InvalidTokenIds), ("stop_sequences", StopSequenceGroups), ("exponential_decay_length_penalty", ExponentialDecayLengthPenalty), ("group_request_id", ctypes.c_int64), # p d mode used params @@ -333,22 +357,28 @@ class SamplingParams(ctypes.Structure): def init(self, tokenizer, **kwargs): super().__init__() - self.best_of = kwargs.get("best_of", 1) - self.n = kwargs.get("n", self.best_of) - self.do_sample = kwargs.get("do_sample", SamplingParams._do_sample) - self.presence_penalty = kwargs.get("presence_penalty", SamplingParams._presence_penalty) - self.frequency_penalty = kwargs.get("frequency_penalty", SamplingParams._frequency_penalty) - self.repetition_penalty = kwargs.get("repetition_penalty", SamplingParams._repetition_penalty) - self.temperature = kwargs.get("temperature", SamplingParams._temperature) - self.top_p = kwargs.get("top_p", SamplingParams._top_p) - self.top_k = kwargs.get("top_k", SamplingParams._top_k) - self.ignore_eos = kwargs.get("ignore_eos", False) - self.image_max_patch_num = kwargs.get("image_max_patch_num", -1) - self.max_new_tokens = kwargs.get("max_new_tokens", 16) - self.min_new_tokens = kwargs.get("min_new_tokens", 1) - self.input_penalty = kwargs.get("input_penalty", DEFAULT_INPUT_PENALTY) - self.group_request_id = kwargs.get("group_request_id", -1) - self.suggested_dp_index = kwargs.get("suggested_dp_index", -1) + + def _get(key, default): + v = kwargs.get(key) + return v if v is not None else default + + self.best_of = _get("best_of", 1) + self.n = _get("n", self.best_of) + self.do_sample = _get("do_sample", SamplingParams._do_sample) + self.presence_penalty = _get("presence_penalty", SamplingParams._presence_penalty) + self.frequency_penalty = _get("frequency_penalty", SamplingParams._frequency_penalty) + self.repetition_penalty = _get("repetition_penalty", SamplingParams._repetition_penalty) + self.temperature = _get("temperature", SamplingParams._temperature) + self.top_p = _get("top_p", SamplingParams._top_p) + self.top_k = _get("top_k", SamplingParams._top_k) + self.ignore_eos = _get("ignore_eos", False) + self.min_pixels = _get("min_pixels", -1) + self.max_pixels = _get("max_pixels", -1) + self.max_new_tokens = _get("max_new_tokens", 16) + self.min_new_tokens = _get("min_new_tokens", 1) + self.input_penalty = _get("input_penalty", DEFAULT_INPUT_PENALTY) + self.group_request_id = _get("group_request_id", -1) + self.suggested_dp_index = _get("suggested_dp_index", -1) self.skip_special_tokens = kwargs.get("skip_special_tokens", SKIP_SPECIAL_TOKENS) self.disable_prompt_cache = kwargs.get("disable_prompt_cache", False) @@ -392,6 +422,11 @@ def init(self, tokenizer, **kwargs): self.allowed_token_ids = AllowedTokenIds() self.allowed_token_ids.initialize(allowed_token_ids) + # Initialize invalid_token_ids + invalid_token_ids = map(int, kwargs.get("logit_bias", {}).keys()) + self.invalid_token_ids = InvalidTokenIds() + self.invalid_token_ids.initialize(list[int](invalid_token_ids)) + if self.do_sample is False: self.temperature = 1.0 self.top_p = 1.0 @@ -408,13 +443,18 @@ def init(self, tokenizer, **kwargs): def load_generation_cfg(cls, weight_dir): try: generation_cfg = GenerationConfig.from_pretrained(weight_dir, trust_remote_code=True).to_dict() - cls._do_sample = generation_cfg.get("do_sample", False) - cls._presence_penalty = generation_cfg.get("presence_penalty", 0.0) - cls._frequency_penalty = generation_cfg.get("frequency_penalty", 0.0) - cls._repetition_penalty = generation_cfg.get("repetition_penalty", 1.0) - cls._temperature = generation_cfg.get("temperature", 1.0) - cls._top_p = generation_cfg.get("top_p", 1.0) - cls._top_k = generation_cfg.get("top_k", -1) + + def _cfg(key, default): + v = generation_cfg.get(key) + return v if v is not None else default + + cls._do_sample = _cfg("do_sample", False) + cls._presence_penalty = _cfg("presence_penalty", 0.0) + cls._frequency_penalty = _cfg("frequency_penalty", 0.0) + cls._repetition_penalty = _cfg("repetition_penalty", 1.0) + cls._temperature = _cfg("temperature", 1.0) + cls._top_p = _cfg("top_p", 1.0) + cls._top_k = _cfg("top_k", -1) except: pass @@ -482,6 +522,8 @@ def to_dict(self): "image_max_patch_num": self.image_max_patch_num, "max_new_tokens": self.max_new_tokens, "min_new_tokens": self.min_new_tokens, + "min_pixels": self.min_pixels, + "max_pixels": self.max_pixels, "exponential_decay_length_penalty": self.exponential_decay_length_penalty.to_tuple(), "stop_sequences": self.stop_sequences.to_list(), "best_of": self.best_of, @@ -490,6 +532,7 @@ def to_dict(self): "guided_grammar": self.guided_grammar.to_str(), "guided_json": self.guided_json.to_str(), "allowed_token_ids": self.allowed_token_ids.to_list(), + "invalid_token_ids": self.invalid_token_ids.to_list(), "group_request_id": self.group_request_id, "move_kv_to_decode_node": self.move_kv_to_decode_node.to_dict(), "skip_special_tokens": self.skip_special_tokens, diff --git a/lightllm/server/core/objs/shm_array.py b/lightllm/server/core/objs/shm_array.py index c5ad512c6b..1bf20535ad 100644 --- a/lightllm/server/core/objs/shm_array.py +++ b/lightllm/server/core/objs/shm_array.py @@ -26,6 +26,19 @@ def link_shm(self): self.arr = np.ndarray(self.shape, dtype=self.dtype, buffer=self.shm.buf) return + def link_shm_partial(self): + """Link to an existing SHM that may be larger than the needed shape.""" + self.shm = create_or_link_shm(self.name, -1, force_mode="link") + assert self.shm.size >= self.dest_size, f"SHM {self.name} too small: need {self.dest_size}, got {self.shm.size}" + self.arr = np.ndarray(self.shape, dtype=self.dtype, buffer=self.shm.buf) + + def detach_shm(self): + """Close handle without unlinking (SHM persists for reuse).""" + if self.shm is not None: + self.shm.close() + self.shm = None + self.arr = None + def close_shm(self): if self.shm is not None: self.shm.close() diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index a369cf7f7f..4ac0a4dd2b 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -1,37 +1,42 @@ from dataclasses import dataclass, field from typing import List, Optional, Tuple -# 只是为了更好的编程提示 +# 服务启动参数 @dataclass class StartArgs: run_mode: str = field( default="normal", - metadata={"choices": ["normal", "prefill", "decode", "pd_master", "nixl_prefill", "nixl_decode"]}, + metadata={ + "choices": ["normal", "prefill", "decode", "nixl_prefill", "nixl_decode", "pd_master", "config_server"] + }, ) host: str = field(default="127.0.0.1") port: int = field(default=8000) + httpserver_workers: int = field(default=1) zmq_mode: str = field( default="ipc:///tmp/", metadata={"help": "use socket mode or ipc mode, only can be set in ['tcp://', 'ipc:///tmp/']"}, ) - pd_master_ip: str = field(default="127.0.0.1") + pd_master_ip: str = field(default="0.0.0.0") pd_master_port: int = field(default=1212) config_server_host: str = field(default=None) config_server_port: int = field(default=None) pd_decode_rpyc_port: int = field(default=None) - select_p_d_node_strategy: str = field(default=None) + select_p_d_node_strategy: str = field( + default="round_robin", metadata={"choices": ["random", "round_robin", "adaptive_load"]} + ) model_name: str = field(default="default_model_name") model_dir: Optional[str] = field(default=None) - tokenizer_mode: str = field(default="slow") + tokenizer_mode: str = field(default="fast") load_way: str = field(default="HF") max_total_token_num: Optional[int] = field(default=None) mem_fraction: float = field(default=0.9) batch_max_tokens: Optional[int] = field(default=None) - eos_id: List[int] = field(default_factory=list) + eos_id: Optional[List[int]] = field(default=None) tool_call_parser: Optional[str] = field( - default=None, metadata={"choices": ["llama3", "qwen25", "mistral", "deepseekv3", "kimi_k2", "qwen"]} + default=None, metadata={"choices": ["qwen25", "llama3", "mistral", "deepseekv3", "qwen"]} ) reasoning_parser: Optional[str] = field( default=None, @@ -59,7 +64,7 @@ class StartArgs: dp: int = field(default=1) nnodes: int = field(default=1) node_rank: int = field(default=0) - max_req_total_len: int = field(default=2048 + 1024) + max_req_total_len: int = field(default=16384) nccl_host: str = field(default="127.0.0.1") nccl_port: int = field(default=None) use_config_server_to_init_nccl: bool = field(default=False) @@ -75,7 +80,7 @@ class StartArgs: disable_chunked_prefill: bool = field(default=False) diverse_mode: bool = field(default=False) token_healing_mode: bool = field(default=False) - output_constraint_mode: str = field(default="none", metadata={"choices": ["none", "simple", "xgrammar"]}) + output_constraint_mode: str = field(default="none", metadata={"choices": ["outlines", "xgrammar", "none"]}) first_token_constraint_mode: bool = field(default=False) enable_multimodal: bool = field(default=False) enable_multimodal_audio: bool = field(default=False) @@ -95,11 +100,11 @@ class StartArgs: health_monitor: bool = field(default=False) metric_gateway: Optional[str] = field(default=None) job_name: str = field(default="lightllm") - grouping_key: List[str] = field(default_factory=list) + grouping_key: List[str] = field(default_factory=lambda: []) push_interval: int = field(default=10) visual_infer_batch_size: int = field(default=None) visual_send_batch_size: int = field(default=1) - visual_gpu_ids: List[int] = field(default_factory=lambda: [0]) + visual_gpu_ids: List[int] = field(default=None) visual_tp: int = field(default=1) visual_dp: int = field(default=1) visual_nccl_ports: List[int] = field(default=None) @@ -111,9 +116,9 @@ class StartArgs: graph_split_batch_size: int = field(default=32) graph_grow_step_size: int = field(default=16) graph_max_len_in_batch: int = field(default=0) - quant_type: Optional[str] = field(default=None) + quant_type: Optional[str] = field(default="none") quant_cfg: Optional[str] = field(default=None) - vit_quant_type: Optional[str] = field(default=None) + vit_quant_type: Optional[str] = field(default="none") vit_quant_cfg: Optional[str] = field(default=None) llm_prefill_att_backend: List[str] = field( default=("auto",), metadata={"choices": ["auto", "triton", "fa3", "flashinfer"]} @@ -144,7 +149,7 @@ class StartArgs: pd_node_id: int = field(default=-1) enable_cpu_cache: bool = field(default=False) cpu_cache_storage_size: float = field(default=2) - cpu_cache_token_page_size: int = field(default=64) + cpu_cache_token_page_size: int = field(default=256) enable_disk_cache: bool = field(default=False) disk_cache_storage_size: float = field(default=10) disk_cache_dir: Optional[str] = field(default=None) @@ -162,3 +167,23 @@ class StartArgs: # multi_modal enable_multimodal: bool = field(default=False) enable_multimodal_audio: bool = field(default=False) + + httpserver_workers: int = field(default=1) + disable_shm_warning: bool = field(default=False) + dp_balancer: str = field(default="bs_balancer", metadata={"choices": ["round_robin", "bs_balancer"]}) + enable_custom_allgather: bool = field(default=False) + enable_fused_shared_experts: bool = field(default=False) + enable_mps: bool = field(default=False) + multinode_router_gloo_port: int = field(default=20001) + schedule_time_interval: float = field(default=0.03) + use_dynamic_prompt_cache: bool = field(default=False) + disable_custom_allreduce: bool = field(default=False) + enable_torch_memory_saver: bool = field(default=False) + enable_weight_cpu_backup: bool = field(default=False) + hardware_platform: str = field(default="cuda", metadata={"choices": ["cuda", "musa"]}) + enable_torch_fallback: bool = field(default=False) + enable_triton_fallback: bool = field(default=False) + + enable_return_routed_experts: bool = field(default=False) + + weight_version: str = "default" diff --git a/lightllm/server/detokenization/decode_req.py b/lightllm/server/detokenization/decode_req.py index 9aa3a8effc..c77379986c 100644 --- a/lightllm/server/detokenization/decode_req.py +++ b/lightllm/server/detokenization/decode_req.py @@ -62,11 +62,7 @@ def stop_sequences_str_match(self) -> bool: return False def need_detoken(self): - if ( - (not self.req.is_aborted) - and (not self.req.stop_str_matched) - and len(self.output_ids) < self.req.candetoken_out_len - ): + if (not self.req.stop_str_matched) and len(self.output_ids) < self.req.candetoken_out_len: return True return False @@ -83,8 +79,6 @@ def get_decode_tokens(self): return prefix_tokens, read_tokens def can_set_release_mark(self): - if self.req.is_aborted: - return True if self.req.stop_str_matched: return True if ( diff --git a/lightllm/server/detokenization/manager.py b/lightllm/server/detokenization/manager.py index 389171ba8a..17a47dfde6 100644 --- a/lightllm/server/detokenization/manager.py +++ b/lightllm/server/detokenization/manager.py @@ -6,7 +6,6 @@ import zmq import inspect from lightllm.server.core.objs import ShmReqManager, StartArgs -from lightllm.server.core.objs.io_objs import GroupReqIndexes from lightllm.utils.graceful_utils import graceful_registry from typing import Union, Dict, List from .decode import decode_token @@ -17,6 +16,14 @@ import time from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import get_unique_server_name +from lightllm.server.io_struct import ( + BaseReq, + GenerateResp, + FlushCacheResp, + ReleaseMemoryResp, + ResumeMemoryResp, + GeneralModelToHttpRpcRsp, +) logger = init_logger(__name__) @@ -31,9 +38,9 @@ def __init__( self.zmq_recv_socket = context.socket(zmq.PULL) self.zmq_recv_socket.bind(f"{args.zmq_mode}127.0.0.1:{args.detokenization_port}") - self.pub_to_httpserver = context.socket(zmq.PUB) - self.pub_to_httpserver.bind(f"{args.zmq_mode}127.0.0.1:{args.http_server_port}") - logger.info(f"pub_to_httpserver sendhwm {self.pub_to_httpserver.getsockopt(zmq.SNDHWM)}") + self.send_to_httpserver = context.socket(zmq.PUSH) + self.send_to_httpserver.bind(f"{args.zmq_mode}127.0.0.1:{args.http_server_port}") + logger.info(f"send_to_httpserver sendhwm {self.send_to_httpserver.getsockopt(zmq.SNDHWM)}") self.tokenizer = get_tokenizer(args.model_dir, args.tokenizer_mode, trust_remote_code=args.trust_remote_code) self.all_special_ids = set(self.tokenizer.all_special_ids) self.req_id_to_out: Dict[int, DecodeReq] = {} @@ -46,7 +53,7 @@ def _init_get_token_id_to_token_str(self): self.token_id_to_token = {token_id: token for token, token_id in self.tokenizer.get_vocab().items()} return - def _add_new_group_req_index(self, recv_obj: GroupReqIndexes): + def _add_new_group_req_index(self, recv_obj: BaseReq): for req_index in recv_obj.shm_req_indexes: req = self.shm_req_manager.get_req_obj_by_index(req_index) req.link_prompt_ids_shm_array() @@ -74,8 +81,10 @@ def handle_loop(self): try: # 一次最多从 zmq 中取 recv_max_count 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。 for _ in range(recv_max_count): - recv_obj: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) - assert isinstance(recv_obj, GroupReqIndexes) + recv_obj: BaseReq = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) + if isinstance(recv_obj, GeneralModelToHttpRpcRsp): + self.send_to_httpserver.send_pyobj(recv_obj, protocol=pickle.HIGHEST_PROTOCOL) + continue self._add_new_group_req_index(recv_obj=recv_obj) # 当队列中存在较多的请求时,将一次接受的数量上调 @@ -146,7 +155,7 @@ def gen_token_out(self): # 通知 httpserver 进程 if exist_decode: - self.pub_to_httpserver.send_pyobj(None, protocol=pickle.HIGHEST_PROTOCOL) + self.send_to_httpserver.send_pyobj(GenerateResp(), protocol=pickle.HIGHEST_PROTOCOL) self.remove_finished_reqs() diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 212e037e90..3ef778ca42 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -10,6 +10,7 @@ import hashlib import datetime import pickle +import inspect from frozendict import frozendict asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) @@ -25,16 +26,41 @@ from lightllm.server.core.objs import Req, FinishStatus, StartArgs from lightllm.server.core.objs import SamplingParams from lightllm.server.core.objs.out_token_circlequeue import LIGHTLLM_OUT_TOKEN_QUEUE_SIZE -from lightllm.server.core.objs.io_objs import GroupReqObjs from lightllm.server.core.objs.shm_req_manager import ShmReqManager from lightllm.server.core.objs.atomic_array_lock import AtomicShmArrayLock, AsyncLock, AtomicLockItem from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt +from lightllm.common.basemodel.routing_manager import get_routing_config_shm from lightllm.utils.log_utils import init_logger from lightllm.server.metrics.manager import MetricClient +from lightllm.server.io_struct import ( + AbortReq, + BaseReq, + FlushCacheReq, + FlushCacheResp, + GenerateReq, + GenerateResp, + GenerateReqMeta, + GenerateReqIndex, + ReleaseMemoryReq, + ReleaseMemoryResp, + ResumeMemoryReq, + ResumeMemoryResp, + InitWeightsUpdateGroupReq, + InitWeightsUpdateGroupRsp, + DestroyWeightsUpdateGroupReq, + DestroyWeightsUpdateGroupRsp, + UpdateWeightsFromDistributedReq, + UpdateWeightsFromDistributedRsp, + UpdateWeightsFromTensorReq, + UpdateWeightsFromTensorRsp, + GeneralHttpToModelRpcReq, + GeneralModelToHttpRpcRsp, +) from lightllm.utils.statics_utils import MovingAverage from lightllm.utils.config_utils import get_vocab_size from lightllm.utils.envs_utils import get_unique_server_name from lightllm.utils.error_utils import NixlPrefillNodeStopGenToken +from lightllm.utils.torch_memory_saver_utils import MemoryTag from rpyc.utils.classic import obtain logger = init_logger(__name__) @@ -74,7 +100,7 @@ def __init__( self.multinode_req_manager = context.socket(zmq.PULL) self.multinode_req_manager.bind(f"tcp://*:{args.multinode_httpmanager_port}") logger.info( - f"HttpServerManager listening for child node requests on *:{args.multinode_httpmanager_port}" + f"HttpServerManager listening for master node requests on *:{args.multinode_httpmanager_port}" ) self.enable_multimodal = args.enable_multimodal @@ -90,9 +116,8 @@ def __init__( self.shm_req_manager = ShmReqManager() # recv from detokenization - self.zmq_recv_socket = context.socket(zmq.SUB) + self.zmq_recv_socket = context.socket(zmq.PULL) self.zmq_recv_socket.connect(f"{args.zmq_mode}127.0.0.1:{args.http_server_port}") - self.zmq_recv_socket.setsockopt(zmq.SUBSCRIBE, b"") self.tokenizer = get_tokenizer(args.model_dir, args.tokenizer_mode, trust_remote_code=args.trust_remote_code) @@ -114,6 +139,18 @@ def __init__( # If the timemark is not updated for a pre-set time, a prob request will be sent to the backend. self.latest_success_infer_time_mark = SharedInt(f"{get_unique_server_name()}_latest_success_infer_time_mark") self.latest_success_infer_time_mark.set_value(int(time.time())) + + # Cache routing config for MoE expert routing data extraction + self._routing_shm = get_routing_config_shm() if args.enable_return_routed_experts else None + + self.is_pause = False + self.is_pause_cond = asyncio.Condition() + + # 交互式请求 event + self.flush_cache_event: Optional[asyncio.Event] = None + self.release_memory_event: Optional[asyncio.Event] = None + self.resume_memory_event: Optional[asyncio.Event] = None + self.async_events_per_func: Dict[str, asyncio.Event] = {} return async def _alloc_resource(self, items, md5sums, token_nums, datas): @@ -227,18 +264,32 @@ def tokens(self, prompt, multimodal_params, samping_params: SamplingParams, kwar async def loop_for_request(self): assert self.args.node_rank > 0 while True: - ( - prompt, - sampling_params, - multimodal_params, - ) = await self.multinode_req_manager.recv_pyobj() - results_generator = self.generate(prompt, sampling_params, multimodal_params, None) + req_obj = await self.multinode_req_manager.recv_pyobj() + if req_obj is None: + continue + if isinstance(req_obj, GenerateReqMeta): + self.process_generate_request(req_obj) + elif isinstance(req_obj, AbortReq): + self.process_abort_request(req_obj) + else: + assert False, f"Unknown request type: {type(req_obj)}" + return + + def process_generate_request(self, req_meta: GenerateReqMeta): + prompt = req_meta.prompt + sampling_params = req_meta.sampling_params + multimodal_params = req_meta.multimodal_params + results_generator = self.generate(prompt, sampling_params, multimodal_params, None) - async def generate_wrapper(results_generator): - async for _, _, _, _ in results_generator: - pass + async def generate_wrapper(results_generator): + async for _, _, _, _ in results_generator: + pass + + asyncio.create_task(generate_wrapper(results_generator)) + return - asyncio.create_task(generate_wrapper(results_generator)) + def process_abort_request(self, request: AbortReq): + asyncio.create_task(self.abort_request(request)) return def alloc_req_id(self, sampling_params, is_health_req: bool = False): @@ -281,15 +332,15 @@ async def generate( group_request_id = self.alloc_req_id(sampling_params, is_health_req) try: - original_multimodal_params = None - if self.is_multinode_tp_master: - original_multimodal_params = copy.deepcopy(multimodal_params) - if self.pd_mode.is_P_or_NORMAL(): await multimodal_params.verify_and_preload(request) # 记录请求到达的相关信息 await self._log_req_header(request_headers, group_request_id) + + async with self.is_pause_cond: + await self.is_pause_cond.wait_for(lambda: not self.is_pause) + # encode prompt_ids = await self._encode(prompt, multimodal_params, sampling_params) @@ -348,12 +399,17 @@ async def generate( ) req_objs.append(req_obj) - req_status = ReqStatus(group_request_id, multimodal_params, req_objs, start_time) + req_status = ReqStatus( + group_request_id=group_request_id, + prompt=prompt, + sampling_params=sampling_params, + multimodal_params=multimodal_params, + req_objs=req_objs, + start_time=start_time, + ) self.req_id_to_out_inf[group_request_id] = req_status - await self.transfer_to_next_module_or_node( - prompt, sampling_params, original_multimodal_params, req_status.group_req_objs - ) + await self.transfer_to_next_module_or_node(req_status.group_req_objs) results_generator = self._wait_to_token_package( start_time, @@ -441,7 +497,21 @@ async def _encode( # 这里的校验对多模态不是很充分, to do if all(isinstance(e, int) for e in prompt): - if not self.enable_multimodal and not self.pd_mode.is_D(): + if self.enable_multimodal: + assert ( + len(multimodal_params.images + multimodal_params.audios) <= self.args.cache_capacity + ), "too many multimodal items!" + if multimodal_params.audios: + assert self.args.enable_multimodal_audio, "audio multimodal not enabled" + await self._alloc_multimodal_resources(multimodal_params, sampling_params) + prompt_ids = self.tokenizer.encode( + prompt, + multimodal_params, + add_special_tokens=sampling_params.add_special_tokens, + already_tokenized=True, + ) + return prompt_ids + elif not self.enable_multimodal and not self.pd_mode.is_D(): if all(e < self.vocab_size for e in prompt): return prompt else: @@ -484,44 +554,49 @@ async def _check_and_repair_length(self, prompt_ids: List[int], sampling_params: async def transfer_to_next_module_or_node( self, - prompt: str, - sampling_params: SamplingParams, - original_multimodal_params: MultimodalParams, - group_req_objs: Optional[GroupReqObjs] = None, + req_obj: Optional["BaseReq"] = None, ): # 多节点纯tp 运行模式下,master 节点需要将请求转发给slave节点. + req_to_next_node = req_obj.get_req_to_next_node() + self.transfer_to_next_node(req_to_next_node) + req_to_next_module = req_obj.get_req_to_next_module() + await self.transfer_to_next_module(req_to_next_module) + return + + def transfer_to_next_node( + self, + req_to_next_node: Optional["BaseReq"] = None, + ): if self.is_multinode_tp_master: for sender in self.multinode_req_manager: sender.send_pyobj( - (prompt, sampling_params, original_multimodal_params), + req_to_next_node, protocol=pickle.HIGHEST_PROTOCOL, ) - - await self.transfer_to_next_module(group_req_objs) return async def transfer_to_next_module( self, - group_req_objs: Optional[GroupReqObjs] = None, + req_to_next_module: Optional["GenerateReqIndex"] = None, ): if self.pd_mode.is_P_or_NORMAL(): if self.enable_multimodal: self.send_to_visual.send_pyobj( - group_req_objs.to_group_req_index(), + req_to_next_module, protocol=pickle.HIGHEST_PROTOCOL, ) return if self.args.enable_cpu_cache: self.send_to_multi_level_kv_cache.send_pyobj( - group_req_objs.to_group_req_index(), + req_to_next_module, protocol=pickle.HIGHEST_PROTOCOL, ) return self.send_to_router.send_pyobj( - group_req_objs.to_group_req_index(), + req_to_next_module, protocol=pickle.HIGHEST_PROTOCOL, ) return @@ -529,7 +604,7 @@ async def transfer_to_next_module( if self.pd_mode.is_D(): # 在 D 模式下,不需要传输真的多模态参数,因为其已经被 P 处理好了 self.send_to_router.send_pyobj( - group_req_objs.to_group_req_index(), + req_to_next_module, protocol=pickle.HIGHEST_PROTOCOL, ) return @@ -659,12 +734,24 @@ async def abort(self, group_req_id: int) -> bool: logger.warning(f"aborted group_request_id {group_req_id} not exist") return False - group_req_objs: GroupReqObjs = req_status.group_req_objs + group_req_objs: GenerateReq = req_status.group_req_objs for req in group_req_objs.shm_req_objs: req.is_aborted = True logger.warning(f"aborted group_request_id {group_req_objs.group_req_id}") return True + async def abort_request(self, request: AbortReq): + request_id = request.request_id + abort_all = request.abort_all + if self.is_multinode_tp_master: + self.transfer_to_next_node(req_to_next_node=request) + if request_id is not None and not abort_all: + await self.abort(request_id) + if abort_all: + for group_req_id in list(self.req_id_to_out_inf.keys()): + await self.abort(group_req_id) + pass + async def recycle_resource_loop(self): pre_time_mark = time.time() @@ -686,6 +773,11 @@ async def recycle_resource_loop(self): for req_status in release_req_status: self.req_id_to_out_inf.pop(req_status.group_req_objs.group_req_id, None) for req in req_status.group_req_objs.shm_req_objs: + if hasattr(req, "shm_routing_data") and req.shm_routing_data is not None: + try: + req.close_routing_data_shm_array() + except Exception as e: + logger.debug(f"Failed to close routing data shm for req {req.request_id}: {e}") await self.shm_req_manager.async_put_back_req_obj(req) await self.shm_req_manager.async_release_req_index(req.index_in_shm_mem) await self._release_multimodal_resources(req_status.group_req_objs.multimodal_params) @@ -721,65 +813,16 @@ async def handle_loop(self): while True: try: - await asyncio.wait_for(self.zmq_recv_socket.recv_pyobj(), timeout=0.05) + recv_obj = await asyncio.wait_for(self.zmq_recv_socket.recv_pyobj(), timeout=0.05) except asyncio.TimeoutError: - pass + recv_obj = None try: - for group_req_id_ in list(self.req_id_to_out_inf.keys()): - req_status = self.req_id_to_out_inf.get(group_req_id_, None) - if req_status is None: - continue - - token_list = [] - for req in req_status.group_req_objs.shm_req_objs: - req_id = req.request_id - read_token_count = 1 - if req.out_tokens_queue.is_full(): - read_token_count = LIGHTLLM_OUT_TOKEN_QUEUE_SIZE - - for _ in range(read_token_count): - if not req.out_tokens_queue.is_empty(): - - text, src_index, special, count_output_tokens = req.out_tokens_queue.peek() - req.cumlogprob += float(req.shm_logprobs.arr[src_index]) - metadata = { - "id": int(req.shm_prompt_ids.arr[src_index]), - "logprob": float(req.shm_logprobs.arr[src_index]), - "cumlogprob": float(req.cumlogprob) / count_output_tokens, - "special": special, - "count_output_tokens": count_output_tokens, - "prompt_cache_len": req.prompt_cache_len, - "cpu_prompt_cache_len": req.cpu_prompt_cache_len, - "disk_prompt_cache_len": req.disk_prompt_cache_len, - "mtp_accepted_token_num": req.mtp_accepted_token_num, - } - if self.args.return_all_prompt_logprobs: - metadata.update(req.get_all_prompt_metadata()) - if self.args.use_reward_model: - metadata["score"] = float(req.reward_score) - - req.out_tokens_queue.pop_no_ret() - - finished_token_index = ( - req.stop_str_matched_token_index if req.stop_str_matched else req.finish_token_index - ) - - if finished_token_index != src_index: - token_list.append((req_id, text, metadata, FinishStatus())) - else: - if req.stop_str_matched: - finish_status = FinishStatus(FinishStatus.FINISHED_STOP) - else: - finish_status = FinishStatus(req.finish_status.status) - - token_list.append((req_id, text, metadata, finish_status)) - else: - break + if recv_obj is None or isinstance(recv_obj, GenerateResp): + await self._handle_recv_generate_request(recv_obj) + elif isinstance(recv_obj, GeneralModelToHttpRpcRsp): + await self._handle_recv_general_model_to_http_request(recv_obj) - async with req_status.lock: - req_status.out_token_info_list.extend(token_list) - req_status.event.set() except BaseException as e: logger.exception(str(e)) raise e @@ -787,13 +830,189 @@ async def handle_loop(self): self.recycle_event.set() return + async def _handle_recv_generate_request(self, recv_obj: GenerateReqMeta): + for group_req_id_ in list(self.req_id_to_out_inf.keys()): + req_status = self.req_id_to_out_inf.get(group_req_id_, None) + if req_status is None: + continue + + token_list = [] + for req in req_status.group_req_objs.shm_req_objs: + req_id = req.request_id + read_token_count = 1 + if req.out_tokens_queue.is_full(): + read_token_count = LIGHTLLM_OUT_TOKEN_QUEUE_SIZE + + for _ in range(read_token_count): + if not req.out_tokens_queue.is_empty(): + + text, src_index, special, count_output_tokens = req.out_tokens_queue.peek() + req.cumlogprob += float(req.shm_logprobs.arr[src_index]) + metadata = { + "id": int(req.shm_prompt_ids.arr[src_index]), + "logprob": float(req.shm_logprobs.arr[src_index]), + "cumlogprob": float(req.cumlogprob) / count_output_tokens, + "special": special, + "count_output_tokens": count_output_tokens, + "prompt_cache_len": req.prompt_cache_len, + "cpu_prompt_cache_len": req.cpu_prompt_cache_len, + "mtp_accepted_token_num": req.mtp_accepted_token_num, + } + if self.args.return_all_prompt_logprobs: + metadata.update(req.get_all_prompt_metadata()) + if self.args.use_reward_model: + metadata["score"] = float(req.reward_score) + + req.out_tokens_queue.pop_no_ret() + + finished_token_index = ( + req.stop_str_matched_token_index if req.stop_str_matched else req.finish_token_index + ) + + if finished_token_index != src_index: + token_list.append((req_id, text, metadata, FinishStatus())) + else: + if req.stop_str_matched: + finish_status = FinishStatus(FinishStatus.FINISHED_STOP) + else: + finish_status = FinishStatus(req.finish_status.status) + + if self._routing_shm is not None: + _num_moe = int(self._routing_shm.arr[0]) + _topk = int(self._routing_shm.arr[1]) + _dtype_id = int(self._routing_shm.arr[2]) + if _num_moe > 0: + routing_meta = req.get_routing_metadata(_num_moe, _topk, dtype_id=_dtype_id) + if routing_meta is not None: + metadata["routed_experts"] = routing_meta + + token_list.append((req_id, text, metadata, finish_status)) + else: + break + + async with req_status.lock: + req_status.out_token_info_list.extend(token_list) + req_status.event.set() + + async def _handle_recv_general_model_to_http_request(self, recv_obj: GeneralModelToHttpRpcRsp): + assert recv_obj.func_name is not None + event = await self.get_event_for_func(recv_obj.func_name) + event.result = recv_obj + event.set() + return + + async def pause_generation(self): + # 因为请求是从master node转发到slave node的 + # 所以只要master暂停了,slave自然暂停。 + if self.is_pause: + return + async with self.is_pause_cond: + self.is_pause = True + while True: + await self.abort_request(AbortReq(request_id=None, abort_all=True)) + running_req_num = len(list(self.req_id_to_out_inf.keys())) + if running_req_num == 0: + break + await asyncio.sleep(1.0) + + async def continue_generation(self): + async with self.is_pause_cond: + self.is_pause = False + self.is_pause_cond.notify_all() + + async def get_event_for_func(self, func_name: str) -> asyncio.Event: + if func_name not in self.async_events_per_func: + self.async_events_per_func[func_name] = asyncio.Event() + return self.async_events_per_func[func_name] + + async def http_to_model_special_request( + self, request: GeneralHttpToModelRpcReq, timeout: int = 300 + ) -> GeneralModelToHttpRpcRsp: + event = await self.get_event_for_func(request.func_name) + await self.transfer_to_next_module(request) + try: + await asyncio.wait_for(event.wait(), timeout=timeout) + ret = event.result + + except asyncio.TimeoutError: + ret = GeneralModelToHttpRpcRsp(success=False, msg="wait for response timeout", func_name=request.func_name) + except Exception as e: + ret = GeneralModelToHttpRpcRsp( + success=False, msg="wait for response error: %s" % str(e), func_name=request.func_name + ) + return ret + + async def flush_cache(self, request: FlushCacheReq): + return await self.http_to_model_special_request( + GeneralHttpToModelRpcReq(func_name="flush_cache", func_args=request) + ) + + async def release_memory_occupation(self, request: ReleaseMemoryReq): + assert len(self.req_id_to_out_inf) == 0, "there are still requests running, cannot release memory occupation" + # 暂停接受请求,除非resume + await self.pause_generation() + return await self.http_to_model_special_request( + GeneralHttpToModelRpcReq(func_name="release_memory_occupation", func_args=request.tags) + ) + + async def resume_memory_occupation(self, request: ResumeMemoryReq): + ret = await self.http_to_model_special_request( + GeneralHttpToModelRpcReq(func_name="resume_memory_occupation", func_args=request.tags) + ) + if ret.success: + await self.continue_generation() + return ret + + async def init_weights_update_group(self, request: InitWeightsUpdateGroupReq): + return await self.http_to_model_special_request( + GeneralHttpToModelRpcReq(func_name="init_weights_update_group", func_args=request) + ) + + async def destroy_weights_update_group(self, request: DestroyWeightsUpdateGroupReq): + return await self.http_to_model_special_request( + GeneralHttpToModelRpcReq(func_name="destroy_weights_update_group", func_args=request) + ) + + async def update_weights_from_distributed(self, request: UpdateWeightsFromDistributedReq): + + if request.abort_all_requests: + await self.abort_request(AbortReq(abort_all=True)) + + if request.flush_cache: + await self.flush_cache(FlushCacheReq()) + + return await self.http_to_model_special_request( + GeneralHttpToModelRpcReq(func_name="update_weights_from_distributed", func_args=request) + ) + + async def update_weights_from_tensor(self, request: UpdateWeightsFromTensorReq) -> Tuple[bool, str]: + if request.abort_all_requests: + await self.abort_request(AbortReq(abort_all=True)) + + if request.flush_cache: + await self.flush_cache(FlushCacheReq()) + + return await self.http_to_model_special_request( + GeneralHttpToModelRpcReq(func_name="update_weights_from_tensor", func_args=request) + ) + class ReqStatus: - def __init__(self, group_request_id, multimodal_params, req_objs: List[Req], start_time) -> None: + def __init__( + self, + group_request_id: int, + prompt: str, + sampling_params: SamplingParams, + multimodal_params: MultimodalParams, + req_objs: List[Req], + start_time, + ) -> None: self.lock = asyncio.Lock() self.event = asyncio.Event() - self.group_req_objs = GroupReqObjs( + self.group_req_objs = GenerateReq( group_req_id=group_request_id, + prompt=prompt, + sampling_params=sampling_params, multimodal_params=multimodal_params, shm_req_objs=req_objs, time_mark=start_time, diff --git a/lightllm/server/io_struct.py b/lightllm/server/io_struct.py new file mode 100644 index 0000000000..e04e8871ce --- /dev/null +++ b/lightllm/server/io_struct.py @@ -0,0 +1,195 @@ +from abc import ABC +from dataclasses import dataclass +from lightllm.server.core.objs.req import Req +from lightllm.server.core.objs.sampling_params import SamplingParams +from lightllm.server.multimodal_params import MultimodalParams +from typing import List, Optional, Any, Union +from lightllm.utils.torch_memory_saver_utils import MemoryTag + + +@dataclass +class BaseReq(ABC): + def get_req_to_next_node(self): + return self + + def get_req_to_next_module(self): + return self + + +@dataclass +class BaseRsp(ABC): + success: bool + msg: Optional[str] + + +# for next node +@dataclass +class GenerateReqMeta(BaseReq): + prompt: str + sampling_params: SamplingParams + multimodal_params: MultimodalParams + + +# for next module +@dataclass +class GenerateReqIndex(BaseReq): + group_req_id: int + multimodal_params: MultimodalParams + shm_req_indexes: List[int] + time_mark: float + + +@dataclass +class GenerateReq(BaseReq): + group_req_id: int + prompt: str + sampling_params: SamplingParams + multimodal_params: MultimodalParams + shm_req_objs: List[Req] + time_mark: float + + def get_req_to_next_module(self): + # 已经完成跨节点转发,可以释放图片原始资源 + self.multimodal_params.free() + return GenerateReqIndex( + group_req_id=self.group_req_id, + multimodal_params=self.multimodal_params, + shm_req_indexes=[req.index_in_shm_mem for req in self.shm_req_objs], + time_mark=self.time_mark, + ) + + def get_req_to_next_node(self): + return GenerateReqMeta( + prompt=self.prompt, + sampling_params=self.sampling_params, + multimodal_params=self.multimodal_params, + ) + + +@dataclass +class GenerateResp(BaseReq): + pass + + +@dataclass +class FlushCacheReq(BaseReq): + pass + + +@dataclass +class FlushCacheResp(BaseReq): + success: bool + + +@dataclass +class AbortReq(BaseReq): + # 外部调用传入,等同内部的 group_req_id + request_id: int = None + abort_all: bool = False + + +@dataclass +class ReleaseMemoryReq(BaseReq): + tags: Optional[List[MemoryTag]] = None + + +@dataclass +class ReleaseMemoryResp(BaseReq): + success: bool + + +@dataclass +class ResumeMemoryReq(BaseReq): + tags: Optional[List[MemoryTag]] = None + + +@dataclass +class ResumeMemoryResp(BaseReq): + success: bool + + +@dataclass +class GeneralHttpToModelRpcReq(BaseReq): + func_name: str + func_args: Optional[Any] = None + + +@dataclass +class GeneralModelToHttpRpcRsp(BaseRsp): + func_name: str + func_rsp: Optional[Any] = None + + +@dataclass +class InitWeightsUpdateGroupReq(BaseReq): + # The master address + master_address: str + # The master port + master_port: int + # The rank offset + rank_offset: int + # The world size + world_size: int + # The group name + group_name: str = "weight_update_group" + # The backend + backend: str = "nccl" + + +@dataclass +class InitWeightsUpdateGroupRsp(BaseRsp): + pass + + +@dataclass +class DestroyWeightsUpdateGroupReq(BaseReq): + group_name: str = "weight_update_group" + + +@dataclass +class DestroyWeightsUpdateGroupRsp(BaseRsp): + pass + + +@dataclass +class UpdateWeightsFromDistributedReq(BaseReq): + names: List[str] + dtypes: List[str] + shapes: List[List[int]] + # The group name + group_name: str = "weight_update_group" + # Whether to flush the cache after updating weights + flush_cache: bool = True + # Whether to abort all requests before updating weights + abort_all_requests: bool = False + # Optional: Update weight version along with weights + weight_version: Optional[str] = None + + +@dataclass +class UpdateWeightsFromDistributedRsp(BaseRsp): + pass + + +@dataclass +class UpdateWeightsFromTensorReq(BaseReq): + """Update model weights from tensor input. + + - Tensors are serialized for transmission + - Data is structured in JSON for easy transmission over HTTP + """ + + serialized_named_tensors: List[Union[str, bytes]] + # Optional format specification for loading + load_format: Optional[str] = None + # Whether to flush the cache after updating weights + flush_cache: bool = True + # Whether to abort all requests before updating weights + abort_all_requests: bool = False + # Optional: Update weight version along with weights + weight_version: Optional[str] = None + + +@dataclass +class UpdateWeightsFromTensorRsp(BaseRsp): + pass diff --git a/lightllm/server/multi_level_kv_cache/manager.py b/lightllm/server/multi_level_kv_cache/manager.py index 1de1b502c9..b8ff820ec5 100644 --- a/lightllm/server/multi_level_kv_cache/manager.py +++ b/lightllm/server/multi_level_kv_cache/manager.py @@ -12,7 +12,7 @@ from queue import Queue from typing import List from lightllm.server.core.objs import ShmReqManager, Req, StartArgs -from lightllm.server.core.objs.io_objs import GroupReqIndexes +from lightllm.server.io_struct import GenerateReqIndex, BaseReq from lightllm.utils.graceful_utils import graceful_registry from .cpu_cache_client import CpuKvCacheClient from lightllm.utils.log_utils import init_logger @@ -135,7 +135,7 @@ def _disk_cache_match(self, token_hash_list: List[int], all_pages: List[int]) -> self.cpu_cache_client.lock.release() return all_pages, len(new_page_indexes) - def _handle_group_req_multi_cache_match(self, group_req_indexes: GroupReqIndexes, start_time: float): + def _handle_group_req_multi_cache_match(self, group_req_indexes: GenerateReqIndex, start_time: float): """ match cpu cache and disk cache pages """ @@ -198,8 +198,9 @@ def recv_loop(self): try: # 一次最多从 zmq 中取 recv_max_count 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。 for _ in range(recv_max_count): - recv_obj: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) - assert isinstance(recv_obj, GroupReqIndexes) + recv_obj: BaseReq = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) + if not isinstance(recv_obj, GenerateReqIndex): + continue recv_objs.append(recv_obj) start_time = recv_obj.time_mark diff --git a/lightllm/server/multimodal_params.py b/lightllm/server/multimodal_params.py index d957c7649a..72db3d80cd 100644 --- a/lightllm/server/multimodal_params.py +++ b/lightllm/server/multimodal_params.py @@ -56,9 +56,11 @@ def read(self): assert self._preload_data is not None ans = self._preload_data self._preload_data = None - self._data = None return ans + def free(self): + self._data = None + def to_dict(self): ret = {} ret["uuid"] = self.uuid @@ -121,9 +123,11 @@ def read(self): assert self._preload_data is not None ans = self._preload_data self._preload_data = None - self._data = None return ans + def free(self): + self._data = None + def to_dict(self): ret = {} ret["uuid"] = self.uuid @@ -174,3 +178,10 @@ def to_origin_dict(self): ret = {} ret["images"] = [i.to_origin_dict() for i in self.images] return ret + + def free(self): + for image in self.images: + image.free() + for audio in self.audios: + audio.free() + return diff --git a/lightllm/server/req_id_generator.py b/lightllm/server/req_id_generator.py index 9bf9040c30..20da121dc0 100644 --- a/lightllm/server/req_id_generator.py +++ b/lightllm/server/req_id_generator.py @@ -30,7 +30,8 @@ def __init__(self): self.current_id.arr[0] = 0 self.current_id.arr[1] = 0 self.lock = AtomicShmLock(f"{get_unique_server_name()}_req_id_gen_lock") - self._wait_all_workers_ready() + if self.args.httpserver_workers > 1: + self._wait_all_workers_ready() logger.info("ReqIDGenerator init finished") def _wait_all_workers_ready(self): diff --git a/lightllm/server/router/dynamic_prompt/radix_cache.py b/lightllm/server/router/dynamic_prompt/radix_cache.py index c517748984..e96b1d1a32 100644 --- a/lightllm/server/router/dynamic_prompt/radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/radix_cache.py @@ -424,6 +424,30 @@ def clear_tree_nodes(self): self.refed_tokens_num.arr[0] = 0 return + def flush_cache(self): + nodes_to_clear = collections.deque(self.root_node.children.values()) + self.root_node.children.clear() + while nodes_to_clear: + node = nodes_to_clear.popleft() + nodes_to_clear.extend(node.children.values()) + node.parent = None + node.children.clear() + + self.root_node.token_id_key[:] = 0 + self.root_node.token_mem_index_value[:] = 0 + self.root_node.ref_counter = 1 # 保持为1,确保不会被evict + self.root_node.time_id = time_gen.generate_time_id() + self.root_node.node_value_len = 0 + self.root_node.node_prefix_total_len = 0 + + self.evict_tree_set.clear() + self.evict_tree_set.add(self.root_node) + + self.tree_total_tokens_num.arr[0] = 0 + self.refed_tokens_num.arr[0] = 0 + + return + def dec_node_ref_counter(self, node: TreeNode): if node is None: return diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index ac5c1abee3..7d5cfa7bea 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -5,6 +5,7 @@ import pickle import inspect import setproctitle +import rpyc asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) import zmq @@ -17,7 +18,6 @@ from .model_infer.model_rpc import start_model_process, ModelRpcClient from .req_queue import build_req_queue from lightllm.server.core.objs.io_objs import ( - GroupReqIndexes, AbortedReqCmd, StopStrMatchedReqCmd, ) @@ -29,11 +29,23 @@ from lightllm.server.router.token_load import TokenLoad from lightllm.server.metrics.manager import MetricClient from lightllm.common.basemodel.infer_lock import g_router_lock +from lightllm.server.io_struct import ( + BaseReq, + GenerateReqIndex, + FlushCacheReq, + FlushCacheResp, + ReleaseMemoryReq, + ReleaseMemoryResp, + ResumeMemoryReq, + ResumeMemoryResp, + GeneralHttpToModelRpcReq, + GeneralModelToHttpRpcRsp, +) from lightllm.common.kv_cache_mem_manager import ReadOnlyStaticsMemoryManager from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.process_check import start_parent_check_thread from lightllm.utils.envs_utils import get_unique_server_name - +from lightllm.utils.torch_memory_saver_utils import MemoryTag logger = init_logger(__name__) @@ -356,8 +368,13 @@ def _get_aborted_reqs_from_running_batch(self) -> List[Req]: ans = [] if self.running_batch is None: return ans - for req in self.running_batch.reqs: - if req.is_aborted and req._router_aborted is False: + aborted_req_mask = torch.tensor( + [req.is_aborted for req in self.running_batch.reqs], dtype=torch.bool, device="cpu" + ) + if self.is_multinode_tp: + dist.all_reduce(aborted_req_mask, op=dist.ReduceOp.MIN, group=self.mulitnode_group) + for req, is_aborted in zip(self.running_batch.reqs, aborted_req_mask.numpy()): + if is_aborted and req._router_aborted is False: req._router_aborted = True ans.append(req) return ans @@ -406,7 +423,7 @@ def get_used_tokens(self, dp_index): else: return self.max_total_token_num - self.read_only_statics_mem_manager.get_unrefed_token_num(dp_index) - def _add_req(self, group_req_indexes: GroupReqIndexes): + def _add_req(self, group_req_indexes: BaseReq): req_group = [] for req_index in group_req_indexes.shm_req_indexes: req = self.shm_req_manager.get_req_obj_by_index(req_index) @@ -452,9 +469,22 @@ def _multinode_tp_generate_new_batch(self): dist.broadcast_object_list(req_ids, src=0, group=self.mulitnode_group) req_id_select_mark = [1 for _ in range(len(req_ids))] req_id_select_mark = torch.tensor(req_id_select_mark, dtype=torch.int32, device="cpu") + # TODO: 这里可以合成一个 allreudce,req_id_select_mark + aborted_req_mask dist.all_reduce(req_id_select_mark, op=dist.ReduceOp.MIN, group=self.mulitnode_group) + aborted_req_mask = torch.tensor( + [req.is_aborted for req in new_batch.reqs], dtype=torch.bool, device="cpu" + ) + dist.all_reduce(aborted_req_mask, op=dist.ReduceOp.MIN, group=self.mulitnode_group) back_req_list = [] - for req_id, select in zip(req_ids, req_id_select_mark.numpy()): + for req_id, select, is_aborted in zip( + req_ids, req_id_select_mark.numpy(), aborted_req_mask.numpy() + ): + # 释放多节点abort 请求,如果select == 0, is_aborted 一定为False + if is_aborted and select == 1: + req = new_batch.pop_req(req_id) + self.req_queue.free_aborted_req(req) + self.shm_req_manager.put_back_req_obj(req) + continue if select == 0: req = new_batch.pop_req(req_id) back_req_list.append(req) @@ -470,23 +500,28 @@ def _multinode_tp_generate_new_batch(self): else: req_ids = [None for _ in range(req_num)] dist.broadcast_object_list(req_ids, src=0, group=self.mulitnode_group) - all_req_id_set = set([req.request_id for req in self.req_queue.waiting_req_list]) + # all_req_id_set = set([req.request_id for req in self.req_queue.waiting_req_list]) + id_to_req_obj = {req.request_id: req for req in self.req_queue.waiting_req_list} req_id_select_mark = [] + aborted_req_mask = [] for req_id in req_ids: - req_id_select_mark.append(1 if req_id in all_req_id_set else 0) + req_id_select_mark.append(1 if req_id in id_to_req_obj else 0) + aborted_req_mask.append(id_to_req_obj[req_id].is_aborted if req_id in id_to_req_obj else False) req_id_select_mark = torch.tensor(req_id_select_mark, dtype=torch.int32, device="cpu") dist.all_reduce(req_id_select_mark, op=dist.ReduceOp.MIN, group=self.mulitnode_group) - select_req_ids = [] - for req_id, select in zip(req_ids, req_id_select_mark.numpy()): - if select == 1: - select_req_ids.append(req_id) - + aborted_req_mask = torch.tensor(aborted_req_mask, dtype=torch.bool, device="cpu") + dist.all_reduce(aborted_req_mask, op=dist.ReduceOp.MIN, group=self.mulitnode_group) select_reqs = [] - for req_id in select_req_ids: - for req in self.req_queue.waiting_req_list: - if req.request_id == req_id: - select_reqs.append(req) - + for req_id, select, is_aborted in zip( + req_ids, req_id_select_mark.numpy(), aborted_req_mask.numpy() + ): + if select == 1: + req = id_to_req_obj[req_id] + if is_aborted: + self.req_queue.free_aborted_req(req) + self.shm_req_manager.put_back_req_obj(req) + continue + select_reqs.append(req) for req in select_reqs: self.req_queue.waiting_req_list.remove(req) if select_reqs: @@ -507,13 +542,17 @@ async def _recv_new_reqs_and_schedule(self): self.recv_max_count = 64 try: + # 多机tp需要广播给其他node的请求 + special_reqs = [] # 一次最多从 zmq 中取 recv_max_count 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。 for _ in range(self.recv_max_count): - recv_req: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) - if isinstance(recv_req, GroupReqIndexes): + recv_req: BaseReq = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) + if isinstance(recv_req, GenerateReqIndex): self._add_req(recv_req) + elif isinstance(recv_req, GeneralHttpToModelRpcReq): + special_reqs.append(recv_req) else: - assert False, f"Error Req Inf {recv_req}" + raise ValueError(f"Unknown request type: {type(recv_req)}") # 当队列中存在较多的请求时,将一次接受的数量上调 self.recv_max_count = min(int(self.recv_max_count * 1.3), 256) @@ -522,6 +561,8 @@ async def _recv_new_reqs_and_schedule(self): # 当队列已经开始清空的时候,将一次接受的数量下调 self.recv_max_count = 64 + self._process_special_reqs(special_reqs) + if self.is_multinode_tp: self._multinode_tp_generate_new_batch() else: @@ -529,6 +570,44 @@ async def _recv_new_reqs_and_schedule(self): self._generate_new_batch() return + def _process_special_reqs(self, special_reqs: List[BaseReq]): + if self.is_multinode_tp: + special_reqs = self.broadcast_reqs_to_other_nodes(special_reqs) + for req in special_reqs: + assert isinstance(req, GeneralHttpToModelRpcReq), "special request must be GeneralHttpToModelRpcReq" + self.forward_to_model(req) + + def broadcast_reqs_to_other_nodes(self, reqs: List[BaseReq]): + req_num = len(reqs) + if self.node_rank == 0: + req_nums = [len(reqs)] + dist.broadcast_object_list(req_nums, src=0, group=self.mulitnode_group) + req_num = req_nums[0] + if req_num > 0: + dist.broadcast_object_list(reqs, src=0, group=self.mulitnode_group) + else: + req_nums = [None] + dist.broadcast_object_list(req_nums, src=0, group=self.mulitnode_group) + req_num = req_nums[0] + if req_num > 0: + reqs = [None for _ in range(req_num)] + dist.broadcast_object_list(reqs, src=0, group=self.mulitnode_group) + return reqs + + def forward_to_model(self, req: GeneralHttpToModelRpcReq) -> None: + ret = self.model_rpc_client.forward_to_model(req) + if self.is_multinode_tp: + output_list = [None for _ in self.nnodes] if self.node_rank == 0 else None + dist.gather_object(ret, output_list, dst=0, group=self.mulitnode_group) + for res in output_list: + res: GeneralModelToHttpRpcRsp + if not res.success: + ret = res + break + + if self.node_rank == 0: + self.send_to_detokenization.send_pyobj(ret, protocol=pickle.HIGHEST_PROTOCOL) + def clean_up(self): return diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 4b8b3c538f..66aeb6e95d 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -19,6 +19,7 @@ from lightllm.utils.envs_utils import get_env_start_args from lightllm.server.pd_io_struct import NIXLDecodeNodeInfo from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient +from lightllm.common.basemodel import routing_manager as _routing_mgr logger = init_logger(__name__) @@ -113,6 +114,16 @@ def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_prefix_cache: return req_objs + def _extract_routing_data(self, req: "InferReq"): + if req.shm_req.shm_routing_num_tokens > 0: + return + mem_indexes = self.req_manager.req_to_token_indexs[req.req_idx][0 : req.cur_kv_len] + mgr = _routing_mgr.g_routing_capture_manager + routing_data = mgr.extract_routing_data(mem_indexes) + req.shm_req.create_routing_data_shm_array(mgr.num_moe_layers, req.cur_kv_len, mgr.topk, np_dtype=mgr.np_dtype) + req.shm_req.shm_routing_data.arr[:] = routing_data + req.shm_req.shm_routing_data.detach_shm() + def free_a_req_mem(self, free_token_index: List, req: "InferReq"): if self.radix_cache is None: free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][0 : req.cur_kv_len]) @@ -149,12 +160,18 @@ def _filter(self, finished_request_ids: List[int]): if len(finished_request_ids) == 0: return + need_routing_data = _routing_mgr.g_routing_capture_manager is not None + free_req_index = [] free_token_index = [] for request_id in finished_request_ids: req: InferReq = self.requests_mapping.pop(request_id) if self.args.diverse_mode: req.clear_master_slave_state() + + if need_routing_data: + self._extract_routing_data(req) + self.free_a_req_mem(free_token_index, req) free_req_index.append(req.req_idx) @@ -266,6 +283,7 @@ def __init__( self.fsm_current_state: int = 0 self.allowed_token_ids = self.shm_param.allowed_token_ids.to_list() + self.invalid_token_ids = self.shm_param.invalid_token_ids.to_list() if len(self.allowed_token_ids) == 0: self.allowed_token_ids = None @@ -281,6 +299,11 @@ def __init__( logger.error("allowed_token_ids contain tokenid >= vobsize, we remove these token ids") self.allowed_token_ids = [e for e in self.allowed_token_ids if e < vocab_size] + if len(self.invalid_token_ids) > 0: + if not all(e < vocab_size for e in self.invalid_token_ids): + logger.error("invalid_token_ids contain tokenid >= vobsize, we remove these token ids") + self.invalid_token_ids = [e for e in self.invalid_token_ids if e < vocab_size] + # nixl decode node information if self.shm_param.nixl_params.data_len > 0: self.nixl_decode_node: NIXLDecodeNodeInfo = pickle.loads(self.shm_param.nixl_params.get()) @@ -491,6 +514,8 @@ def update_finish_status(self, eos_ids, output_len: int): self.finish_status.set_status(FinishStatus.FINISHED_STOP) elif output_len >= self.sampling_param.shm_param.max_new_tokens: self.finish_status.set_status(FinishStatus.FINISHED_LENGTH) + elif self.infer_aborted: + self.finish_status.set_status(FinishStatus.FINISHED_ABORTED) return def _stop_sequences_matched(self, output_len: int): @@ -580,6 +605,8 @@ def handle( shm_req.shm_cur_output_len = self.output_len if finish_status.is_finished(): + if _routing_mgr.g_routing_capture_manager is not None: + g_infer_context._extract_routing_data(req_obj) shm_req.finish_token_index = shm_req.input_len + self.output_len - 1 shm_req.finish_status = req_obj.finish_status diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 64310d6b0d..70b0ec9ebf 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -4,7 +4,7 @@ import time import threading import torch.distributed as dist -from typing import List, Tuple, Callable, Optional +from typing import List, Tuple, Callable, Optional, Union from transformers.configuration_utils import PretrainedConfig from lightllm.utils.infer_utils import set_random_seed from lightllm.utils.log_utils import init_logger @@ -16,7 +16,7 @@ from lightllm.common.basemodel.basemodel import TpPartBaseModel from lightllm.common.basemodel.batch_objs import ModelOutput, ModelInput from lightllm.common.basemodel.triton_kernel.mtp_utils import mtp_verify -from lightllm.utils.dist_utils import init_distributed_env +from lightllm.utils.dist_utils import init_distributed_env, init_custom_process_group from lightllm.utils.envs_utils import get_unique_server_name from lightllm.server.core.objs import ShmReqManager, StartArgs from lightllm.server.core.objs.io_objs import AbortedReqCmd, StopStrMatchedReqCmd @@ -31,6 +31,9 @@ enable_radix_tree_timer_merge, get_radix_tree_merge_update_delta, ) +from lightllm.utils.serializer import LocalSerializedTensor, MultiprocessingSerializer +from lightllm.utils.patch_torch import monkey_patch_torch_reductions +from lightllm.utils.tensor_bucket import FlattenedTensorBucket, FlattenedTensorMetadata from lightllm.distributed import dist_group_manager from lightllm.server.core.objs.shm_objs_io_buffer import ShmObjsIOBuffer from lightllm.server.router.model_infer.mode_backend.overlap_events import OverlapEventManager, OverlapEventPack @@ -41,7 +44,16 @@ from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token from lightllm.server.pd_io_struct import NIXLChunckedTransTaskRet +from lightllm.common.basemodel import routing_manager as _routing_mgr +from lightllm.utils.torch_memory_saver_utils import MemoryTag from .multi_level_kv_cache import MultiLevelKvCacheModule +from lightllm.server.io_struct import ( + FlushCacheReq, + InitWeightsUpdateGroupReq, + DestroyWeightsUpdateGroupReq, + UpdateWeightsFromDistributedReq, + UpdateWeightsFromTensorReq, +) class ModeBackend: @@ -114,6 +126,8 @@ def init_model(self, kvargs): ) dist_group_manager.create_groups(group_size=group_size) # set the default group + self._model_update_group = {} + self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", self.dp_size_in_node) # 为 p d 分离模式添加的全局锁管理,用于做一些同步操作。 一定需要在 @@ -338,6 +352,191 @@ def init_mtp_draft_model(self, main_kvargs: dict): self.logger.info(f"loaded mtp model class {self.draft_models[i].__class__}") return + def flush_cache(self, request: FlushCacheReq): + if self.radix_cache is not None: + self.radix_cache.flush_cache() + return True, "Succeeded to flush cache." + + def release_memory_occupation(self, tags: List[MemoryTag]): + try: + self.model.release_memory_occupation(tags) + self.flush_cache(request=None) + self.model.req_manager.free_all() + self.model.mem_manager.free_all() + return True, "Succeeded to release memory occupation." + except Exception as e: + self.logger.error(f"release memory occupation failed: {str(e)}") + return False, f"release memory occupation failed: {str(e)}" + + def resume_memory_occupation(self, tags: List[MemoryTag]): + try: + self.model.resume_memory_occupation(tags) + return True, "Succeeded to resume memory occupation." + except Exception as e: + self.logger.error(f"resume memory occupation failed: {str(e)}") + return False, f"resume memory occupation failed: {str(e)}" + + def init_weights_update_group(self, request: InitWeightsUpdateGroupReq): + assert torch.distributed.is_initialized(), "Default torch process group must be initialized" + + assert request.group_name != "", "Group name cannot be empty" + rank_offset = request.rank_offset + rank = rank_offset + self.rank_in_dp + world_size = request.world_size + group_name = request.group_name + self.logger.info( + f"init custom process group: master_address={request.master_address}, master_port={request.master_port}, " + f"rank_offset={rank_offset}, rank={rank}, world_size={world_size}, group_name={group_name}, " + f" backend={request.backend}" + ) + + try: + if group_name in self._model_update_group: + raise ValueError(f"Process group with name {group_name} already exists.") + + self._model_update_group[group_name] = init_custom_process_group( + backend=request.backend, + init_method=f"tcp://{request.master_address}:{request.master_port}", + world_size=world_size, + rank=rank, + group_name=group_name, + ) + return True, "Succeeded to initialize custom process group." + + except Exception as e: + message = f"Failed to initialize custom process group: {e}." + self.logger.error(message) + return False, message + + def destroy_weights_update_group(self, request: DestroyWeightsUpdateGroupReq): + try: + if request.group_name in self._model_update_group: + pg = self._model_update_group.pop(request.group_name) + torch.distributed.destroy_process_group(pg) + return True, "Succeeded to destroy custom process group." + else: + return False, "The group to be destroyed does not exist." + except Exception as e: + message = f"Failed to destroy custom process group: {e}." + self.logger.error(message) + return False, message + + def update_weights_from_distributed(self, request: UpdateWeightsFromDistributedReq): + """ + Update specific parameter in the model weights online + through `_model_update_group` process group. + + Args: + name: the name of the parameter to be updated. + dtype: the data type of the parameter to be updated. + shape: the shape of the parameter to be updated. + """ + + assert request.group_name in self._model_update_group, ( + f"Group {request.group_name} not in {list(self._model_update_group.keys())}. " + "Please call `init_weights_update_group` first." + ) + + try: + weights = [] + handles = [] + for name, dtype, shape in zip(request.names, request.dtypes, request.shapes): + target_dtype = dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype) + weight = torch.empty(shape, dtype=target_dtype, device="cuda") + handles.append( + torch.distributed.broadcast( + weight, + src=0, + group=self._model_update_group[request.group_name], + async_op=True, + ) + ) + weights.append((name, weight)) + for handle in handles: + handle.wait() + + self.model.load_weights(weights) + return True, "Succeeded to update parameter online from distributed." + + except Exception as e: + error_msg = ( + f"Failed to update parameter online: {e}. " + f"The full weights of the ModelRunner are partially updated. " + f"Please discard the whole weights." + ) + self.logger.error(error_msg) + return False, error_msg + + def _update_weights_from_flattened_bucket( + self, + flattened_tensor_bucket_dict, + ): + """Handle flattened bucket format for weight updates""" + flattened_tensor = flattened_tensor_bucket_dict["flattened_tensor"] + metadata = flattened_tensor_bucket_dict["metadata"] + + # Convert metadata dict to our format + converted_metadata = [] + for meta in metadata: + converted_meta = FlattenedTensorMetadata( + name=meta.name, + shape=meta.shape, + dtype=meta.dtype, + start_idx=meta.start_idx, + end_idx=meta.end_idx, + numel=meta.numel, + ) + converted_metadata.append(converted_meta) + + # Create bucket and reconstruct tensors + bucket = FlattenedTensorBucket(flattened_tensor=flattened_tensor, metadata=converted_metadata) + reconstructed_tensors = bucket.reconstruct_tensors() + + named_tensors = {name: tensor for name, tensor in reconstructed_tensors} + + # Load the reconstructed tensors using the standard method + self.model.load_weights(named_tensors) + + return True, "Succeeded to update parameter online from flattened bucket tensor." + + def update_weights_from_tensor(self, request: UpdateWeightsFromTensorReq): + try: + monkey_patch_torch_reductions() + if request.load_format == "flattened_bucket": + # Handle flattened bucket format + serialized_named_tensors = MultiprocessingSerializer.deserialize( + request.serialized_named_tensors[self.rank_in_dp] + ) + return self._update_weights_from_flattened_bucket(flattened_tensor_bucket_dict=serialized_named_tensors) + + # We need to get device after patch otherwise the device would be wrong + self.device_module = torch.get_device_module("cuda") + infered_device = self.device_module.current_device() + + named_tensors = MultiprocessingSerializer.deserialize(request.serialized_named_tensors[self.rank_in_dp]) + + def _unwrap_tensor(tensor, tp_rank, device): + if isinstance(tensor, LocalSerializedTensor): + tensor = tensor.get(tp_rank) + clone = tensor.to(device).clone() + del tensor # free the ipc tensor + return clone + + named_tensors = { + name: _unwrap_tensor(tensor, tp_rank=self.rank_in_dp, device=infered_device) + for name, tensor in named_tensors + } + + self.model.load_weights(named_tensors) + + return True, "Succeeded to update parameter online from tensor." + + except Exception as e: + message = f"Failed to update parameter online from tensor. Reason: {e}." + self.logger.error(message) + + return False, message + def _async_copy_next_token_infos_to_pin_mem(self, next_token_ids: torch.Tensor, next_token_logprobs: torch.Tensor): """ 这个函数会把next token id和logprobs保存到pinned memory中 @@ -798,6 +997,18 @@ def _sample_and_scatter_token( ) return next_token_ids, next_token_ids_cpu, next_token_logprobs_cpu + def _flush_routing_to_kv_buffer(self, mem_indexes: torch.Tensor, microbatch_index: int = 0) -> None: + """Scatter captured routing data from capture buffer to KV-indexed GPU buffer. + + Must be called AFTER model.forward() completes. mem_indexes should be the + original (unpadded) tensor — either CPU or CUDA. + """ + if _routing_mgr.g_routing_capture_manager is not None and mem_indexes is not None: + if not mem_indexes.is_cuda: + mem_indexes = mem_indexes.cuda(non_blocking=True) + num_tokens = mem_indexes.shape[0] + _routing_mgr.g_routing_capture_manager.flush_to_routing_buffer(mem_indexes, num_tokens, microbatch_index) + def _dp_all_gather_prefill_and_decode_req_num( self, prefill_reqs: List[InferReq], decode_reqs: List[InferReq] ) -> Tuple[np.ndarray, np.ndarray]: diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index a8a5224ebc..9f4443e48e 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -109,6 +109,7 @@ def prefill_normal( model_input, run_reqs = prepare_prefill_inputs(prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) + self._flush_routing_to_kv_buffer(model_input.mem_indexes) _, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( logits=model_output.logits, b_req_idx=model_input.b_req_idx, @@ -148,6 +149,7 @@ def decode_normal( model_input, run_reqs = prepare_decode_inputs(decode_reqs) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) + self._flush_routing_to_kv_buffer(model_input.mem_indexes) _, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( logits=model_output.logits, b_req_idx=model_input.b_req_idx, @@ -186,6 +188,7 @@ def prefill_mtp( model_input, run_reqs = prepare_prefill_inputs(prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) + self._flush_routing_to_kv_buffer(model_input.mem_indexes) next_token_ids, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( logits=model_output.logits, b_req_idx=model_input.b_req_idx, @@ -236,6 +239,7 @@ def decode_mtp( with torch.cuda.stream(g_infer_context.get_overlap_stream()): b_mtp_index_cpu = model_input.b_mtp_index model_output = self.model.forward(model_input) + self._flush_routing_to_kv_buffer(model_input.mem_indexes) next_token_ids, next_token_logprobs = sample(model_output.logits, run_reqs, self.eos_id) # verify the next_token_ids b_req_mtp_start_loc = [index for index, mtp_index in enumerate(b_mtp_index_cpu) if mtp_index == 0] diff --git a/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py index 5a179cb620..ebc55b7ef4 100644 --- a/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py @@ -40,8 +40,8 @@ def beam_prefill(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq ) with torch.cuda.stream(g_infer_context.get_overlap_stream()): - model_output = self.model.forward(model_input) + self._flush_routing_to_kv_buffer(model_input.mem_indexes) logits = model_output.logits batch_idx, run_reqs = self._diverse_copy( diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index bb0e848e76..f01e5fe935 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -145,6 +145,7 @@ def prefill_normal( run_reqs_num = len(run_reqs) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) + self._flush_routing_to_kv_buffer(model_input.mem_indexes) if run_reqs_num > 0: _, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( logits=model_output.logits[:run_reqs_num], @@ -188,6 +189,7 @@ def decode_normal(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq run_reqs_num = len(run_reqs) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) + self._flush_routing_to_kv_buffer(model_input.mem_indexes) if run_reqs_num > 0: _, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( logits=model_output.logits[:run_reqs_num], @@ -236,6 +238,8 @@ def prefill_overlap(self, event_pack: OverlapEventPack, prefill_reqs: List[Infer with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output0, model_output1 = self.model.microbatch_overlap_prefill(model_input0, model_input1) + self._flush_routing_to_kv_buffer(model_input0.mem_indexes, microbatch_index=0) + self._flush_routing_to_kv_buffer(model_input1.mem_indexes, microbatch_index=1) logits0 = model_output0.logits logits1 = model_output1.logits @@ -305,6 +309,8 @@ def decode_overlap(self, event_pack: OverlapEventPack, decode_reqs: List[InferRe with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output0, model_output1 = self.model.microbatch_overlap_decode(model_input0, model_input1) + self._flush_routing_to_kv_buffer(model_input0.mem_indexes, microbatch_index=0) + self._flush_routing_to_kv_buffer(model_input1.mem_indexes, microbatch_index=1) logits0 = model_output0.logits logits1 = model_output1.logits @@ -359,6 +365,7 @@ def prefill_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq] req_num = len(run_reqs) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output: ModelOutput = self.model.forward(model_input) + self._flush_routing_to_kv_buffer(model_input.mem_indexes) b_has_out_cpu = model_input.b_prefill_has_output_cpu[0:req_num] logits = model_output.logits[0:req_num, :] b_req_idx = model_input.b_req_idx[0:req_num] @@ -421,6 +428,7 @@ def decode_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq]): with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) + self._flush_routing_to_kv_buffer(model_input.mem_indexes) mtp_accept_len, b_req_mtp_start_loc, next_token_ids = None, None, None if req_num > 0: logits = model_output.logits[0:req_num, :] @@ -629,6 +637,8 @@ def prefill_overlap_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[I ) = padded_overlap_prepare_prefill_inputs(prefill_reqs) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output0, model_output1 = self.model.microbatch_overlap_prefill(model_input0, model_input1) + self._flush_routing_to_kv_buffer(model_input0.mem_indexes, microbatch_index=0) + self._flush_routing_to_kv_buffer(model_input1.mem_indexes, microbatch_index=1) logits0 = model_output0.logits logits1 = model_output1.logits req_num0, req_num1 = len(run_reqs0), len(run_reqs1) @@ -726,8 +736,9 @@ def decode_overlap_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[Inf b_mtp_index_cpu0 = model_input0.b_mtp_index b_mtp_index_cpu1 = model_input1.b_mtp_index with torch.cuda.stream(g_infer_context.get_overlap_stream()): - model_output0, model_output1 = self.model.microbatch_overlap_decode(model_input0, model_input1) + self._flush_routing_to_kv_buffer(model_input0.mem_indexes, microbatch_index=0) + self._flush_routing_to_kv_buffer(model_input1.mem_indexes, microbatch_index=1) logits0 = model_output0.logits logits1 = model_output1.logits run_reqs = run_reqs0 + run_reqs1 diff --git a/lightllm/server/router/model_infer/mode_backend/generic_post_process.py b/lightllm/server/router/model_infer/mode_backend/generic_post_process.py index e2ccf290e8..fc551b08ea 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_post_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_post_process.py @@ -1,7 +1,8 @@ import torch from typing import List -from lightllm.common.basemodel.triton_kernel.apply_penalty import apply_penalty -from lightllm.common.basemodel.triton_kernel.apply_penalty_gpu_cache import apply_penalty_gpu_cache +from lightllm.common.basemodel.triton_kernel.post_process.apply_penalty import apply_penalty +from lightllm.common.basemodel.triton_kernel.post_process.apply_penalty_gpu_cache import apply_penalty_gpu_cache +from lightllm.common.basemodel.triton_kernel.post_process.apply_invalid_token import apply_invalid_token_ids from lightllm.server.router.model_infer.infer_batch import InferReq, g_infer_context from lightllm.utils.envs_utils import get_env_start_args @@ -14,7 +15,10 @@ def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]): b_top_ks, b_length_penalty_param, b_mask_eos_reqs, + invalid_token_ids, + cu_invalid_token_num, is_all_greedy, + has_invalid_token_ids, ) = _get_post_sample_tensors(reqs) eos_ids = torch.tensor(eos_id, dtype=torch.int32, device="cpu", pin_memory=True).cuda(non_blocking=True) @@ -59,6 +63,14 @@ def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]): eos_ids=eos_ids, sampling_params_manager=sampling_params_manager, ) + + if has_invalid_token_ids: + apply_invalid_token_ids( + Logits=logits, + invalid_token_ids=invalid_token_ids, + cu_invalid_token_num=cu_invalid_token_num, + ) + logits.div_(b_temperatures.view((-1, 1))) probs = torch.softmax(logits, dim=-1) @@ -112,6 +124,12 @@ def _get_post_sample_tensors(reqs: List[InferReq]): mask_eos_reqs: List[bool] = [] is_all_greedy = True + # invalid token ids + invalid_token_ids: List[int] = [] + has_invalid_token_ids = False + cu_invalid_token_num = [0] + invalid_token_num_start = 0 + for i, req_obj in enumerate(reqs): sample_param = req_obj.sampling_param shm_param = sample_param.shm_param @@ -127,6 +145,11 @@ def _get_post_sample_tensors(reqs: List[InferReq]): if top_k_val > 1: is_all_greedy = False req_idxes.append(req_obj.req_idx) + invalid_token_num_start += len(req_obj.sampling_param.invalid_token_ids) + cu_invalid_token_num.append(invalid_token_num_start) + if len(req_obj.sampling_param.invalid_token_ids) > 0: + has_invalid_token_ids = True + invalid_token_ids.extend(req_obj.sampling_param.invalid_token_ids) req_idxes_cpu = torch.tensor(req_idxes, dtype=torch.int32, device="cpu", pin_memory=True) temperatures_cpu = torch.tensor(temperatures, dtype=torch.float, device="cpu", pin_memory=True) @@ -135,6 +158,10 @@ def _get_post_sample_tensors(reqs: List[InferReq]): length_penalty_param_cpu = torch.tensor(length_penalty_param, dtype=torch.int32, device="cpu", pin_memory=True) mask_eos_reqs_cpu = torch.tensor(mask_eos_reqs, dtype=torch.bool, device="cpu", pin_memory=True) + if has_invalid_token_ids: + invalid_token_ids_cpu = torch.tensor(invalid_token_ids, dtype=torch.int32, device="cpu", pin_memory=True) + cu_invalid_token_num_cpu = torch.tensor(cu_invalid_token_num, dtype=torch.int32, device="cpu", pin_memory=True) + return ( req_idxes_cpu.cuda(non_blocking=True), temperatures_cpu.cuda(non_blocking=True), @@ -142,5 +169,8 @@ def _get_post_sample_tensors(reqs: List[InferReq]): top_ks_cpu.cuda(non_blocking=True), length_penalty_param_cpu.cuda(non_blocking=True), mask_eos_reqs_cpu.cuda(non_blocking=True), + invalid_token_ids_cpu.cuda(non_blocking=True) if has_invalid_token_ids else None, + cu_invalid_token_num_cpu.cuda(non_blocking=True) if has_invalid_token_ids else None, is_all_greedy, + has_invalid_token_ids, ) diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py index 55fe7a415e..d7b85adf68 100644 --- a/lightllm/server/router/model_infer/model_rpc.py +++ b/lightllm/server/router/model_infer/model_rpc.py @@ -33,6 +33,8 @@ from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.process_check import start_parent_check_thread from lightllm.utils.envs_utils import get_unique_server_name +from lightllm.utils.torch_memory_saver_utils import MemoryTag +from lightllm.server.io_struct import GeneralHttpToModelRpcReq, GeneralModelToHttpRpcRsp logger = init_logger(__name__) @@ -179,6 +181,34 @@ def init_model(self, kvargs): def get_max_total_token_num(self): return self.backend.get_max_total_token_num() + def release_memory_occupation(self, tags: List[MemoryTag]): + try: + self.backend.release_memory_occupation(tags) + return True + except BaseException as e: + logger.exception(f"release memory occupation failed: {str(e)}") + return False + + def resume_memory_occupation(self, tags: List[MemoryTag]): + try: + self.backend.resume_memory_occupation(tags) + return True + except BaseException as e: + logger.exception(f"resume memory occupation failed: {str(e)}") + return False + + def forward_to_model(self, req: GeneralHttpToModelRpcReq) -> GeneralModelToHttpRpcRsp: + try: + if self.backend is None or not hasattr(self.backend, req.func_name): + raise ValueError(f"Backend does not support function {req.func_name}") + success, ret = getattr(self.backend, req.func_name)(req.func_args) + return GeneralModelToHttpRpcRsp(success=success, msg=str(ret), func_name=req.func_name, func_rsp=ret) + except BaseException as e: + logger.exception(f"forward to model backend failed: {str(e)}") + return GeneralModelToHttpRpcRsp( + success=False, msg=f"forward to model backend failed: {str(e)}", func_name=req.func_name + ) + class ModelRpcClient: def __init__(self, rpc_event, rpc_finished_event): @@ -209,6 +239,16 @@ async def get_max_total_token_num(self): assert func_name == "get_max_total_token_num" return ret + def forward_to_model(self, req: GeneralHttpToModelRpcReq) -> GeneralModelToHttpRpcRsp: + self.rpc_shm_params.write_func_params("forward_to_model", (req,)) + self.rpc_event.set() + + self.rpc_finished_event.wait() + self.rpc_finished_event.clear() + func_name, ret = self.rpc_shm_results.read_func_result() + assert func_name == "forward_to_model" + return ret + def _init_env( args, @@ -269,7 +309,11 @@ async def start_model_process( success_event, ), ) - proc.start() + from lightllm.utils.torch_memory_saver_utils import TorchMemorySaverWrapper + + torch_memory_saver = TorchMemorySaverWrapper(args.enable_torch_memory_saver) + with torch_memory_saver.configure_subprocess(): + proc.start() # Use asyncio.to_thread to make the blocking wait non-blocking await asyncio.to_thread(success_event.wait, timeout=40) diff --git a/lightllm/server/router/req_queue/base_queue.py b/lightllm/server/router/req_queue/base_queue.py index 36aefae6e7..d7ef06828b 100644 --- a/lightllm/server/router/req_queue/base_queue.py +++ b/lightllm/server/router/req_queue/base_queue.py @@ -34,6 +34,17 @@ def free_aborted_req_cpu_cache_pages(self, req: Req): req.cpu_cache_match_page_indexes.clear() self.router.cpu_cache_client.lock.release() + def free_aborted_req(self, req: Req): + # 为了让http server 能正常返回请求,还没有开始推理的请求,直接设置结束,返回空字符串 + input_len = req.input_len + req.link_prompt_ids_shm_array() + req.link_logprobs_shm_array() + req.candetoken_out_len = 1 + req.finish_token_index = input_len + req.shm_prompt_ids.arr[input_len] = self.args.eos_id[0] + req.shm_logprobs.arr[input_len] = 0 + req.finish_status.set_status(FinishStatus.FINISHED_ABORTED) + def extend(self, req_group: List[Req]): for req in req_group: req.sample_params.suggested_dp_index = self.dp_index diff --git a/lightllm/server/router/req_queue/chunked_prefill/beam_impl.py b/lightllm/server/router/req_queue/chunked_prefill/beam_impl.py index ae7c90b335..ed2a5dbb12 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/beam_impl.py +++ b/lightllm/server/router/req_queue/chunked_prefill/beam_impl.py @@ -89,7 +89,7 @@ def generate_new_batch(self, current_batch: Batch): aborted_count = 0 cur_group_reqs = [] for req in self.waiting_req_list: - if req.is_aborted: + if req.is_aborted and not self.router.is_multinode_tp: aborted_count += 1 abort_req_list.append(req) continue @@ -111,7 +111,7 @@ def generate_new_batch(self, current_batch: Batch): ok_insert, new_batch_first_router_need_tokens = self._can_add_new_group_reqs( cur_group_reqs, is_busy, new_batch_first_router_need_tokens ) - if ok_insert: + if ok_insert and False: can_run_list.extend(cur_group_reqs) new_batch = None @@ -120,6 +120,7 @@ def generate_new_batch(self, current_batch: Batch): for req in abort_req_list: self.free_aborted_req_cpu_cache_pages(req) + self.free_aborted_req(req) self.router.shm_req_manager.put_back_req_obj(req) self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :] return new_batch diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl.py b/lightllm/server/router/req_queue/chunked_prefill/impl.py index 0d870b55d8..9449798e9c 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/impl.py +++ b/lightllm/server/router/req_queue/chunked_prefill/impl.py @@ -79,8 +79,8 @@ def generate_new_batch(self, current_batch: Batch): waiting_queue = self.waiting_req_list for req in waiting_queue: - if req.is_aborted: - # 由于管理的复杂性,只有没有被调度运行过的请求可以因为abort直接在队列中忽略掉. + if req.is_aborted and not self.router.is_multinode_tp: + # 由于管理的复杂性,只有没有被调度运行过的单节点请求可以因为abort直接在队列中忽略掉. # 暂停的请求需要恢复后,由 router manager 部分来过滤。暂时保持这种处理方法, 否则会导致管理token的泄漏 aborted_count += 1 abort_req_list.append(req) @@ -97,6 +97,7 @@ def generate_new_batch(self, current_batch: Batch): new_batch = Batch(uuid.uuid4().int, can_run_list, dp_size_in_node=self.dp_size_in_node) for req in abort_req_list: self.free_aborted_req_cpu_cache_pages(req) + self.free_aborted_req(req) self.router.shm_req_manager.put_back_req_obj(req) self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :] return new_batch diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl_for_nixl_pd.py b/lightllm/server/router/req_queue/chunked_prefill/impl_for_nixl_pd.py index f2658159b4..842b93648b 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/impl_for_nixl_pd.py +++ b/lightllm/server/router/req_queue/chunked_prefill/impl_for_nixl_pd.py @@ -70,7 +70,7 @@ def generate_new_batch(self, current_batch: Batch): waiting_queue = self.waiting_req_list for req in waiting_queue: - if req.is_aborted: + if req.is_aborted and not self.router.is_multinode_tp: # 由于管理的复杂性,只有没有被调度运行过的请求可以因为abort直接在队列中忽略掉. # 暂停的请求需要恢复后,由 router manager 部分来过滤。暂时保持这种处理方法, 否则会导致管理token的泄漏 aborted_count += 1 @@ -88,6 +88,7 @@ def generate_new_batch(self, current_batch: Batch): new_batch = Batch(uuid.uuid4().int, can_run_list, dp_size_in_node=self.dp_size_in_node) for req in abort_req_list: self.free_aborted_req_cpu_cache_pages(req) + self.free_aborted_req(req) self.router.shm_req_manager.put_back_req_obj(req) self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :] return new_batch diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd_decode.py b/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd_decode.py index e0da134875..3dea3cf955 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd_decode.py +++ b/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd_decode.py @@ -38,7 +38,7 @@ def generate_new_batch(self, current_batch: Batch): abort_req_list = [] aborted_count = 0 for req in self.waiting_req_list: - if req.is_aborted: + if req.is_aborted and not self.router.is_multinode_tp: # 由于管理的复杂性,只有没有被调度运行过的请求可以因为abort直接在队列中忽略掉. # 暂停的请求需要恢复后,由 router manager 部分来过滤。暂时保持这种处理方法, 否则会导致管理token和管理req对象的泄漏 aborted_count += 1 @@ -53,6 +53,7 @@ def generate_new_batch(self, current_batch: Batch): new_batch = Batch(uuid.uuid4().int, can_run_list, dp_size_in_node=self.dp_size_in_node) for req in abort_req_list: self.free_aborted_req_cpu_cache_pages(req) + self.free_aborted_req(req) self.router.shm_req_manager.put_back_req_obj(req) self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :] return new_batch diff --git a/lightllm/server/router/req_queue/dp_base_queue.py b/lightllm/server/router/req_queue/dp_base_queue.py index a73823b8b7..e5f731df5f 100644 --- a/lightllm/server/router/req_queue/dp_base_queue.py +++ b/lightllm/server/router/req_queue/dp_base_queue.py @@ -27,6 +27,12 @@ def __init__(self, args, router, base_queue_class, dp_size_in_node) -> None: self.reqs_waiting_for_dp_index: List[List[Req]] = [] return + def free_aborted_req(self, req: Req): + dp_index = req.sample_params.suggested_dp_index + assert dp_index >= 0 and dp_index < self.dp_size_in_node + self.inner_queues[dp_index].free_aborted_req(req) + return + def get_dp_queue(self, dp_index: int): assert dp_index < self.dp_size_in_node, "dp index out of range" return self.inner_queues[dp_index] diff --git a/lightllm/server/tokenizer.py b/lightllm/server/tokenizer.py index e0b2bd425e..3563739f79 100644 --- a/lightllm/server/tokenizer.py +++ b/lightllm/server/tokenizer.py @@ -30,6 +30,7 @@ from ..models.qwen2_vl.model import QWen2VLTokenizer from ..models.qwen3_vl.model import QWen3VLTokenizer from ..models.internvl.model import InternvlTokenizer +from ..models.neo_chat_moe.model import NeoChatTokenizer from ..models.gemma3.model import Gemma3Tokenizer # A fast LLaMA tokenizer with the pre-processed `tokenizer.json` file. @@ -104,5 +105,7 @@ def get_tokenizer( tokenizer = InternvlTokenizer(tokenizer, model_cfg, weight_dir=tokenizer_name) elif model_type == "gemma3": tokenizer = Gemma3Tokenizer(tokenizer, model_cfg) + elif model_type == "neo_chat": + tokenizer = NeoChatTokenizer(tokenizer, model_cfg, weight_dir=tokenizer_name) return tokenizer diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index 202c2fc453..a54a4aeffa 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -8,7 +8,6 @@ import inspect import setproctitle from typing import List -from lightllm.server.core.objs.io_objs.group_req import GroupReqIndexes from lightllm.server.core.objs import ShmReqManager, StartArgs asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) @@ -19,6 +18,7 @@ from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.process_check import start_parent_check_thread from lightllm.utils.envs_utils import get_unique_server_name +from lightllm.server.io_struct import BaseReq, GenerateReqIndex from rpyc.utils.classic import obtain @@ -49,7 +49,7 @@ def __init__( self.cache_client = rpyc.connect("localhost", args.cache_port, config={"allow_pickle": True}) self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) self.cache_port = args.cache_port - self.waiting_reqs: List[GroupReqIndexes] = [] + self.waiting_reqs: List[BaseReq] = [] self.model_weightdir = args.model_dir self.tp_world_size = args.tp self.vit_dp = args.visual_dp @@ -187,11 +187,12 @@ async def loop_for_netio_req(self): while True: try: for _ in range(self.visual_recv_max_count): - recv_req: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) - if isinstance(recv_req, GroupReqIndexes): + recv_req: BaseReq = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) + # 目前只有 GenerateReqIndex 会进入这个队列,判断是否需要推理图片 + if isinstance(recv_req, GenerateReqIndex): self.waiting_reqs.append(recv_req) else: - assert False, f"Error Req Inf {recv_req}" + self.send_to_next_module.send_pyobj(recv_req, protocol=pickle.HIGHEST_PROTOCOL) self.visual_recv_max_count = min(self.visual_recv_max_count * 1.3, 256) except zmq.ZMQError: # 当队列已经开始清空的时候,将一次接受数量下调 diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 8f4a1ee450..22dfa915ba 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -4,6 +4,7 @@ import torch import socket import inspect +import setproctitle from datetime import timedelta from typing import Dict, List, Tuple from transformers.configuration_utils import PretrainedConfig @@ -19,12 +20,15 @@ from lightllm.models.qwen2_5_vl.qwen2_5_visual import Qwen2_5_VisionTransformerPretrainedModel from lightllm.models.qwen3_vl.qwen3_visual import Qwen3VisionTransformerPretrainedModel from lightllm.models.tarsier2.tarsier2_visual import TarsierVisionTransformerPretrainedModel +from lightllm.models.neo_chat_moe.neo_visual import NeoVisionTransformerPretrainedModel from lightllm.utils.infer_utils import set_random_seed from lightllm.utils.dist_utils import init_vision_distributed_env from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.envs_utils import get_env_start_args from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient from lightllm.server.visualserver import set_vit_att_backend +from lightllm.utils.process_check import start_parent_check_thread +from lightllm.utils.envs_utils import get_unique_server_name class VisualModelRpcServer(rpyc.Service): @@ -80,6 +84,8 @@ def exposed_init_model(self, kvargs): # self.model = InternVLVisionModel() elif self.model_type == "gemma3": self.model = Gemma3VisionModel() + elif self.model_type == "neo_chat": + self.model = NeoVisionTransformerPretrainedModel(kvargs, **model_cfg["vision_config"]).eval().bfloat16() else: raise Exception(f"can not support {self.model_type} now") @@ -172,6 +178,8 @@ async def encode(self, images: List[ImageItem]): def _init_env(port, device_id): # 注册graceful 退出的处理 graceful_registry(inspect.currentframe().f_code.co_name) + setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::visual_server::RANK{device_id}") + start_parent_check_thread() import lightllm.utils.rpyc_fix_utils as _ diff --git a/lightllm/utils/device_utils.py b/lightllm/utils/device_utils.py index a1ed6ed950..9e46a57ec3 100644 --- a/lightllm/utils/device_utils.py +++ b/lightllm/utils/device_utils.py @@ -100,7 +100,8 @@ def get_current_device_name(): gpu_name = gpu_name.replace(" ", "_") return gpu_name else: - return None + return "unknown" # need fix + # raise RuntimeError("No GPU available") @lru_cache(maxsize=None) diff --git a/lightllm/utils/dist_utils.py b/lightllm/utils/dist_utils.py index 65ac401d4c..28667c6d00 100644 --- a/lightllm/utils/dist_utils.py +++ b/lightllm/utils/dist_utils.py @@ -65,12 +65,15 @@ def init_vision_distributed_env(kvargs): device_id = visual_gpu_ids[kvargs["vit_rank_id"]] set_current_device_id(device_id) torch.cuda.set_device(device_id) + # 不要在init_process_group时,显示的传入device_id + # 这会触发torch的device-bound split优化,会默认后面想加入新进程组的rank + # 都已经存在于默认组,这样RL更新weight的init_group时,外部想加入的组,在执行 + # 通信原语时例如all_reduce,会永远等不到LightLLM默认组里的回复,从而导致错误结果。 dist.init_process_group( "nccl", init_method=f'tcp://127.0.0.1:{kvargs["visual_nccl_port"]}', rank=kvargs["tp_rank_id"], world_size=tp_world_size, - device_id=torch.device(f"cuda:{device_id}"), ) # warmup nccl communicator _a = torch.zeros([1]).to(f"cuda:{device_id}") @@ -104,7 +107,6 @@ def init_distributed_env(kvargs): init_method=f'tcp://{kvargs["nccl_host"]}:{kvargs["nccl_port"]}', rank=kvargs["rank_id"], world_size=kvargs["world_size"], - device_id=torch.device(f"cuda:{device_id}"), ) # warmup nccl communicator _a = torch.zeros([1]).to(f"cuda:{device_id}") @@ -270,3 +272,71 @@ def _init_nccl_env(): assert response.status_code == 200, f"Failed to init config server nccl tcp store: {response.status_code}" return + + +# copy from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/utils/common.py#L1675 +def init_custom_process_group( + backend=None, + init_method=None, + timeout=None, + world_size=-1, + rank=-1, + store=None, + group_name=None, + pg_options=None, + device_id=None, +): + from torch.distributed.distributed_c10d import ( + Backend, + PrefixStore, + _new_process_group_helper, + _world, + default_pg_timeout, + rendezvous, + ) + + assert (store is None) or (init_method is None), "Cannot specify both init_method and store." + + if store is not None: + assert world_size > 0, "world_size must be positive if using store" + assert rank >= 0, "rank must be non-negative if using store" + elif init_method is None: + init_method = "env://" + + if backend: + backend = Backend(backend) + else: + backend = Backend("undefined") + + if timeout is None: + timeout = default_pg_timeout + + # backward compatible API + if store is None: + rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout) + store, rank, world_size = next(rendezvous_iterator) + store.set_timeout(timeout) + + # Use a PrefixStore to avoid accidental overrides of keys used by + # different systems (e.g. RPC) in case the store is multi-tenant. + store = PrefixStore(group_name, store) + + # NOTE: The pg_options parameter was renamed into backend_options in PyTorch 2.6.0 + # https://github.com/pytorch/pytorch/commit/a0c7029a75628cd5fa8df83c0de0ea98ee7fd844 + # We need to determine the appropriate parameter name based on PyTorch version + pg_options_param_name = "backend_options" if str(torch.__version__) >= "2.6" else "pg_options" + pg, _ = _new_process_group_helper( + world_size, + rank, + [], + backend, + store, + group_name=group_name, + **{pg_options_param_name: pg_options}, + timeout=timeout, + device_id=device_id, + ) + + _world.pg_group_ranks[pg] = {i: i for i in range(world_size)} + + return pg diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index 3a0e28bcb6..a702a465b2 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -13,6 +13,7 @@ def set_unique_server_name(args): if args.run_mode == "pd_master": os.environ["LIGHTLLM_UNIQUE_SERVICE_NAME_ID"] = str(args.port) + "_pd_master" else: + assert str(args.nccl_port) != "None", "nccl_port is not set" os.environ["LIGHTLLM_UNIQUE_SERVICE_NAME_ID"] = str(args.nccl_port) + "_" + str(args.node_rank) return @@ -20,6 +21,8 @@ def set_unique_server_name(args): @lru_cache(maxsize=None) def get_unique_server_name(): service_uni_name = os.getenv("LIGHTLLM_UNIQUE_SERVICE_NAME_ID") + assert "None" not in service_uni_name, "service_uni_name is not set" + return service_uni_name diff --git a/lightllm/utils/kv_cache_utils.py b/lightllm/utils/kv_cache_utils.py index 3256fdd1fd..f44aad92ac 100644 --- a/lightllm/utils/kv_cache_utils.py +++ b/lightllm/utils/kv_cache_utils.py @@ -24,6 +24,7 @@ PPLINT8KVMemoryManager, PPLINT4KVMemoryManager, Deepseek2MemoryManager, + NeoMemoryManager, ) from typing import List, Tuple, Optional @@ -97,6 +98,28 @@ def calcu_cpu_cache_meta() -> "CpuKVCacheMeta": scale_head_dim=get_head_dim(args.model_dir) // 8, scale_data_type=get_llm_data_type(), ) + elif mem_manager_class is PPLINT8KVMemoryManager: + cpu_cache_meta = CpuKVCacheMeta( + page_num=0, + token_page_size=args.cpu_cache_token_page_size, + layer_num=get_layer_num(args.model_dir), + num_heads=get_num_key_value_heads(args.model_dir) * 2, + head_dim=get_head_dim(args.model_dir) * 2, + data_type=get_llm_data_type(), + scale_head_dim=0, + scale_data_type=get_llm_data_type(), + ) + elif mem_manager_class is MemoryManager: + cpu_cache_meta = CpuKVCacheMeta( + page_num=0, + token_page_size=args.cpu_cache_token_page_size, + layer_num=get_layer_num(args.model_dir), + num_heads=get_num_key_value_heads(args.model_dir) * 2, + head_dim=get_head_dim(args.model_dir), + data_type=get_llm_data_type(), + scale_head_dim=0, + scale_data_type=get_llm_data_type(), + ) else: logger.error(f"not support mem manager: {mem_manager_class} for cpu kv cache") raise Exception(f"not support mem manager: {mem_manager_class} for cpu kv cache") diff --git a/lightllm/utils/net_utils.py b/lightllm/utils/net_utils.py index 20b9888753..486414e88e 100644 --- a/lightllm/utils/net_utils.py +++ b/lightllm/utils/net_utils.py @@ -2,44 +2,72 @@ import subprocess import ipaddress import random +import portpicker from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) def alloc_can_use_network_port(num=3, used_nccl_ports=None, from_port_num=10000): + if used_nccl_ports is None: + used_nccl_ports = [] + port_list = [] - for port in range(from_port_num, 65536): - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - result = s.connect_ex(("localhost", port)) - if result != 0 and port not in used_nccl_ports: + max_attempts = num * 50 # Allow more attempts to find ports in range + + for _ in range(max_attempts): + if len(port_list) >= num: + break + + try: + port = portpicker.pick_unused_port() + + if port >= from_port_num and port not in used_nccl_ports: port_list.append(port) - if len(port_list) > num * 30: - break + logger.debug(f"Allocated port: {port}") + else: + logger.debug(f"Port {port} is out of range or in used_nccl_ports, skipping") + + except Exception as e: + logger.warning(f"Failed to allocate port: {e}") + continue if len(port_list) < num: + logger.error(f"Failed to allocate {num} ports, only got {len(port_list)}") return None - random.shuffle(port_list) - return port_list[0:num] + logger.info(f"Successfully allocated {len(port_list)} ports: {port_list}") + return port_list def alloc_can_use_port(min_port, max_port): port_list = [] for port in range(min_port, max_port): - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - result = s.connect_ex(("localhost", port)) + try: + test_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + test_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + result = test_socket.connect_ex(("localhost", port)) + test_socket.close() + if result != 0: port_list.append(port) + except Exception: + continue return port_list def find_available_port(start_port, end_port): for port in range(start_port, end_port + 1): - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: - result = sock.connect_ex(("localhost", port)) + try: + test_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + test_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + result = test_socket.connect_ex(("localhost", port)) + test_socket.close() + if result != 0: return port + except Exception: + continue return None diff --git a/lightllm/utils/patch_torch.py b/lightllm/utils/patch_torch.py new file mode 100644 index 0000000000..9f51edeb64 --- /dev/null +++ b/lightllm/utils/patch_torch.py @@ -0,0 +1,63 @@ +# copied from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/utils/patch_torch.py +from typing import Callable, Union + +import torch +from packaging import version +from torch.multiprocessing import reductions + + +def monkey_patch_torch_reductions(): + """Monkey patching before Torch https://github.com/pytorch/pytorch/pull/149248 is fixed""" + + # Currently, NPU does not support UUID. This has been temporarily commented out, + # with support expected in the fourth quarter. + # if _is_npu: + # return + + if hasattr(reductions, "_reduce_tensor_original"): + return + + reductions._reduce_tensor_original = reductions.reduce_tensor + reductions._rebuild_cuda_tensor_original = reductions.rebuild_cuda_tensor + + reductions.reduce_tensor = _reduce_tensor_modified + reductions.rebuild_cuda_tensor = _rebuild_cuda_tensor_modified + + reductions.init_reductions() + + +# The signature has not been changed for years, and we will not need this when the next version is released, +# so it looks safe to use a constant. +_REDUCE_TENSOR_ARG_DEVICE_INDEX = 6 + + +def _reduce_tensor_modified(*args, **kwargs): + output_fn, output_args = reductions._reduce_tensor_original(*args, **kwargs) + output_args = _modify_tuple(output_args, _REDUCE_TENSOR_ARG_DEVICE_INDEX, _device_to_uuid) + return output_fn, output_args + + +def _rebuild_cuda_tensor_modified(*args): + args = _modify_tuple(args, _REDUCE_TENSOR_ARG_DEVICE_INDEX, _device_from_maybe_uuid) + return reductions._rebuild_cuda_tensor_original(*args) + + +def _device_to_uuid(device: int) -> str: + return str(torch.cuda.get_device_properties(device).uuid) + + +def _device_from_maybe_uuid(device_maybe_uuid: Union[int, str]) -> int: + if isinstance(device_maybe_uuid, int): + return device_maybe_uuid + + if isinstance(device_maybe_uuid, str): + for device in range(torch.cuda.device_count()): + if str(torch.cuda.get_device_properties(device).uuid) == device_maybe_uuid: + return device + raise Exception("Invalid device_uuid=" + device_maybe_uuid) + + raise Exception(f"Unknown type: {device_maybe_uuid=}") + + +def _modify_tuple(t, index: int, modifier: Callable): + return *t[:index], modifier(t[index]), *t[index + 1 :] diff --git a/lightllm/utils/serializer.py b/lightllm/utils/serializer.py new file mode 100644 index 0000000000..d8180aeb0c --- /dev/null +++ b/lightllm/utils/serializer.py @@ -0,0 +1,131 @@ +# copied from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/utils/common.py + +import base64 +import pickle +import io +from dataclasses import dataclass +from multiprocessing.reduction import ForkingPickler +from typing import List + + +class MultiprocessingSerializer: + @staticmethod + def serialize(obj, output_str: bool = False): + """ + Serialize a Python object using ForkingPickler. + + Args: + obj: The object to serialize. + output_str (bool): If True, return a base64-encoded string instead of raw bytes. + + Returns: + bytes or str: The serialized object. + """ + buf = io.BytesIO() + ForkingPickler(buf).dump(obj) + buf.seek(0) + output = buf.read() + + if output_str: + # Convert bytes to base64-encoded string + output = base64.b64encode(output).decode("utf-8") + + return output + + @staticmethod + def deserialize(data): + """ + Deserialize a previously serialized object. + + Args: + data (bytes or str): The serialized data, optionally base64-encoded. + + Returns: + The deserialized Python object. + """ + if isinstance(data, str): + # Decode base64 string to bytes + data = base64.b64decode(data, validate=True) + + return SafeUnpickler(io.BytesIO(data)).load() + + +class SafeUnpickler(pickle.Unpickler): + ALLOWED_MODULE_PREFIXES = { + # --- Python types --- + "builtins.", + "collections.", + "copyreg.", + "functools.", + "itertools.", + "operator.", + "types.", + "weakref.", + # --- PyTorch types --- + "torch.", + "torch._tensor.", + "torch.storage.", + "torch.nn.parameter.", + "torch.autograd.function.", + # --- torch distributed --- + "torch.distributed.", + "torch.distributed._shard.", + "torch.distributed._composable.", + "torch._C._distributed_c10d.", + "torch._C._distributed_fsdp.", + "torch.distributed.optim.", + # --- multiprocessing --- + "multiprocessing.resource_sharer.", + "multiprocessing.reduction.", + "pickletools.", + # --- PEFT / LoRA --- + "peft.", + "transformers.", + "huggingface_hub.", + # --- SGLang & Unitest --- + "sglang.srt.weight_sync.tensor_bucket.", + "sglang.srt.model_executor.model_runner.", + "sglang.srt.layers.", + "sglang.srt.utils.", + # --- LightLLM --- + "lightllm.utils.", + } + + DENY_CLASSES = { + ("builtins", "eval"), + ("builtins", "exec"), + ("builtins", "compile"), + ("os", "system"), + ("subprocess", "Popen"), + ("subprocess", "run"), + ("codecs", "decode"), + ("types", "CodeType"), + ("types", "FunctionType"), + } + + def find_class(self, module, name): + # Block deterministic attacks + if (module, name) in self.DENY_CLASSES: + raise RuntimeError( + f"Blocked unsafe class loading ({module}.{name}), " f"to prevent exploitation of CVE-2025-10164" + ) + # Allowlist of safe-to-load modules. + if any((module + ".").startswith(prefix) for prefix in self.ALLOWED_MODULE_PREFIXES): + return super().find_class(module, name) + + # Block everything else. (Potential attack surface) + raise RuntimeError( + f"Blocked unsafe class loading ({module}.{name}), " f"to prevent exploitation of CVE-2025-10164" + ) + + +@dataclass +class LocalSerializedTensor: + """torch.Tensor that gets serialized by MultiprocessingSerializer + (which only serializes a pointer and not the data). + The i-th element in the list corresponds to i-th rank's GPU.""" + + values: List[bytes] + + def get(self, rank: int): + return MultiprocessingSerializer.deserialize(self.values[rank]) diff --git a/lightllm/utils/tensor_bucket.py b/lightllm/utils/tensor_bucket.py new file mode 100644 index 0000000000..a9d7a367dd --- /dev/null +++ b/lightllm/utils/tensor_bucket.py @@ -0,0 +1,104 @@ +# copy from +# https://raw.githubusercontent.com/sgl-project/sglang/refs/heads/main/python/sglang/ +# srt/weight_sync/tensor_bucket.py +from dataclasses import dataclass +from typing import List, Tuple + +import torch + + +@dataclass +class FlattenedTensorMetadata: + """Metadata for a tensor in a flattened bucket""" + + name: str + shape: torch.Size + dtype: torch.dtype + start_idx: int + end_idx: int + numel: int + + +class FlattenedTensorBucket: + """ + A bucket that flattens multiple tensors into a single tensor for efficient processing + while preserving all metadata needed for reconstruction. + """ + + # This field is solely for users of to check whether the class supports this feature + supports_multi_dtypes = True + + def __init__( + self, + named_tensors: List[Tuple[str, torch.Tensor]] = None, + flattened_tensor: torch.Tensor = None, + metadata: List[FlattenedTensorMetadata] = None, + ): + """ + Initialize a tensor bucket from a list of named tensors OR from pre-flattened data. + Args: + named_tensors: List of (name, tensor) tuples (for creating new bucket) + flattened_tensor: Pre-flattened tensor (for reconstruction) + metadata: Pre-computed metadata (for reconstruction) + """ + if named_tensors is not None: + # Create bucket from named tensors + self.metadata: List[FlattenedTensorMetadata] = [None] * len(named_tensors) + self.flattened_tensor: torch.Tensor = None + + if not named_tensors: + raise ValueError("Cannot create empty tensor bucket") + + # Collect metadata and flatten tensors + current_idx = 0 + flattened_tensors: List[torch.Tensor] = [None] * len(named_tensors) + + for i, (name, tensor) in enumerate(named_tensors): + flattened = tensor.flatten().view(torch.uint8) + flattened_tensors[i] = flattened + + # Store metadata + + numel = flattened.numel() + metadata_obj = FlattenedTensorMetadata( + name=name, + shape=tensor.shape, + dtype=tensor.dtype, + start_idx=current_idx, + end_idx=current_idx + numel, + numel=numel, + ) + self.metadata[i] = metadata_obj + current_idx += numel + + # Concatenate all flattened tensors + self.flattened_tensor = torch.cat(flattened_tensors, dim=0) + else: + # Initialize from pre-flattened data + if flattened_tensor is None or metadata is None: + raise ValueError("Must provide either named_tensors or both flattened_tensor and metadata") + self.flattened_tensor = flattened_tensor + self.metadata = metadata + + def get_flattened_tensor(self) -> torch.Tensor: + """Get the flattened tensor containing all bucket tensors""" + return self.flattened_tensor + + def get_metadata(self) -> List[FlattenedTensorMetadata]: + """Get metadata for all tensors in the bucket""" + return self.metadata + + def reconstruct_tensors(self) -> List[Tuple[str, torch.Tensor]]: + """ + Reconstruct original tensors from flattened tensor with optimized performance. + Uses memory-efficient operations to minimize allocations and copies. + """ + # preallocate the result list + reconstructed = [None] * len(self.metadata) + + for i, meta in enumerate(self.metadata): + tensor = self.flattened_tensor[meta.start_idx : meta.end_idx].view(meta.dtype).reshape(meta.shape) + + reconstructed[i] = (meta.name, tensor) + + return reconstructed diff --git a/lightllm/utils/torch_memory_saver_utils.py b/lightllm/utils/torch_memory_saver_utils.py new file mode 100644 index 0000000000..c1184ef30c --- /dev/null +++ b/lightllm/utils/torch_memory_saver_utils.py @@ -0,0 +1,92 @@ +import torch +from contextlib import contextmanager +from enum import Enum +from lightllm.utils.log_utils import init_logger + +try: + from torch_memory_saver import ( + torch_memory_saver, + configure_subprocess, + ) + + HAS_TORCH_MEMORY_SAVER = True + +except ImportError: + HAS_TORCH_MEMORY_SAVER = False + pass + +logger = init_logger(__name__) + + +class MemoryTag(Enum): + KV_CACHE = "kv_cache" + WEIGHT = "weights" + GRAPH = "graph" + + def is_kv_cache(self): + return self == MemoryTag.KV_CACHE + + def is_weight(self): + return self == MemoryTag.WEIGHT + + def is_graph(self): + return self == MemoryTag.GRAPH + + def __str__(self): + return self.value + + +class TorchMemorySaverWrapper: + def __new__(cls, enable_torch_memory_saver: bool = False): + if enable_torch_memory_saver: + assert ( + HAS_TORCH_MEMORY_SAVER + ), "torch_memory_saver is not installed, please install it via `pip install torch_memory_saver`." + return _TorchMemorySaver() + else: + return _TorchMemorySaverFake() + + +class _TorchMemorySaver: + def configure_subprocess(self): + return configure_subprocess() + + def region(self, tag: MemoryTag, enable_cpu_backup: bool = False): + return torch_memory_saver.region(tag=tag.value, enable_cpu_backup=enable_cpu_backup) + + def cuda_graph(self, graph_obj: torch.cuda.CUDAGraph, **kwargs): + return torch_memory_saver.cuda_graph(cuda_graph=graph_obj, **kwargs, tag=MemoryTag.GRAPH.value) + + def disable(self): + return torch_memory_saver.disable() + + def pause(self, tag: MemoryTag): + return torch_memory_saver.pause(tag=tag.value) + + def resume(self, tag: MemoryTag): + return torch_memory_saver.resume(tag=tag.value) + + +class _TorchMemorySaverFake: + @contextmanager + def configure_subprocess(self): + yield + + @contextmanager + def region(self, tag: MemoryTag, enable_cpu_backup: bool = False): + yield + + def cuda_graph(self, graph_obj: torch.cuda.CUDAGraph, **kwargs): + return torch.cuda.graph(graph_obj, **kwargs) + + @contextmanager + def disable(self): + yield + + def pause(self, tag: MemoryTag): + logger.warning("torch_memory_saver is not enabled, pause is not supported.") + return + + def resume(self, tag: MemoryTag): + logger.warning("torch_memory_saver is not enabled, resume is not supported.") + return diff --git a/requirements.txt b/requirements.txt index a3b9473f82..bbf9668221 100644 --- a/requirements.txt +++ b/requirements.txt @@ -95,3 +95,5 @@ partial_json_parser==0.2.1.1.post6 websockets==15.0.1 cupy-cuda12x==13.6.0 nixl==0.8.0 +torch_memory_saver==0.0.9 +portpicker==1.6.0 \ No newline at end of file diff --git a/test/test_api/test_r3.py b/test/test_api/test_r3.py new file mode 100644 index 0000000000..85c4e44ef9 --- /dev/null +++ b/test/test_api/test_r3.py @@ -0,0 +1,92 @@ +import sys +import argparse +import requests +import base64 +import numpy as np + + +def test_routing_export(url: str = "http://localhost:8000"): + print(f"Testing routing export at {url}") + print("-" * 50) + + try: + response = requests.post( + f"{url}/generate", + json={ + "inputs": "What is the capital of France? What is the capital of France?", + "parameters": { + "max_new_tokens": 50, + # "return_routed_experts": True, + # "repetition_penalty": 1.0, + }, + }, + timeout=60, + ) + except requests.exceptions.ConnectionError: + print(f"ERROR: Cannot connect to server at {url}") + print("Make sure the LightLLM server is running with --enable_return_routed_experts") + return False + except requests.exceptions.Timeout: + print("ERROR: Request timed out") + return False + + print(f"Status: {response.status_code}") + + if response.status_code != 200: + print(f"ERROR: Request failed with status {response.status_code}") + print(f"Response: {response.text}") + return False + + res = response.json() + print(f"Generated text: {res.get('generated_text', 'N/A')[:100]}...") + + if "routed_experts" not in res or not res["routed_experts"]: + print("\nWARNING: No routed_experts in response.") + print("This could mean:") + print(" - The model is not a MoE model") + print(" - The server was not started with --enable_return_routed_experts") + print(" - The routing capture manager was not initialized") + return False + + routing_info = res["routed_experts"] + shape = routing_info["shape"] + dtype_str = routing_info["dtype"] + dtype = np.dtype(dtype_str) + data = base64.b64decode(routing_info["data"]) + routing_array = np.frombuffer(data, dtype=dtype).reshape(shape) + + print(f"\n{'=' * 50}") + print("ROUTING CAPTURE SUCCESS!") + print(f"{'=' * 50}") + print(f"Shape: {shape}") + print(f"Dtype: {dtype}") + print(f"Num tokens: {shape[0]}") + print(f"Num MoE layers: {shape[1]}") + print(f"Top-K: {shape[2]}") + + # Compute payload size savings + int32_size = np.prod(shape) * 4 + actual_size = len(data) + savings = (1 - actual_size / int32_size) * 100 + print(f"Payload: {actual_size} bytes (vs {int32_size} bytes with int32, {savings:.0f}% smaller)") + + print(f"\nSample routing (first layer, first 5 tokens):") + num_tokens_to_show = shape[0] + for i in range(num_tokens_to_show): + print(f" Token {i}: experts {routing_array[i, 0, :].tolist()}") + + if np.all(routing_array == 0): + print("\nWARNING: All routing data is zeros. Capture may not be working correctly.") + return False + + print("\nTest PASSED!") + return True + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Test R3 routing export feature") + parser.add_argument("--url", default="http://localhost:8000", help="Server URL") + args = parser.parse_args() + + success = test_routing_export(args.url) + sys.exit(0 if success else 1) diff --git a/unit_tests/__init__.py b/unit_tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/unit_tests/common/__init__.py b/unit_tests/common/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/unit_tests/common/basemodel/__init__.py b/unit_tests/common/basemodel/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/unit_tests/common/basemodel/test_routing_capture_manager.py b/unit_tests/common/basemodel/test_routing_capture_manager.py new file mode 100644 index 0000000000..dcc010b372 --- /dev/null +++ b/unit_tests/common/basemodel/test_routing_capture_manager.py @@ -0,0 +1,219 @@ +import torch +import numpy as np + + +class TestRoutingCaptureManager: + def test_capture_and_extract_basic(self): + """Test the core pipeline: capture → flush_to_kv_buffer → extract_from_gpu.""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + manager = RoutingCaptureManager( + num_moe_layers=4, + topk=8, + num_experts=64, + kv_cache_size=1024, + max_capture_tokens=64, + ) + + # Simulate a batch of 10 tokens at KV-cache positions [100..109] + mem_indexes = torch.arange(100, 110, device="cuda") + + # Capture routing for each MoE layer (writes to capture buffer) + for layer_idx in range(4): + topk_ids = torch.randint(0, 64, (10, 8), device="cuda") + manager.capture(moe_layer_index=layer_idx, topk_ids=topk_ids, microbatch_index=0) + + # Flush from capture buffer to KV-indexed gpu_kv_buffer + manager.flush_to_kv_buffer(mem_indexes, num_tokens=10, microbatch_index=0) + + # Extract for those same KV-cache positions + result = manager.extract_from_gpu(mem_indexes) + assert result.shape == (4, 10, 8) + assert result.dtype == np.int8 + + def test_capture_writes_to_correct_kv_positions(self): + """Verify that captured data lands in the right KV-cache positions after flush.""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + manager = RoutingCaptureManager( + num_moe_layers=2, + topk=4, + num_experts=32, + kv_cache_size=256, + max_capture_tokens=16, + ) + + # Use non-contiguous mem_indexes to simulate real KV-cache + mem_indexes = torch.tensor([10, 50, 200], device="cuda") + + topk_ids = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], device="cuda") + manager.capture(moe_layer_index=0, topk_ids=topk_ids, microbatch_index=0) + + topk_ids_layer1 = topk_ids + 20 + manager.capture(moe_layer_index=1, topk_ids=topk_ids_layer1, microbatch_index=0) + + # Flush to KV positions + manager.flush_to_kv_buffer(mem_indexes, num_tokens=3, microbatch_index=0) + + # Extract and verify + result = manager.extract_from_gpu(mem_indexes) + assert result.shape == (2, 3, 4) + np.testing.assert_array_equal(result[0], topk_ids.cpu().numpy()) + np.testing.assert_array_equal(result[1], topk_ids_layer1.cpu().numpy()) + + def test_microbatch_isolation(self): + """Two microbatches writing to different KV positions don't interfere.""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + manager = RoutingCaptureManager( + num_moe_layers=1, + topk=4, + num_experts=32, + kv_cache_size=256, + max_capture_tokens=16, + ) + + # Microbatch 0: positions [10, 11] + mem0 = torch.tensor([10, 11], device="cuda") + ids_0 = torch.ones((2, 4), dtype=torch.int64, device="cuda") + manager.capture(moe_layer_index=0, topk_ids=ids_0, microbatch_index=0) + + # Microbatch 1: positions [20, 21] + mem1 = torch.tensor([20, 21], device="cuda") + ids_1 = torch.ones((2, 4), dtype=torch.int64, device="cuda") * 2 + manager.capture(moe_layer_index=0, topk_ids=ids_1, microbatch_index=1) + + # Flush each microbatch to different KV positions + manager.flush_to_kv_buffer(mem0, num_tokens=2, microbatch_index=0) + manager.flush_to_kv_buffer(mem1, num_tokens=2, microbatch_index=1) + + # Extract microbatch 0 + result0 = manager.extract_from_gpu(mem0) + assert result0.shape == (1, 2, 4) + assert result0[0, 0, 0] == 1 + + # Extract microbatch 1 + result1 = manager.extract_from_gpu(mem1) + assert result1[0, 0, 0] == 2 + + def test_dtype_selection_int8(self): + """Models with ≤127 experts use int8.""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + manager = RoutingCaptureManager( + num_moe_layers=1, + topk=2, + num_experts=64, + kv_cache_size=128, + max_capture_tokens=16, + ) + assert manager.dtype == torch.int8 + assert manager.np_dtype == np.int8 + assert manager.dtype_id == 1 + + def test_dtype_selection_int16(self): + """Models with >127 experts use int16.""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + manager = RoutingCaptureManager( + num_moe_layers=1, + topk=2, + num_experts=256, + kv_cache_size=128, + max_capture_tokens=16, + ) + assert manager.dtype == torch.int16 + assert manager.np_dtype == np.int16 + assert manager.dtype_id == 2 + + def test_extract_preserves_values(self): + """Extracted values exactly match what was captured, no dtype truncation.""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + manager = RoutingCaptureManager( + num_moe_layers=1, + topk=4, + num_experts=64, + kv_cache_size=64, + max_capture_tokens=16, + ) + + mem_indexes = torch.tensor([0, 1, 2], device="cuda") + + topk_ids = torch.tensor([[10, 20, 30, 40], [50, 60, 63, 1], [0, 5, 127, 3]], device="cuda") + manager.capture(moe_layer_index=0, topk_ids=topk_ids, microbatch_index=0) + + # Flush then extract + manager.flush_to_kv_buffer(mem_indexes, num_tokens=3, microbatch_index=0) + result = manager.extract_from_gpu(mem_indexes) + expected = topk_ids.cpu().numpy().astype(np.int8) + np.testing.assert_array_equal(result[0], expected) + + def test_gpu_kv_buffer_shape(self): + """Buffer shape is (num_moe_layers, kv_cache_size, topk).""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + # 127 experts fits in int8 (max value 127) + manager = RoutingCaptureManager( + num_moe_layers=48, + topk=8, + num_experts=127, + kv_cache_size=2048, + max_capture_tokens=256, + ) + assert manager.gpu_kv_buffer.shape == (48, 2048, 8) + assert manager.gpu_kv_buffer.dtype == torch.int8 + assert manager.gpu_kv_buffer.device.type == "cuda" + + # 128 experts requires int16 + manager2 = RoutingCaptureManager( + num_moe_layers=48, + topk=8, + num_experts=128, + kv_cache_size=2048, + max_capture_tokens=256, + ) + assert manager2.gpu_kv_buffer.dtype == torch.int16 + + def test_partial_token_capture(self): + """capture() only writes num_tokens rows to the buffer.""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + manager = RoutingCaptureManager( + num_moe_layers=1, + topk=2, + num_experts=32, + kv_cache_size=128, + max_capture_tokens=16, + ) + + # Capture only 3 tokens, flush to 5 KV positions (first 3 get data) + mem_indexes = torch.tensor([10, 11, 12, 13, 14], device="cuda") + + topk_ids = torch.tensor([[1, 2], [3, 4], [5, 6]], device="cuda") # only 3 tokens + manager.capture(moe_layer_index=0, topk_ids=topk_ids, microbatch_index=0) + + # Flush only the 3 captured tokens + manager.flush_to_kv_buffer(mem_indexes[:3], num_tokens=3, microbatch_index=0) + + # Positions 10-12 should have data, 13-14 should be zeros (from init) + result_written = manager.extract_from_gpu(mem_indexes[:3]) + np.testing.assert_array_equal(result_written[0], topk_ids.cpu().numpy().astype(np.int8)) + + result_unwritten = manager.extract_from_gpu(mem_indexes[3:]) + np.testing.assert_array_equal(result_unwritten[0], np.zeros((2, 2), dtype=np.int8)) + + def test_capture_buffer_shape(self): + """Capture buffer has correct shape (max_tokens, num_moe_layers, topk).""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + manager = RoutingCaptureManager( + num_moe_layers=4, + topk=8, + num_experts=64, + kv_cache_size=1024, + max_capture_tokens=256, + ) + assert manager._capture_buffer[0].shape == (256, 4, 8) + assert manager._capture_buffer[1].shape == (256, 4, 8) + assert manager._capture_buffer[0].dtype == torch.int8 diff --git a/unit_tests/common/basemodel/triton_kernel/test_apply_invalid_token.py b/unit_tests/common/basemodel/triton_kernel/test_apply_invalid_token.py new file mode 100644 index 0000000000..3b2f159f62 --- /dev/null +++ b/unit_tests/common/basemodel/triton_kernel/test_apply_invalid_token.py @@ -0,0 +1,50 @@ +import pytest +import torch + +from lightllm.common.basemodel.triton_kernel.post_process.apply_invalid_token import ( + apply_invalid_token_ids, +) + + +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) +def test_apply_invalid_token_ids(dtype): + if not torch.cuda.is_available(): + pytest.skip("CUDA is required for Triton kernels.") + + batch_size = 4 + vocab_size = 32 + logits = torch.randn((batch_size, vocab_size), device="cuda", dtype=dtype) + expected = logits.clone() + + invalid_token_ids_per_batch = [ + [1, 3, 5], + [], + [0, 2, 31], + [7], + ] + + flat_ids = [] + cu_invalid_token_num = [0] + invalid_token_num_start = 0 + for ids in invalid_token_ids_per_batch: + flat_ids.extend(ids) + invalid_token_num_start += len(ids) + cu_invalid_token_num.append(invalid_token_num_start) + + invalid_token_ids = torch.tensor(flat_ids, device="cuda", dtype=torch.int32) + cu_invalid_token_num = torch.tensor(cu_invalid_token_num, device="cuda", dtype=torch.int32) + + for batch_idx, ids in enumerate(invalid_token_ids_per_batch): + if ids: + expected[batch_idx, ids] = float("-inf") + + apply_invalid_token_ids( + Logits=logits, + invalid_token_ids=invalid_token_ids, + cu_invalid_token_num=cu_invalid_token_num, + ) + assert torch.equal(logits, expected) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/unit_tests/server/router/dynamic_prompt/test_radix_cache.py b/unit_tests/server/router/dynamic_prompt/test_radix_cache.py index 605433e9d8..dfeda0b6f7 100644 --- a/unit_tests/server/router/dynamic_prompt/test_radix_cache.py +++ b/unit_tests/server/router/dynamic_prompt/test_radix_cache.py @@ -230,5 +230,32 @@ def test_case9(): assert torch.equal(unmerged_node_d.token_id_key, torch.tensor([6], dtype=torch.int64)) +def test_case10(): + """ + 测试场景:测试 flush_cache 函数 + """ + print("\nTest Case 10: Testing flush_cache function\n") + tree = RadixCache("unique_name", 100, 0) + tree.insert(torch.tensor([1, 2, 3], dtype=torch.int64)) + tree.insert(torch.tensor([1, 2, 3, 4, 5], dtype=torch.int64)) + tree_node, size, values = tree.match_prefix( + torch.tensor([1, 2, 3], dtype=torch.int64, device="cpu"), update_refs=True + ) + assert tree_node is not None + assert size == 3 + tree.flush_cache() + tree_node, size, values = tree.match_prefix( + torch.tensor([1, 2, 3], dtype=torch.int64, device="cpu"), update_refs=True + ) + assert tree_node is None + assert size == 0 + assert tree.get_tree_total_tokens_num() == 0 + assert tree.get_refed_tokens_num() == 0 + assert len(tree.root_node.children) == 0 + assert tree.root_node.token_id_key.numel() == 0 + assert tree.root_node.token_mem_index_value.numel() == 0 + assert tree.root_node.ref_counter == 1 + + if __name__ == "__main__": pytest.main()