diff --git a/.gitignore b/.gitignore
index 63408699f4..3fb49db8b1 100644
--- a/.gitignore
+++ b/.gitignore
@@ -7,3 +7,4 @@ dist
.vscode
tmp/
requirements-musa.txt
+CLAUDE.md
diff --git a/lightllm/common/basemodel/attention/base_att.py b/lightllm/common/basemodel/attention/base_att.py
index 859d97ca84..6429bce9a0 100644
--- a/lightllm/common/basemodel/attention/base_att.py
+++ b/lightllm/common/basemodel/attention/base_att.py
@@ -65,6 +65,7 @@ class AttControl:
mla_prefill_dict: Dict = None
mla_decode: bool = False
mla_decode_dict: Dict = None
+ scale: float = None
@dataclass
diff --git a/lightllm/common/basemodel/attention/fa3/fp.py b/lightllm/common/basemodel/attention/fa3/fp.py
index 952bb39d91..2f5fccd57b 100644
--- a/lightllm/common/basemodel/attention/fa3/fp.py
+++ b/lightllm/common/basemodel/attention/fa3/fp.py
@@ -220,8 +220,11 @@ def _normal_decode_att(
sink_weight = None
k_descale, v_descale = None, None # disable quantization
- Lq = q.shape[-1]
- sm_scale = 1.0 / (Lq ** 0.5)
+ if att_control.scale is not None:
+ sm_scale = att_control.scale
+ else:
+ Lq = q.shape[-1]
+ sm_scale = 1.0 / (Lq ** 0.5)
o = flash_attn_with_kvcache(
q=q,
k_cache=k.view(k.shape[0], 1, k.shape[1], k.shape[2]),
diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py
index 5c1d2b8712..6702653c8e 100755
--- a/lightllm/common/basemodel/basemodel.py
+++ b/lightllm/common/basemodel/basemodel.py
@@ -6,7 +6,7 @@
import json
import torch
import torch.nn.functional as F
-from typing import final, List
+from typing import final, List, Optional
from tqdm import tqdm
from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights
@@ -32,6 +32,10 @@
from lightllm.utils.envs_utils import set_model_init_status, enable_diverse_mode_gqa_decode_fast_kernel
from lightllm.common.triton_utils.autotuner import Autotuner
from lightllm.utils.infer_utils import post_empty_cache
+from lightllm.utils.torch_memory_saver_utils import (
+ TorchMemorySaverWrapper,
+ MemoryTag,
+)
from .attention import get_prefill_att_backend_class, get_decode_att_backend_class
from .attention import BaseAttBackend
@@ -90,6 +94,7 @@ def __init__(self, kvargs):
self.tp_world_size_ = get_dp_world_size()
self.enable_tpsp_mix_mode = get_env_start_args().enable_tpsp_mix_mode
+ self.torch_memory_saver = TorchMemorySaverWrapper(self.args.enable_torch_memory_saver)
self.is_mtp_mode = self.args.mtp_mode in [
"vanilla_with_att",
"eagle_with_att",
@@ -103,15 +108,17 @@ def __init__(self, kvargs):
self._verify_params()
self._init_quant()
- self._init_weights()
- self._init_mem_manager()
- self._init_kv_move_buffer()
+ with self.torch_memory_saver.region(tag=MemoryTag.WEIGHT, enable_cpu_backup=self.args.enable_weight_cpu_backup):
+ self._init_weights()
+ with self.torch_memory_saver.region(tag=MemoryTag.KV_CACHE):
+ self._init_mem_manager()
+ self._init_kv_move_buffer()
+ self._init_req_manager()
self._check_mem_size()
- self._init_req_manager()
self._init_infer_layer()
self._init_some_value()
self._init_custom()
- self._load_hf_weights()
+ self.load_weights(self.weight_dict)
# wait必须在init cudagraph 之前,避免错误捕获
self._wait_other_modules_ready()
@@ -176,17 +183,15 @@ def _init_weights(self, start_layer_index=0):
]
return
- def _load_hf_weights(self):
+ def load_weights(self, weight_dict: dict):
+ assert weight_dict is None or isinstance(weight_dict, dict), "weight_dict must be a dict or None"
load_hf_weights(
self.data_type,
- weight_dir=self.weight_dir_,
+ self.weight_dir_,
pre_post_layer=self.pre_post_weight,
transformer_layer_list=self.trans_layers_weight,
- weight_dict=self.weight_dict,
+ weight_dict=weight_dict,
)
- self.pre_post_weight.verify_load()
- [weight.verify_load() for weight in self.trans_layers_weight]
- return
def _init_mem_manager(self):
assert self.config["num_attention_heads"] % self.tp_world_size_ == 0
@@ -884,6 +889,7 @@ def _check_max_len_infer(self):
)
logger.error(exception_str)
raise Exception(exception_str)
+ torch.cuda.empty_cache()
return
def autotune_layers(self):
@@ -1012,6 +1018,9 @@ def _init_padded_req(self):
del b_seq_len
del b_ready_cache_len
del model_output
+ del b_mtp_index
+ del b_prefill_start_loc
+ del b_q_seq_len
torch.cuda.empty_cache()
return
@@ -1032,3 +1041,71 @@ def _gen_special_model_input(self, token_num: int):
special_model_input["mtp_draft_input_hiddens"] = None
return special_model_input
+
+ def release_memory_occupation(self, tags: Optional[List[MemoryTag]]):
+ if tags is None:
+ self.release_all()
+ return
+ if MemoryTag.WEIGHT in tags:
+ self.release_weight()
+ if MemoryTag.KV_CACHE in tags:
+ self.release_kv_cache()
+ if MemoryTag.GRAPH in tags:
+ self.release_graph()
+ return
+
+ def resume_memory_occupation(self, tags: Optional[List[MemoryTag]]):
+ if tags is None:
+ self.resume_all()
+ return
+ if MemoryTag.WEIGHT in tags:
+ self.resume_weight()
+ if MemoryTag.KV_CACHE in tags:
+ self.resume_kv_cache()
+ if MemoryTag.GRAPH in tags:
+ self.resume_graph()
+ return
+
+ def release_weight(self):
+ self.torch_memory_saver.pause(tag=MemoryTag.WEIGHT)
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ def release_kv_cache(self):
+ self.torch_memory_saver.pause(tag=MemoryTag.KV_CACHE)
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ def release_graph(self):
+ self.torch_memory_saver.pause(tag=MemoryTag.GRAPH)
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ def release_all(self):
+ self.torch_memory_saver.pause(tag=MemoryTag.WEIGHT)
+ self.torch_memory_saver.pause(tag=MemoryTag.KV_CACHE)
+ self.torch_memory_saver.pause(tag=MemoryTag.GRAPH)
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ def resume_weight(self):
+ torch.cuda.empty_cache()
+ gc.collect()
+ self.torch_memory_saver.resume(tag=MemoryTag.WEIGHT)
+
+ def resume_kv_cache(self):
+ torch.cuda.empty_cache()
+ gc.collect()
+ self.torch_memory_saver.resume(tag=MemoryTag.KV_CACHE)
+
+ def resume_graph(self):
+ torch.cuda.empty_cache()
+ gc.collect()
+ self.torch_memory_saver.resume(tag=MemoryTag.GRAPH)
+
+ def resume_all(self):
+ torch.cuda.empty_cache()
+ gc.collect()
+ self.torch_memory_saver.resume(tag=MemoryTag.WEIGHT)
+ self.torch_memory_saver.resume(tag=MemoryTag.KV_CACHE)
+ self.torch_memory_saver.resume(tag=MemoryTag.GRAPH)
diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py
index dd29c9a833..2007e4db4e 100644
--- a/lightllm/common/basemodel/cuda_graph.py
+++ b/lightllm/common/basemodel/cuda_graph.py
@@ -7,6 +7,10 @@
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.distributed import dist_group_manager, lightllm_capture_graph, CustomProcessGroup
from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput
+from lightllm.utils.torch_memory_saver_utils import (
+ TorchMemorySaverWrapper,
+ MemoryTag,
+)
from .infer_struct import InferStateInfo
@@ -24,6 +28,7 @@ def __init__(self, max_batch_size=8, max_len_in_batch=8192):
self.max_batch_size = max_batch_size
self.graph_max_len_in_batch = max_len_in_batch
self.enable_decode_microbatch_overlap = self.args.enable_decode_microbatch_overlap
+ self.torch_memory_saver = TorchMemorySaverWrapper(self.args.enable_torch_memory_saver)
# gen cuda graph batch_sizes
# cuda graph gen for batch size = [1, 2, 3, ..., graph_split_batch_size]
@@ -89,7 +94,7 @@ def _capture_decode(self, decode_func, infer_state: InferStateInfo):
delattr(infer_state, param_name)
with lightllm_capture_graph(dist_group):
- with torch.cuda.graph(graph_obj, pool=self.mempool):
+ with self.torch_memory_saver.cuda_graph(graph_obj, pool=self.mempool):
model_output = decode_func(infer_state)
self.graph[batch_size] = (graph_obj, infer_state, model_output)
graph_obj.replay()
@@ -127,7 +132,7 @@ def _capture_decode_overlap(
with lightllm_capture_graph(dist_group1):
with lightllm_capture_graph(dist_group):
- with torch.cuda.graph(graph_obj, pool=self.mempool):
+ with self.torch_memory_saver.cuda_graph(graph_obj, pool=self.mempool):
model_output, model_output1 = decode_func(infer_state, infer_state1)
self.graph[batch_size] = (
graph_obj,
diff --git a/lightllm/common/basemodel/layer_weights/hf_load_utils.py b/lightllm/common/basemodel/layer_weights/hf_load_utils.py
index 8cf66a5ad6..ec0e282844 100755
--- a/lightllm/common/basemodel/layer_weights/hf_load_utils.py
+++ b/lightllm/common/basemodel/layer_weights/hf_load_utils.py
@@ -5,6 +5,8 @@
from tqdm import tqdm
import lightllm.utils.petrel_helper as utils
from lightllm.utils.dist_utils import get_current_device_id
+from queue import Queue
+from threading import Thread
def load_func(file_, use_safetensors=False, pre_post_layer=None, transformer_layer_list=None, weight_dir=None):
@@ -65,7 +67,6 @@ def load_hf_weights(data_type, weight_dir, pre_post_layer=None, transformer_laye
iterator = p.imap_unordered(partial_func, candidate_files, chunksize=1)
desc_str = f"pid {os.getpid()} Loading model weights with {worker} workers"
iterator = tqdm(iterator, total=len(candidate_files), desc=desc_str)
-
for _ in iterator:
pass
diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py
index 6bcf7fc03c..3dc888b6ac 100644
--- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py
+++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py
@@ -33,6 +33,7 @@ def __init__(
num_fused_shared_experts: int = 0,
layer_num: int = 0,
network_config: Dict[str, Any] = None,
+ moe_layer_index: int = 0,
) -> None:
super().__init__(data_type=data_type)
self.w1_weight_name = gate_proj_name
@@ -50,6 +51,7 @@ def __init__(
self.enable_ep_moe = get_env_start_args().enable_ep_moe
self.n_routed_experts = n_routed_experts
self.num_fused_shared_experts = num_fused_shared_experts
+ self.moe_layer_index = moe_layer_index
self._init_config(network_config)
self._init_redundancy_expert_params()
self._init_parallel_params()
@@ -130,6 +132,7 @@ def experts(
topk_group: int,
num_expert_group: int,
is_prefill: Optional[bool] = None,
+ microbatch_index: int = 0,
) -> torch.Tensor:
"""Backward compatible method that routes to platform-specific implementation."""
return self.fuse_moe_impl(
@@ -145,6 +148,8 @@ def experts(
topk_group=topk_group,
num_expert_group=num_expert_group,
is_prefill=is_prefill,
+ moe_layer_index=self.moe_layer_index,
+ microbatch_index=microbatch_index,
)
def low_latency_dispatch(
@@ -295,6 +300,7 @@ def _create_weight(self):
device_id=self.device_id_,
num_experts=self.local_n_routed_experts,
)
+ self.w1, self.w3 = w13_param_list
self.w1_list: List[WeightPack] = self._get_expert_weight_list(w13_param_list[0])
self.w3_list: List[WeightPack] = self._get_expert_weight_list(w13_param_list[1])
self.w2_list: List[WeightPack] = self._get_expert_weight_list(self.w2)
@@ -312,6 +318,8 @@ def _load_weight(self, expert_idx_to_local_idx: Dict[int, int], weights: Dict[st
for expert_idx, local_expert_idx in expert_idx_to_local_idx.items():
with self.lock:
self._load_expert(expert_idx, local_expert_idx, weights)
+ # for rl updated weight
+ self._load_merge_weight(weights)
self._load_expert_scale(
expert_idx,
local_expert_idx,
@@ -332,6 +340,7 @@ def _load_expert(
w1_weight = f"{self.weight_prefix}.{expert_idx}.{self.w1_weight_name}.{self.quant_method.weight_suffix}"
w2_weight = f"{self.weight_prefix}.{expert_idx}.{self.w2_weight_name}.{self.quant_method.weight_suffix}"
w3_weight = f"{self.weight_prefix}.{expert_idx}.{self.w3_weight_name}.{self.quant_method.weight_suffix}"
+
row_slice_func = self.row_slicer._slice_weight
col_slice_func = self.col_slicer._slice_weight
if w1_weight in weights:
@@ -341,6 +350,19 @@ def _load_expert(
if w2_weight in weights:
self.quant_method.load_weight(col_slice_func(weights[w2_weight]), self.w2_list[local_expert_idx])
+ def _load_merge_weight(self, weights: Dict[str, torch.Tensor]):
+ w1_merge_weight = f"{self.weight_prefix}.{self.w1_weight_name}"
+ w2_merge_weight = f"{self.weight_prefix}.{self.w2_weight_name}"
+ w3_merge_weight = f"{self.weight_prefix}.{self.w3_weight_name}"
+ row_slice_func = self.row_slicer._slice_weight
+ col_slice_func = self.col_slicer._slice_weight
+ if w1_merge_weight in weights:
+ self.quant_method.load_weight(row_slice_func(weights[w1_merge_weight]), self.w1)
+ if w2_merge_weight in weights:
+ self.quant_method.load_weight(col_slice_func(weights[w2_merge_weight]), self.w2)
+ if w3_merge_weight in weights:
+ self.quant_method.load_weight(row_slice_func(weights[w3_merge_weight]), self.w3)
+
def _load_expert_scale(
self,
expert_idx: int,
diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py
index 6ed0cef0b4..4ca1605be4 100644
--- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py
+++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py
@@ -8,6 +8,7 @@
from lightllm.common.quantization import Quantcfg
from lightllm.common.quantization.quantize_method import QuantizationMethod
from lightllm.utils.log_utils import init_logger
+from lightllm.common.basemodel import routing_manager as _routing_mgr
logger = init_logger(__name__)
@@ -46,6 +47,7 @@ def __init__(
num_fused_shared_experts: int = 0,
layer_num: int = 0,
network_config: Dict[str, Any] = None,
+ moe_layer_index: int = 0,
) -> None:
network_config["norm_topk_prob"] = None
super().__init__(
@@ -62,6 +64,7 @@ def __init__(
num_fused_shared_experts=num_fused_shared_experts,
layer_num=layer_num,
network_config=network_config,
+ moe_layer_index=moe_layer_index,
)
self.hidden_size = network_config["hidden_size"]
@@ -144,10 +147,15 @@ def experts(
topk_group: int,
num_expert_group: int,
is_prefill: Optional[bool] = None,
+ microbatch_index: int = 0,
):
topk_weights, topk_ids = self._router(router_logits, top_k)
+ # Rollout router replay
+ if _routing_mgr.g_routing_capture_manager is not None:
+ _routing_mgr.g_routing_capture_manager.capture(self.moe_layer_index, topk_ids, microbatch_index)
+
w1, w1_scale = self.w1
w2, w2_scale = self.w2
use_fp8_w8a8 = self.quant_method is not None
diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py
index 00587ac185..1c93cb13dc 100644
--- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py
+++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py
@@ -62,5 +62,7 @@ def __call__(
topk_group: int,
num_expert_group: int,
is_prefill: Optional[bool] = None,
+ moe_layer_index: Optional[int] = None,
+ microbatch_index: int = 0,
) -> torch.Tensor:
pass
diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py
index 8bcdb4bf90..1e81b226ec 100644
--- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py
+++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py
@@ -3,6 +3,7 @@
from lightllm.common.quantization.no_quant import WeightPack
from lightllm.common.quantization.quantize_method import QuantizationMethod
from .base_impl import FuseMoeBaseImpl
+from lightllm.common.basemodel import routing_manager as _routing_mgr
class FuseMoeTriton(FuseMoeBaseImpl):
@@ -124,6 +125,8 @@ def __call__(
topk_group: int,
num_expert_group: int,
is_prefill: Optional[bool] = None,
+ moe_layer_index: Optional[int] = None,
+ microbatch_index: int = 0,
):
topk_weights, topk_ids = self._select_experts(
input_tensor=input_tensor,
@@ -136,6 +139,10 @@ def __call__(
num_expert_group=num_expert_group,
scoring_func=scoring_func,
)
+
+ if _routing_mgr.g_routing_capture_manager is not None and moe_layer_index is not None:
+ _routing_mgr.g_routing_capture_manager.capture(moe_layer_index, topk_ids, microbatch_index)
+
output = self._fused_experts(
input_tensor=input_tensor,
w13=w13,
diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py
index ddbf98a866..15f050c14a 100644
--- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py
+++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py
@@ -47,17 +47,17 @@ def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Ten
# 默认weight 的shape是 outxin,这也是目前最通用的约定。
-# 所以row-wise是沿着dim=0进行切分,col-wise是沿着dim=1进行切分。
+# 这里约定row-wise沿着倒数第二维切分,col-wise沿着第一维切分。
class RowSliceMixin(SliceMixinTpl):
def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times: int = 1):
super().__init__(tp_rank, tp_world_size, repeat_times)
def _slice_weight(self, weight: torch.Tensor) -> torch.Tensor:
assert (
- weight.shape[0] * self.repeat_times_ % self.tp_world_size_ == 0
- ), f"tp slice error {weight.shape[0] * self.repeat_times_} % {self.tp_world_size_}"
- start, end = self._get_slice_start_end(weight.shape[0])
- return weight[start:end, :]
+ weight.shape[-2] * self.repeat_times_ % self.tp_world_size_ == 0
+ ), f"tp slice error {weight.shape[-2] * self.repeat_times_} % {self.tp_world_size_}"
+ start, end = self._get_slice_start_end(weight.shape[-2])
+ return weight[..., start:end, :]
def _slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
assert (
@@ -75,17 +75,17 @@ def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times:
def _slice_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor:
assert (
- weight_scale.shape[0] % self.tp_world_size_ == 0
- ), f"tp slice error {weight_scale.shape[0]} % {self.tp_world_size_}"
- start, end = self._get_slice_start_end(weight_scale.shape[0])
- return weight_scale[start:end]
+ weight_scale.shape[-2] % self.tp_world_size_ == 0
+ ), f"tp slice error {weight_scale.shape[-2]} % {self.tp_world_size_}"
+ start, end = self._get_slice_start_end(weight_scale.shape[-2])
+ return weight_scale[..., start:end, :]
def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Tensor:
assert (
- weight_zero_point.shape[0] % self.tp_world_size_ == 0
- ), f"tp slice error {weight_zero_point.shape[0]} % {self.tp_world_size_}"
- start, end = self._get_slice_start_end(weight_zero_point.shape[0])
- return weight_zero_point[start:end]
+ weight_zero_point.shape[-2] % self.tp_world_size_ == 0
+ ), f"tp slice error {weight_zero_point.shape[-2]} % {self.tp_world_size_}"
+ start, end = self._get_slice_start_end(weight_zero_point.shape[-2])
+ return weight_zero_point[..., start:end, :]
class ColSliceMixin(SliceMixinTpl):
@@ -94,10 +94,10 @@ def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times:
def _slice_weight(self, weight: torch.Tensor) -> torch.Tensor:
assert (
- weight.shape[1] * self.repeat_times_ % self.tp_world_size_ == 0
- ), f"tp slice error {weight.shape[1] * self.repeat_times_ } % {self.tp_world_size_}"
- start, end = self._get_slice_start_end(weight.shape[1])
- return weight[:, start:end]
+ weight.shape[-1] * self.repeat_times_ % self.tp_world_size_ == 0
+ ), f"tp slice error {weight.shape[-1] * self.repeat_times_ } % {self.tp_world_size_}"
+ start, end = self._get_slice_start_end(weight.shape[-1])
+ return weight[..., start:end]
def _slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
return bias / self.tp_world_size_ * self.repeat_times_
@@ -110,16 +110,16 @@ def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times:
def _slice_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor:
assert (
weight_scale.shape[1] * self.repeat_times_ % self.tp_world_size_ == 0
- ), f"tp slice error {weight_scale.shape[1] * self.repeat_times_ } % {self.tp_world_size_}"
- start, end = self._get_slice_start_end(weight_scale.shape[1])
- return weight_scale[:, start:end]
+ ), f"tp slice error {weight_scale.shape[-1] * self.repeat_times_ } % {self.tp_world_size_}"
+ start, end = self._get_slice_start_end(weight_scale.shape[-1])
+ return weight_scale[..., start:end]
def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Tensor:
assert (
- weight_zero_point.shape[1] * self.repeat_times_ % self.tp_world_size_ == 0
- ), f"tp slice error {weight_zero_point.shape[1] * self.repeat_times_ } % {self.tp_world_size_}"
- start, end = self._get_slice_start_end(weight_zero_point.shape[1])
- return weight_zero_point[:, start:end]
+ weight_zero_point.shape[-1] * self.repeat_times_ % self.tp_world_size_ == 0
+ ), f"tp slice error {weight_zero_point.shape[-1] * self.repeat_times_ } % {self.tp_world_size_}"
+ start, end = self._get_slice_start_end(weight_zero_point.shape[-1])
+ return weight_zero_point[..., start:end]
# awq 的量化权重是inxout存储格式,需要定制实现。
diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py
index 3630bc2c00..da9b3f4321 100644
--- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py
+++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py
@@ -193,7 +193,9 @@ def _create_weight(self):
def load_hf_weights(self, weights: Dict[str, torch.Tensor]):
for weight_name in self.weight_names:
if weight_name in weights:
- weight = self.param_slicer._slice_weight(weights[weight_name])
+ tp_start = self.tp_rank_ * self.dim0
+ tp_end = (self.tp_rank_ + 1) * self.dim0
+ weight = weights[weight_name][tp_start:tp_end, :, :]
self.weight.copy_(weight)
self.weight.load_ok = True
return
diff --git a/lightllm/common/basemodel/routing_manager.py b/lightllm/common/basemodel/routing_manager.py
new file mode 100644
index 0000000000..77b611130f
--- /dev/null
+++ b/lightllm/common/basemodel/routing_manager.py
@@ -0,0 +1,191 @@
+import atexit
+import torch
+import numpy as np
+from multiprocessing import shared_memory
+from typing import Optional
+from lightllm.utils.log_utils import init_logger
+from lightllm.utils.dist_utils import get_current_rank_in_dp
+from lightllm.server.router.dynamic_prompt.shared_arr import SharedArray
+from lightllm.utils.envs_utils import get_unique_server_name
+
+logger = init_logger(__name__)
+
+
+def routing_dtype_id_to_np(dtype_id: int):
+ if dtype_id == 1:
+ return np.uint8
+ elif dtype_id == 2:
+ return np.int16
+ return np.int32
+
+
+def get_routing_config_shm() -> SharedArray:
+ service_name = get_unique_server_name()
+ return SharedArray(f"{service_name}_routing_config", shape=(4,), dtype=np.int32)
+
+
+class RoutingCaptureManager:
+ def __init__(
+ self,
+ num_moe_layers: int,
+ topk: int,
+ num_experts: int,
+ kv_cache_size: int,
+ max_capture_tokens: int,
+ ):
+ self.num_moe_layers = num_moe_layers
+ self.topk = topk
+ self.num_experts = num_experts
+ self.kv_cache_size = kv_cache_size
+
+ self.dtype = torch.uint8 if num_experts <= 255 else torch.int16
+ dtype_bytes = 1 if self.dtype == torch.uint8 else 2
+
+ # Shape: (kv_cache_size, num_moe_layers, topk) — on CPU to save GPU memory.
+ # Written after forward() via flush_to_routing_buffer(), read on request finish.
+ routing_buffer_size = num_moe_layers * kv_cache_size * topk * dtype_bytes
+ self.routing_buffer = torch.zeros(
+ (kv_cache_size, num_moe_layers, topk),
+ dtype=self.dtype,
+ device="cpu",
+ )
+
+ # Capture buffers: simple contiguous tensors written to during forward().
+ capture_buf_size = max_capture_tokens * num_moe_layers * topk * dtype_bytes
+ self._capture_buffer = [
+ torch.zeros((max_capture_tokens, num_moe_layers, topk), dtype=self.dtype, device="cuda") for _ in range(2)
+ ]
+
+ dtype_name = "uint8" if self.dtype == torch.uint8 else "int16"
+ logger.info(
+ f"RoutingCaptureManager initialized: {num_moe_layers} MoE layers, topk={topk}, "
+ f"routing_buffer(cpu)={routing_buffer_size / 1024 / 1024:.2f}MB, "
+ f"capture_buffer={capture_buf_size / 1024 / 1024:.2f}MB x2, dtype={dtype_name}"
+ )
+
+ @property
+ def np_dtype(self):
+ return np.uint8 if self.dtype == torch.uint8 else np.int16
+
+ @property
+ def dtype_id(self) -> int:
+ return 1 if self.dtype == torch.uint8 else 2
+
+ def capture(self, moe_layer_index: int, topk_ids: torch.Tensor, microbatch_index: int = 0) -> None:
+ num_tokens = topk_ids.shape[0]
+ self._capture_buffer[microbatch_index][:num_tokens, moe_layer_index, :] = topk_ids.to(self.dtype)
+
+ def flush_to_routing_buffer(self, mem_indexes: torch.Tensor, num_tokens: int, microbatch_index: int = 0) -> None:
+ buf = self._capture_buffer[microbatch_index][:num_tokens] # (num_tokens, num_moe_layers, topk)
+ self.routing_buffer[mem_indexes[:num_tokens].cpu(), :, :] = buf.cpu()
+
+ def extract_routing_data(self, mem_indexes: torch.Tensor) -> np.ndarray:
+ cpu_indexes = mem_indexes.cpu() if mem_indexes.is_cuda else mem_indexes
+ return self.routing_buffer[cpu_indexes, :, :].numpy()
+
+
+g_routing_capture_manager: Optional[RoutingCaptureManager] = None
+
+
+def create_routing_capture_manager(
+ num_moe_layers: int,
+ topk: int,
+ num_experts: int,
+ kv_cache_size: int,
+ max_capture_tokens: int,
+) -> None:
+ global g_routing_capture_manager
+ assert g_routing_capture_manager is None, "RoutingCaptureManager already exists"
+ g_routing_capture_manager = RoutingCaptureManager(
+ num_moe_layers=num_moe_layers,
+ topk=topk,
+ num_experts=num_experts,
+ kv_cache_size=kv_cache_size,
+ max_capture_tokens=max_capture_tokens,
+ )
+
+
+def cleanup_routing_shm_pool() -> None:
+ """Unlink all pre-allocated routing SHM segments. Called at server shutdown."""
+ try:
+ from lightllm.utils.envs_utils import get_env_start_args
+
+ args = get_env_start_args()
+ except Exception:
+ return
+
+ service_name = get_unique_server_name()
+
+ for i in range(args.running_max_req_size):
+ name = f"{service_name}_shm_routing_{i}"
+ try:
+ shm = shared_memory.SharedMemory(name=name)
+ shm.close()
+ shm.unlink()
+ except Exception:
+ pass
+
+ config_name = f"{service_name}_routing_config"
+ try:
+ shm = shared_memory.SharedMemory(name=config_name)
+ shm.close()
+ shm.unlink()
+ except Exception:
+ pass
+
+
+def init_routing_capture(model, num_moe_layers: int) -> None:
+ dp_rank = get_current_rank_in_dp()
+ logger.info(f"init_routing_capture called: num_moe_layers={num_moe_layers}, dp_rank={dp_rank}")
+ if dp_rank != 0:
+ logger.info(f"Skipping routing capture initialization on dp_rank={dp_rank}")
+ return
+
+ if num_moe_layers == 0:
+ logger.warning(
+ "enable_return_routed_experts is set but no MoE layers found. Routing capture will not be enabled."
+ )
+ return
+
+ num_experts = model.config.get("n_routed_experts", model.config.get("num_experts", 0))
+ topk = model.config.get("num_experts_per_tok", 0)
+ assert num_experts > 0 and topk > 0
+
+ from lightllm.utils.envs_utils import get_env_start_args
+
+ args = get_env_start_args()
+
+ # Capture buffer must fit the max tokens in any single forward call.
+ # For prefill that's batch_max_tokens; for decode it's graph_max_batch_size.
+ batch_max_tokens = args.batch_max_tokens or args.max_req_total_len or 8192
+ max_capture_tokens = max(batch_max_tokens, args.graph_max_batch_size)
+
+ logger.info(
+ f"Initializing routing capture: num_moe_layers={num_moe_layers}, "
+ f"topk={topk}, num_experts={num_experts}, max_capture_tokens={max_capture_tokens}"
+ )
+
+ create_routing_capture_manager(
+ num_moe_layers=num_moe_layers,
+ topk=topk,
+ num_experts=num_experts,
+ kv_cache_size=model.mem_manager.size + 1,
+ max_capture_tokens=max_capture_tokens,
+ )
+
+ mgr = g_routing_capture_manager
+ dtype_id = mgr.dtype_id
+
+ max_req_total_len = args.max_req_total_len
+
+ # Write config to cross-process SHM
+ shm = get_routing_config_shm()
+ shm.arr[0] = num_moe_layers
+ shm.arr[1] = topk
+ shm.arr[2] = dtype_id
+ shm.arr[3] = max_req_total_len
+ logger.info(
+ f"Shared routing config set: num_moe_layers={num_moe_layers}, topk={topk}, "
+ f"dtype_id={dtype_id}, max_tokens={max_req_total_len}"
+ )
+ atexit.register(cleanup_routing_shm_pool)
diff --git a/lightllm/common/basemodel/triton_kernel/norm/qk_norm.py b/lightllm/common/basemodel/triton_kernel/norm/qk_norm.py
index 40322e5093..8e0de6a6e3 100644
--- a/lightllm/common/basemodel/triton_kernel/norm/qk_norm.py
+++ b/lightllm/common/basemodel/triton_kernel/norm/qk_norm.py
@@ -64,3 +64,28 @@ def qk_rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, eps):
num_warps=4,
)
return x
+
+
+# @torch.no_grad()
+# def qk_rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, eps: float):
+# assert torch.is_tensor(x) and torch.is_tensor(weight)
+# # assert weight.ndim == 1, weight.shape
+# # assert x.is_contiguous(), "x.is_contiguous()"
+
+# head_dim = weight.numel()
+# x2d = x.view(-1, x.shape[-1]) # (M2, N)
+# M2, N = x2d.shape
+# assert N % head_dim == 0, (N, head_dim)
+# H = N // head_dim
+
+# x3 = x2d.view(M2, H, head_dim) # (M2, H, D)
+
+# x_fp32 = x3.to(torch.float32)
+# w = weight.view(1, 1, head_dim)
+
+# var = x_fp32.pow(2).mean(dim=-1, keepdim=True)
+# rstd = torch.rsqrt(var + eps)
+# y = (x_fp32 * rstd).to(torch.bfloat16) * w
+
+# x3.copy_(y.to(dtype=x3.dtype))
+# return x
diff --git a/lightllm/common/basemodel/triton_kernel/norm/rmsnorm.py b/lightllm/common/basemodel/triton_kernel/norm/rmsnorm.py
index ca8f9a1c81..6988cc4113 100644
--- a/lightllm/common/basemodel/triton_kernel/norm/rmsnorm.py
+++ b/lightllm/common/basemodel/triton_kernel/norm/rmsnorm.py
@@ -4,7 +4,7 @@
import triton.language as tl
import os
-rmsnorm_num_warps = int(os.getenv("RMSNORM_WARPS", "8"))
+rmsnorm_num_warps = int(os.getenv("RMSNORM_WARPS", "4"))
@triton.jit
@@ -36,12 +36,12 @@ def _rms_norm_fwd_fused(
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
mask = cols < N
- w = tl.load(W + cols, mask=mask).to(tl.float32)
+ w = tl.load(W + cols, mask=mask)
x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32)
- x_hat = x * rstd
- y = x_hat * w
+ x_hat = (x * rstd).to(tl.bfloat16)
+ y = x_hat * w.to(tl.bfloat16)
# Write output
- tl.store(Y + cols * y_stride1, y.to(Y.dtype.element_ty), mask=mask)
+ tl.store(Y + cols * y_stride1, y, mask=mask)
def rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, eps: float, out=None):
@@ -79,22 +79,19 @@ def rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, eps: float, out=None)
return y
-def torch_rms_norm(x, weight, eps):
- return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) * weight
+# def rmsnorm_forward(hidden_states, weight, eps, out=None):
+# input_dtype = hidden_states.dtype
+# hidden_states = hidden_states.to(torch.float32)
+# variance = hidden_states.pow(2).mean(-1, keepdim=True)
+# hidden_states = hidden_states * torch.rsqrt(variance + eps)
+# out = weight * hidden_states.to(input_dtype)
+# return out
-def test_rms_norm(M, N, dtype, eps=1e-5, device="cuda"):
- # create data
- x_shape = (M, N)
- w_shape = (x_shape[-1],)
- weight = torch.rand(w_shape, dtype=dtype, device="cuda")
- x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
- # forward pass
- y_tri = rmsnorm_forward(x, weight, eps)
- y_ref = torch_rms_norm(x.to(torch.float32), weight.to(torch.float32), eps).to(dtype)
-
- # compare
- print("type:", y_tri.dtype, y_ref.dtype)
- print("max delta:", torch.max(torch.abs(y_tri - y_ref)))
- assert torch.allclose(y_tri, y_ref, atol=1e-2, rtol=0)
- return
+def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ print(f"norm weight dtype:{self.weight.dtype}")
+ return self.weight * hidden_states.to(input_dtype)
diff --git a/lightllm/common/basemodel/triton_kernel/post_process/__init__.py b/lightllm/common/basemodel/triton_kernel/post_process/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/lightllm/common/basemodel/triton_kernel/post_process/apply_invalid_token.py b/lightllm/common/basemodel/triton_kernel/post_process/apply_invalid_token.py
new file mode 100644
index 0000000000..353affd8ed
--- /dev/null
+++ b/lightllm/common/basemodel/triton_kernel/post_process/apply_invalid_token.py
@@ -0,0 +1,36 @@
+import torch
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _fwd_kernel_apply_invalid_token(
+ Logits,
+ invalid_token_ids,
+ cu_invalid_token_num,
+ stride_logit_b,
+):
+ cur_batch = tl.program_id(0)
+ start_index = tl.load(cu_invalid_token_num + cur_batch)
+ end_index = tl.load(cu_invalid_token_num + cur_batch + 1)
+ for i in range(start_index, end_index):
+ cur_invalid_token_id = tl.load(invalid_token_ids + i)
+ cur_logit_ptr = Logits + cur_batch * stride_logit_b + cur_invalid_token_id
+ tl.store(cur_logit_ptr, float("-inf"))
+ return
+
+
+def apply_invalid_token_ids(
+ Logits: torch.Tensor,
+ invalid_token_ids: torch.Tensor,
+ cu_invalid_token_num: torch.Tensor,
+):
+ batch_size = Logits.shape[0]
+ grid = (batch_size,)
+ _fwd_kernel_apply_invalid_token[grid](
+ Logits=Logits,
+ invalid_token_ids=invalid_token_ids,
+ cu_invalid_token_num=cu_invalid_token_num,
+ stride_logit_b=Logits.stride(0),
+ )
+ return
diff --git a/lightllm/common/basemodel/triton_kernel/apply_penalty.py b/lightllm/common/basemodel/triton_kernel/post_process/apply_penalty.py
similarity index 100%
rename from lightllm/common/basemodel/triton_kernel/apply_penalty.py
rename to lightllm/common/basemodel/triton_kernel/post_process/apply_penalty.py
diff --git a/lightllm/common/basemodel/triton_kernel/apply_penalty_gpu_cache.py b/lightllm/common/basemodel/triton_kernel/post_process/apply_penalty_gpu_cache.py
similarity index 100%
rename from lightllm/common/basemodel/triton_kernel/apply_penalty_gpu_cache.py
rename to lightllm/common/basemodel/triton_kernel/post_process/apply_penalty_gpu_cache.py
diff --git a/lightllm/common/kv_cache_mem_manager/__init__.py b/lightllm/common/kv_cache_mem_manager/__init__.py
index 7d516e6728..bcc1292097 100644
--- a/lightllm/common/kv_cache_mem_manager/__init__.py
+++ b/lightllm/common/kv_cache_mem_manager/__init__.py
@@ -4,6 +4,7 @@
from .ppl_int8kv_mem_manager import PPLINT8KVMemoryManager
from .ppl_int4kv_mem_manager import PPLINT4KVMemoryManager
from .deepseek2_mem_manager import Deepseek2MemoryManager
+from .neo_mem_manager import NeoMemoryManager
__all__ = [
"MemoryManager",
@@ -13,4 +14,5 @@
"PPLINT4KVMemoryManager",
"PPLINT8KVMemoryManager",
"Deepseek2MemoryManager",
+ "NeoMemoryManager",
]
diff --git a/lightllm/common/kv_cache_mem_manager/mem_manager.py b/lightllm/common/kv_cache_mem_manager/mem_manager.py
index 1203cbdec7..ce628aa8f7 100755
--- a/lightllm/common/kv_cache_mem_manager/mem_manager.py
+++ b/lightllm/common/kv_cache_mem_manager/mem_manager.py
@@ -61,7 +61,7 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False
self.size,
dtype,
head_num,
- head_dim,
+ self.head_dim,
layer_num,
)
self.HOLD_TOKEN_MEMINDEX = self.size
diff --git a/lightllm/common/kv_cache_mem_manager/mem_utils.py b/lightllm/common/kv_cache_mem_manager/mem_utils.py
index 1ff58b89a0..686993dde4 100644
--- a/lightllm/common/kv_cache_mem_manager/mem_utils.py
+++ b/lightllm/common/kv_cache_mem_manager/mem_utils.py
@@ -5,6 +5,7 @@
PPLINT8KVMemoryManager,
PPLINT4KVMemoryManager,
Deepseek2MemoryManager,
+ NeoMemoryManager,
)
from lightllm.utils.log_utils import init_logger
from lightllm.utils.envs_utils import get_env_start_args
@@ -19,12 +20,16 @@ def select_mem_manager_class():
# case 1
# 先判断是否是 deepseek 系列的模型
model_class = get_llm_model_class()
- from lightllm.models import Deepseek2TpPartModel
+ from lightllm.models import Deepseek2TpPartModel, NeoTpMOEPartModel, NeoTpPartModel
if issubclass(model_class, Deepseek2TpPartModel):
mem_class = Deepseek2MemoryManager
logger.info(f"Model kv cache using default, mem_manager class: {mem_class}")
return mem_class
+ # 判断是否是 neo 系列的模型
+ elif issubclass(model_class, NeoTpMOEPartModel) or issubclass(model_class, NeoTpPartModel):
+ mem_class = NeoMemoryManager
+ return mem_class
# case normal
logger.info(f"mode setting params: {get_env_start_args().llm_kv_type}")
diff --git a/lightllm/common/kv_cache_mem_manager/neo_mem_manager.py b/lightllm/common/kv_cache_mem_manager/neo_mem_manager.py
new file mode 100755
index 0000000000..0a79aa072b
--- /dev/null
+++ b/lightllm/common/kv_cache_mem_manager/neo_mem_manager.py
@@ -0,0 +1,46 @@
+import torch
+from lightllm.utils.dist_utils import get_current_rank_in_node
+from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt
+from lightllm.common.kv_cache_mem_manager.mem_manager import MemoryManager
+
+
+class NeoMemoryManager(MemoryManager):
+ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9):
+ self.size = size
+ self.head_num = head_num
+ self.head_dim = head_dim * 2 # neo kv 是[k, k_h, k_w]拼在一起的
+ self.layer_num = layer_num
+ self.always_copy = always_copy
+ self.dtype = dtype
+ # profile the max total token num if the size is None
+ self.profile_size(mem_fraction)
+
+ self.mem_state = torch.arange(
+ 0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True
+ )
+ self._mem_state_return = torch.arange(
+ 0, self.size * 3, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True
+ )
+ self._return_start = 0
+ self.mark_start = 0
+ self.mark_end = self.size
+
+ self.can_use_mem_size = self.size
+
+ # 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。
+ from lightllm.utils.envs_utils import get_unique_server_name
+
+ rank_in_node = get_current_rank_in_node()
+ self.shared_can_use_token_num = SharedInt(
+ f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}"
+ )
+
+ self.shared_can_use_token_num.set_value(self.can_use_mem_size)
+ self._init_buffers(
+ self.size,
+ dtype,
+ head_num,
+ self.head_dim,
+ layer_num,
+ )
+ self.HOLD_TOKEN_MEMINDEX = self.size
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=768,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=768,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H200.json
new file mode 100644
index 0000000000..c75c871c72
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=768,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H200.json
@@ -0,0 +1,110 @@
+{
+ "1": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 5,
+ "num_warps": 4
+ },
+ "100": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 8
+ },
+ "1024": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "128": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "16": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_SIZE_K": 32,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "256": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "32": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "4096": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "64": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "8": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "8448": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=384,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=384,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json
new file mode 100644
index 0000000000..14026090e6
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=384,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json
@@ -0,0 +1,110 @@
+{
+ "1024": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "128": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "16384": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "256": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "32768": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "512": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "64": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "67584": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "8": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 8
+ },
+ "800": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "8192": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H200.json
new file mode 100644
index 0000000000..939c939523
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H200.json
@@ -0,0 +1,74 @@
+{
+ "1": {
+ "BLOCK_DIM": 128,
+ "BLOCK_M": 2,
+ "NUM_STAGE": 2,
+ "num_warps": 4
+ },
+ "100": {
+ "BLOCK_DIM": 512,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "1024": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 4
+ },
+ "128": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 16
+ },
+ "16": {
+ "BLOCK_DIM": 128,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 2,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_DIM": 256,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 2
+ },
+ "256": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "32": {
+ "BLOCK_DIM": 256,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 2
+ },
+ "4096": {
+ "BLOCK_DIM": 512,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 4
+ },
+ "64": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "8": {
+ "BLOCK_DIM": 128,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "8448": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=384,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=384,out_dtype=torch.bfloat16}_NVIDIA_H200.json
new file mode 100644
index 0000000000..13ba4ba8e5
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=384,out_dtype=torch.bfloat16}_NVIDIA_H200.json
@@ -0,0 +1,74 @@
+{
+ "1024": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "128": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 4
+ },
+ "16384": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "2048": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "256": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 2,
+ "num_warps": 4
+ },
+ "32768": {
+ "BLOCK_M": 32,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "512": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 8
+ },
+ "64": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "67584": {
+ "BLOCK_M": 64,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "8": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 4
+ },
+ "800": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 4
+ },
+ "8192": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..ee316f610b
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,110 @@
+{
+ "1024": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 8
+ },
+ "128": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "16384": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "256": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "32768": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "512": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "64": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "67584": {
+ "BLOCK_SIZE_K": 32,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "8": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "800": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 8
+ },
+ "8192": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..e027701092
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,110 @@
+{
+ "1024": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "128": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "16384": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "256": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 32,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "32768": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "512": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "64": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "67584": {
+ "BLOCK_SIZE_K": 32,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "8": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "800": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "8192": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..ddda23d257
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,110 @@
+{
+ "1024": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": true,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "128": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": true,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "16384": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": true,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "256": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": true,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "32768": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "512": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "64": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": true,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "67584": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "8": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": true,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "800": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": true,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "8192": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 4,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..560ca6c09d
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,110 @@
+{
+ "1": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 5,
+ "num_warps": 4
+ },
+ "100": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "1024": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "128": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 8
+ },
+ "16": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "256": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "32": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "4096": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "64": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "8": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "8448": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 8
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..0713de7996
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,110 @@
+{
+ "1": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "100": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "1024": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "128": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 8
+ },
+ "16": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "256": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "32": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "4096": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "64": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "8": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "8448": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 8
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..e950ff0954
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,110 @@
+{
+ "1": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": true,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "100": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": true,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "1024": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "128": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": true,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "16": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": true,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": true,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "256": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": true,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "32": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": true,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "4096": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 8
+ },
+ "64": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": true,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "8": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": true,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "8448": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 8
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_align_fused:v1/{topk_num=8}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_align_fused:v1/{topk_num=8}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..7f479b8382
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_align_fused:v1/{topk_num=8}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,50 @@
+{
+ "1": {
+ "BLOCK_SIZE": 256,
+ "num_warps": 2
+ },
+ "100": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 8
+ },
+ "1024": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 2
+ },
+ "128": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 8
+ },
+ "16": {
+ "BLOCK_SIZE": 256,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_SIZE": 256,
+ "num_warps": 8
+ },
+ "256": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 8
+ },
+ "32": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 8
+ },
+ "4096": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 1
+ },
+ "64": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 8
+ },
+ "8": {
+ "BLOCK_SIZE": 256,
+ "num_warps": 2
+ },
+ "8448": {
+ "BLOCK_SIZE": 256,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..b3051c6584
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,74 @@
+{
+ "1": {
+ "BLOCK_DIM": 128,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 2,
+ "num_warps": 8
+ },
+ "100": {
+ "BLOCK_DIM": 512,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "1024": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 2,
+ "NUM_STAGE": 4,
+ "num_warps": 4
+ },
+ "128": {
+ "BLOCK_DIM": 512,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "16": {
+ "BLOCK_DIM": 128,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 2,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_DIM": 512,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 4
+ },
+ "256": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 4,
+ "num_warps": 2
+ },
+ "32": {
+ "BLOCK_DIM": 256,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 2
+ },
+ "4096": {
+ "BLOCK_DIM": 256,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 1
+ },
+ "64": {
+ "BLOCK_DIM": 512,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 4
+ },
+ "8": {
+ "BLOCK_DIM": 256,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 2
+ },
+ "8448": {
+ "BLOCK_DIM": 256,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 2
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.float16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.float16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..fdb3212216
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.float16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,74 @@
+{
+ "1": {
+ "BLOCK_DIM": 256,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 4,
+ "num_warps": 8
+ },
+ "100": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "1024": {
+ "BLOCK_DIM": 512,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 2
+ },
+ "128": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "16": {
+ "BLOCK_DIM": 128,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 4,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "256": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "32": {
+ "BLOCK_DIM": 256,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 2,
+ "num_warps": 8
+ },
+ "4096": {
+ "BLOCK_DIM": 256,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 2
+ },
+ "64": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "8": {
+ "BLOCK_DIM": 128,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 4
+ },
+ "8448": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..a94e669353
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,74 @@
+{
+ "1024": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 4
+ },
+ "128": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "16384": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "2048": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 4
+ },
+ "256": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 4
+ },
+ "32768": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "512": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 4
+ },
+ "64": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 64,
+ "NUM_STAGES": 1,
+ "num_warps": 4
+ },
+ "67584": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "8": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 64,
+ "NUM_STAGES": 1,
+ "num_warps": 4
+ },
+ "800": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 4
+ },
+ "8192": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.float16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.float16}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..441421fd5d
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.float16}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,74 @@
+{
+ "1024": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "128": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 64,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "16384": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "2048": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "256": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "32768": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "512": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "64": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 64,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "67584": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "8": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 64,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "800": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 4
+ },
+ "8192": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ }
+}
\ No newline at end of file
diff --git a/lightllm/models/__init__.py b/lightllm/models/__init__.py
index 095f736791..81208286b2 100644
--- a/lightllm/models/__init__.py
+++ b/lightllm/models/__init__.py
@@ -37,4 +37,6 @@
Tarsier2LlamaTpPartModel,
)
from lightllm.models.gpt_oss.model import GptOssTpPartModel
+from lightllm.models.neo_chat_moe.model import NeoTpMOEPartModel
+from lightllm.models.neo_chat.model import NeoTpPartModel
from .registry import get_model, get_model_class
diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py
index 98cc7c229e..97015f6b20 100644
--- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py
+++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py
@@ -312,6 +312,7 @@ def _moe_ffn(
use_grouped_topk=self.n_group,
topk_group=self.topk_group,
num_expert_group=self.n_group,
+ microbatch_index=infer_state.microbatch_index,
)
if self.n_shared_experts is not None and layer_weight.num_fused_shared_experts == 0:
@@ -339,6 +340,7 @@ def _moe_ffn_edp(
topk_group=self.topk_group,
num_expert_group=self.n_group,
is_prefill=infer_state.is_prefill,
+ microbatch_index=infer_state.microbatch_index,
)
if self.n_shared_experts is not None:
diff --git a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py
index 3eb09f9176..bd72035072 100644
--- a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py
+++ b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py
@@ -242,6 +242,9 @@ def _init_moe(self):
# == 0 时,说明不存在融合共享专家,共享专家单独加载和进行推理。
if self.num_fused_shared_experts == 0:
self._load_mlp(f"model.layers.{self.layer_num_}.mlp.shared_experts", is_shared_experts=True)
+ first_moe = self.network_config_["first_k_dense_replace"]
+ freq = self.network_config_.get("moe_layer_freq", 1)
+ moe_layer_index = (self.layer_num_ - first_moe) // freq
self.experts = FusedMoeWeight(
gate_proj_name="gate_proj",
down_proj_name="down_proj",
@@ -256,6 +259,7 @@ def _init_moe(self):
num_fused_shared_experts=self.num_fused_shared_experts,
layer_num=self.layer_num_,
network_config=self.network_config_,
+ moe_layer_index=moe_layer_index,
)
def _init_ffn(self):
diff --git a/lightllm/models/deepseek2/model.py b/lightllm/models/deepseek2/model.py
index e596eed97c..79bd327068 100644
--- a/lightllm/models/deepseek2/model.py
+++ b/lightllm/models/deepseek2/model.py
@@ -6,6 +6,7 @@
from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo
from lightllm.models.llama.model import LlamaTpPartModel
from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class
+from lightllm.common.basemodel.routing_manager import init_routing_capture
from lightllm.utils.log_utils import init_logger
from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args, get_added_mtp_kv_layer_num
from lightllm.distributed.communication_op import dist_group_manager
@@ -49,6 +50,9 @@ def _init_some_value(self):
def _init_custom(self):
self._init_to_get_yarn_rotary()
dist_group_manager.new_deepep_group(self.config["n_routed_experts"], self.config["hidden_size"])
+ if self.args.enable_return_routed_experts:
+ num_moe_layers = sum(1 for w in self.trans_layers_weight if w.is_moe)
+ init_routing_capture(self, num_moe_layers)
def _verify_params(self):
return super()._verify_params()
diff --git a/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py b/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py
index d80eefd16e..e5672f8210 100644
--- a/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py
+++ b/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py
@@ -51,6 +51,7 @@ def _ffn(self, input, infer_state, layer_weight: GptOssTransformerLayerWeight) -
use_grouped_topk=False,
topk_group=None,
num_expert_group=None,
+ microbatch_index=infer_state.microbatch_index,
)
return hidden_states.view(num_tokens, hidden_dim)
diff --git a/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py b/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py
index 7c8c30940e..7278c62fec 100644
--- a/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py
+++ b/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py
@@ -55,6 +55,7 @@ def _init_moe(self):
num_fused_shared_experts=0,
layer_num=self.layer_num_,
network_config=self.network_config_,
+ moe_layer_index=self.layer_num_,
)
def _init_weight_names(self):
diff --git a/lightllm/models/gpt_oss/model.py b/lightllm/models/gpt_oss/model.py
index 9e9561eb24..cff748933d 100644
--- a/lightllm/models/gpt_oss/model.py
+++ b/lightllm/models/gpt_oss/model.py
@@ -2,6 +2,7 @@
from lightllm.models.gpt_oss.layer_weights.transformer_layer_weight import GptOssTransformerLayerWeight
from lightllm.models.llama.model import LlamaTpPartModel
from lightllm.models.registry import ModelRegistry
+from lightllm.common.basemodel.routing_manager import init_routing_capture
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.utils.log_utils import init_logger
from lightllm.common.basemodel.attention import get_prefill_att_backend_class, get_decode_att_backend_class
@@ -21,6 +22,12 @@ class GptOssTpPartModel(LlamaTpPartModel):
def __init__(self, kvargs):
super().__init__(kvargs)
+ def _init_custom(self):
+ super()._init_custom()
+ if self.args.enable_return_routed_experts:
+ num_moe_layers = len(self.trans_layers_weight)
+ init_routing_capture(self, num_moe_layers)
+
def _init_att_backend(self):
self.prefill_att_backend: BaseAttBackend = get_prefill_att_backend_class(index=0, priority_list=["fa3"])(
model=self
diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py
old mode 100644
new mode 100755
diff --git a/lightllm/models/llama/model.py b/lightllm/models/llama/model.py
index c104ebccc9..cc1dc28178 100644
--- a/lightllm/models/llama/model.py
+++ b/lightllm/models/llama/model.py
@@ -74,14 +74,19 @@ def _init_custom(self):
rope_scaling = self.config.get("rope_scaling", None)
if rope_scaling is None:
self._init_to_get_rotary()
- return
-
- if "rope_type" in rope_scaling:
+ elif "rope_type" in rope_scaling:
scaling_type = rope_scaling["rope_type"]
+ self._init_rotary_by_scaling_type(scaling_type, rope_scaling)
elif "type" in rope_scaling:
scaling_type = rope_scaling["type"]
+ self._init_rotary_by_scaling_type(scaling_type, rope_scaling)
else:
raise ValueError(f"Unknown RoPE scaling format {rope_scaling}")
+ if "rope_theta_hw" in self.config:
+ self._init_to_get_hw_rotary()
+ super()._init_custom()
+
+ def _init_rotary_by_scaling_type(self, scaling_type, rope_scaling):
if scaling_type == "default" or "mrope_section" in rope_scaling:
self._init_to_get_rotary()
elif scaling_type == "yarn":
@@ -96,7 +101,6 @@ def _init_custom(self):
self._init_to_get_rotary()
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
- return
def _init_to_get_rotary(self, default_base=10000):
partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_)
@@ -106,7 +110,6 @@ def _init_to_get_rotary(self, default_base=10000):
rope_scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0)
base = self.config.get("rope_theta", float(default_base))
-
if "max_sequence_length" in self.config:
max_seq_len = self.config["max_sequence_length"]
else:
@@ -139,6 +142,46 @@ def _init_to_get_rotary(self, default_base=10000):
self._sin_cached = torch.sin(freqs).to(self.data_type).cuda()
return
+ def _init_to_get_hw_rotary(self, default_base=10000):
+ partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_ // 2)
+ if self.config.get("rope_scaling", {}) is None:
+ rope_scaling_factor = 1.0
+ else:
+ rope_scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0)
+
+ base = self.config.get("rope_theta_hw", float(default_base))
+ if "max_sequence_length" in self.config:
+ max_seq_len = self.config["max_sequence_length"]
+ else:
+ max_position_embeddings = self.config.get(
+ "max_position_embeddings_hw", 2048 if base <= 10000.0 + 1e-5 else 16384
+ )
+ max_seq_len = max_position_embeddings * rope_scaling_factor
+
+ # NTK
+ try:
+ ntk_alpha = float(os.environ.get("LIGHTLLM_NTK_ALPHA", 1))
+ assert ntk_alpha >= 1
+ if ntk_alpha > 1:
+ logger.info(f"Note: NTK enabled, alpha set to {ntk_alpha}")
+ max_seq_len *= ntk_alpha
+ base = base * (ntk_alpha ** (partial_head_dim / (partial_head_dim - 2))) # Base change formula
+ except:
+ pass
+
+ inv_freq = 1.0 / (
+ base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim)
+ )
+ t = (
+ torch.arange(max(max_seq_len + 1024 * 128, self.max_seq_length), device="cpu", dtype=torch.float32)
+ / rope_scaling_factor
+ )
+ freqs = torch.outer(t, inv_freq)
+
+ self._hw_cos_cached = torch.cos(freqs).to(self.data_type).cuda()
+ self._hw_sin_cached = torch.sin(freqs).to(self.data_type).cuda()
+ return
+
def _init_to_get_dynamic_ntk_rotary(self):
partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_)
max_position_embeddings = self.config.get("max_position_embeddings", 2048)
diff --git a/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py b/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py
index 45de83e989..223a64ad51 100644
--- a/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py
+++ b/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py
@@ -75,7 +75,8 @@ def token_att_fwd(q, k, att_out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen
Lq, Lk = q.shape[-1], k.shape[-1]
assert Lq == Lk
assert Lk in {16, 32, 64, 128, 256}
- sm_scale = 1.0 / (Lk ** 0.5)
+ Lk_scale = Lk // 2
+ sm_scale = 1.0 / (Lk_scale ** 0.5)
batch, head_num = B_req_idx.shape[0], q.shape[1]
diff --git a/lightllm/models/mixtral/layer_infer/_custom_ops.py b/lightllm/models/mixtral/layer_infer/_custom_ops.py
deleted file mode 100644
index b0e27ac1de..0000000000
--- a/lightllm/models/mixtral/layer_infer/_custom_ops.py
+++ /dev/null
@@ -1,46 +0,0 @@
-import functools
-import json
-import os
-from typing import Any, Dict, Optional, Tuple
-
-import torch
-import triton
-import triton.language as tl
-from lightllm.utils.log_utils import init_logger
-
-logger = init_logger(__name__)
-
-# Pytorch version
-# Triton version in progress
-def topk_softmax(
- topk_weights,
- topk_ids,
- token_expert_indicies,
- gating_output,
- topk=2,
-):
- scores = torch.softmax(gating_output, dim=-1)
- topk_weights, topk_ids = torch.topk(scores, k=topk, dim=-1, sorted=False)
- return topk_weights, topk_ids
-
-
-def fused_topk(
- hidden_states: torch.Tensor,
- gating_output: torch.Tensor,
- topk: int,
- renormalize: bool,
- alloc_tensor_func=torch.empty,
-):
- assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
-
- M, _ = hidden_states.shape
-
- topk_weights = alloc_tensor_func((M, topk), dtype=torch.float32, device=hidden_states.device)
- topk_ids = alloc_tensor_func((M, topk), dtype=torch.int32, device=hidden_states.device)
- token_expert_indicies = alloc_tensor_func((M, topk), dtype=torch.int32, device=hidden_states.device)
- topk_weights, topk_ids = topk_softmax(topk_weights, topk_ids, token_expert_indicies, gating_output.float(), topk)
- del token_expert_indicies # Not used. Will be used in the future.
-
- if renormalize:
- topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
- return topk_weights, topk_ids
diff --git a/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py b/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py
index 44e66cff2d..a2968f5ab1 100644
--- a/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py
+++ b/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py
@@ -1,9 +1,6 @@
-import os
import torch
-import torch.nn.functional as F
from lightllm.common.basemodel.infer_struct import InferStateInfo
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
-from lightllm.models.mixtral.layer_infer._custom_ops import fused_topk
from lightllm.models.mixtral.layer_weights.transformer_layer_weight import MixtralTransformerLayerWeight
@@ -19,25 +16,15 @@ def _ffn(self, input, infer_state: InferStateInfo, layer_weight: MixtralTransfor
hidden_states = input.view(-1, self.embed_dim_)
num_tokens, hidden_dim = hidden_states.shape
- router_logits = layer_weight.moe_gate.mm(input.view(-1, self.embed_dim_))
- topk_weights, topk_ids = fused_topk(
- hidden_states=hidden_states,
- gating_output=router_logits,
- topk=self.num_experts_per_tok,
+ router_logits = layer_weight.moe_gate.mm(hidden_states)
+ layer_weight.experts.experts(
+ hidden_states,
+ router_logits=router_logits,
+ top_k=self.num_experts_per_tok,
renormalize=self.renormalize,
- alloc_tensor_func=self.alloc_tensor,
- )
- from lightllm.common.fused_moe.grouped_fused_moe import fused_experts_impl
-
- return fused_experts_impl(
- hidden_states=hidden_states,
- w1=layer_weight.experts.w1[0],
- w2=layer_weight.experts.w2[0],
- topk_weights=topk_weights,
- topk_ids=topk_ids,
- inplace=True,
- use_fp8_w8a8=False,
- w1_scale=None,
- w2_scale=None,
- alloc_tensor_func=self.alloc_tensor,
+ use_grouped_topk=False,
+ topk_group=None,
+ num_expert_group=None,
+ microbatch_index=getattr(infer_state, "microbatch_index", 0),
)
+ return hidden_states.view(num_tokens, hidden_dim)
diff --git a/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py b/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py
index 51c62fd4cb..d93cb5fb58 100644
--- a/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py
+++ b/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py
@@ -57,4 +57,5 @@ def _init_moe(self):
quant_method=self.quant_cfg.get_quant_method(self.layer_num_, "fused_moe"),
layer_num=self.layer_num_,
network_config=self.network_config_,
+ moe_layer_index=self.layer_num_,
)
diff --git a/lightllm/models/mixtral/model.py b/lightllm/models/mixtral/model.py
index 3c2d7b4e87..35bf38de58 100644
--- a/lightllm/models/mixtral/model.py
+++ b/lightllm/models/mixtral/model.py
@@ -2,6 +2,7 @@
import numpy as np
from lightllm.models.registry import ModelRegistry
from lightllm.common.basemodel.basemodel import TpPartBaseModel
+from lightllm.common.basemodel.routing_manager import init_routing_capture
from lightllm.common.kv_cache_mem_manager import MemoryManager
from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer
from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer
@@ -45,6 +46,9 @@ def _verify_params(self):
def _init_custom(self):
self._init_to_get_rotary()
+ if self.args.enable_return_routed_experts:
+ num_moe_layers = len(self.trans_layers_weight)
+ init_routing_capture(self, num_moe_layers)
return
def _init_mem_manager(self):
diff --git a/lightllm/models/neo_chat/__init__.py b/lightllm/models/neo_chat/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/lightllm/models/neo_chat/layer_infer/__init__.py b/lightllm/models/neo_chat/layer_infer/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py
new file mode 100644
index 0000000000..a3436b28ee
--- /dev/null
+++ b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py
@@ -0,0 +1,159 @@
+import torch
+from functools import partial
+from typing import Tuple
+from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd
+from lightllm.models.neo_chat_moe.infer_struct import NeoChatInferStateInfo
+from lightllm.models.neo_chat_moe.triton_kernel.context_attention_fwd_neo import context_attention_fwd_neo
+from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd
+from lightllm.models.qwen3.layer_infer.transformer_layer_infer import Qwen3TransformerLayerInfer
+from lightllm.models.neo_chat.layer_weights.transformer_layer_weight import NeoChatTransformerLayerWeight
+from lightllm.distributed import all_reduce
+import torch.distributed as dist
+from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
+
+
+class NeoChatTransformerLayerInfer(Qwen3TransformerLayerInfer):
+ def __init__(self, data_type, network_config):
+ super().__init__(data_type, network_config)
+ return
+
+ def _bind_attention(self):
+ self._context_attention_kernel = self._context_attention_kernel
+ self._token_attention_kernel = self._token_decode_attention_normal
+ self._copy_kv_to_mem_cache = self._copy_kv_to_mem_cache_normal
+ return
+
+ def _get_qkv(self, input, infer_state: NeoChatInferStateInfo, layer_weight: NeoChatTransformerLayerWeight):
+ input = input.view(-1, self.embed_dim_)
+ q = layer_weight.q_proj.mm(input) # [T, Hq*D]
+
+ q_hw = layer_weight.q_hw_proj.mm(input)
+ q_hw = q_hw.view(-1, self.tp_q_head_num_, self.head_dim_)
+ q_h, q_w = q_hw.chunk(2, dim=-1)
+
+ k_hw = layer_weight.k_hw_proj.mm(input)
+ k_hw = k_hw.view(-1, self.tp_k_head_num_, self.head_dim_)
+ k_h, k_w = k_hw.chunk(2, dim=-1)
+
+ cache_kv = layer_weight.kv_proj.mm(input) # [T, (Hk+Hv)*D]
+
+ layer_weight.q_norm_weight_(q, eps=self.eps_)
+
+ q_h_2d = q_h.reshape(q.shape[0], -1)
+ q_w_2d = q_w.reshape(q.shape[0], -1)
+ layer_weight.q_norm_h_weight_(q_h_2d, eps=self.eps_)
+ layer_weight.q_norm_w_weight_(q_w_2d, eps=self.eps_)
+ q_h = q_h_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2)
+ q_w = q_w_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2)
+
+ layer_weight.k_norm_weight_(
+ cache_kv[:, : self.tp_k_head_num_ * self.head_dim_],
+ eps=self.eps_,
+ )
+
+ k_h_2d = k_h.reshape(q.shape[0], -1) # [T, Hk*(D/2)]
+ k_w_2d = k_w.reshape(q.shape[0], -1)
+ layer_weight.k_norm_h_weight_(k_h_2d, eps=self.eps_)
+ layer_weight.k_norm_w_weight_(k_w_2d, eps=self.eps_)
+ k_h = k_h_2d.view(q.shape[0], self.tp_k_head_num_, self.head_dim_ // 2)
+ k_w = k_w_2d.view(q.shape[0], self.tp_k_head_num_, self.head_dim_ // 2)
+
+ cache_kv = cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
+
+ rotary_emb_fwd(
+ q.view(-1, self.tp_q_head_num_, self.head_dim_),
+ cache_kv[:, : self.tp_k_head_num_, :],
+ infer_state.position_cos,
+ infer_state.position_sin,
+ )
+ rotary_emb_fwd(
+ q_h,
+ k_h,
+ infer_state.position_cos_h,
+ infer_state.position_sin_h,
+ )
+ rotary_emb_fwd(
+ q_w,
+ k_w,
+ infer_state.position_cos_w,
+ infer_state.position_sin_w,
+ )
+
+ q3 = q.view(-1, self.tp_q_head_num_, self.head_dim_)
+ q3 = torch.cat([q3, q_h, q_w], dim=-1)
+ q = q3.reshape(q3.shape[0], -1)
+
+ k = cache_kv[:, : self.tp_k_head_num_, :]
+ k = torch.cat([k, k_h, k_w], dim=-1)
+
+ v = cache_kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :]
+ v_pad = torch.zeros((v.shape[0], v.shape[1], self.head_dim_), device=v.device, dtype=v.dtype)
+ v = torch.cat([v, v_pad], dim=-1)
+
+ cache_kv = torch.cat([k, v], dim=1)
+ return q, cache_kv
+
+ def _context_attention_kernel(
+ self, q, kv, infer_state: NeoChatInferStateInfo, layer_weight, out=None
+ ) -> torch.Tensor:
+ o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out
+ kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
+ context_attention_fwd_neo(
+ q.view(-1, self.tp_q_head_num_, self.head_dim_ * 2),
+ kv[:, 0 : self.tp_k_head_num_, :],
+ kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :],
+ o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2),
+ infer_state.position_ids[0], # [0,0,1,2,3,3,3,4]
+ infer_state.b_req_idx,
+ infer_state.b_start_loc,
+ infer_state.b_seq_len,
+ infer_state.b_ready_cache_len,
+ infer_state.max_len_in_batch,
+ infer_state.req_manager.req_to_token_indexs,
+ )
+ o3 = o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2)
+ o3 = o3[:, :, : self.head_dim_].contiguous()
+ return o3.view(o3.shape[0], -1)
+
+ def _token_attention_kernel(self, q, infer_state: NeoChatInferStateInfo, layer_weight):
+ total_token_num = infer_state.total_token_num
+ batch_size = infer_state.batch_size
+
+ q_3d = q.view(batch_size, self.tp_q_head_num_, self.head_dim_ * 2)
+
+ att_m_tensor = self.alloc_tensor((self.tp_q_head_num_, total_token_num), torch.float32)
+
+ k_3d = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :]
+ token_att_fwd(
+ q_3d,
+ k_3d,
+ att_m_tensor,
+ infer_state.req_manager.req_to_token_indexs,
+ infer_state.b_req_idx,
+ infer_state.b_kv_start_loc,
+ infer_state.b_seq_len,
+ infer_state.max_kv_seq_len,
+ )
+
+ from lightllm.common.basemodel.triton_kernel.att.decode_att.mha.stage3_decode_att import (
+ token_attention_softmax_and_reducev,
+ )
+
+ token_softmax_reducev_fwd = token_attention_softmax_and_reducev.token_softmax_reducev_fwd
+
+ v_3d = infer_state.mem_manager.kv_buffer[self.layer_num_][
+ :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : self.head_dim_
+ ]
+
+ o_3d = self.alloc_tensor((batch_size, self.tp_q_head_num_, self.head_dim_), q.dtype)
+
+ token_softmax_reducev_fwd(
+ att_m_tensor,
+ v_3d,
+ o_3d,
+ infer_state.req_manager.req_to_token_indexs,
+ infer_state.b_req_idx,
+ infer_state.b_kv_start_loc,
+ infer_state.b_seq_len,
+ )
+ return o_3d.view(batch_size, -1)
diff --git a/lightllm/models/neo_chat/layer_weights/__init__.py b/lightllm/models/neo_chat/layer_weights/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/lightllm/models/neo_chat/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/neo_chat/layer_weights/pre_and_post_layer_weight.py
new file mode 100644
index 0000000000..e6489f39af
--- /dev/null
+++ b/lightllm/models/neo_chat/layer_weights/pre_and_post_layer_weight.py
@@ -0,0 +1,23 @@
+import torch
+import numpy as np
+from lightllm.models.qwen2.layer_weights.pre_and_post_layer_weight import Qwen2PreAndPostLayerWeight
+
+# add key: language_model.xxx -> xxx
+# only change keys at PreAndPostLayerWeight load, TransformLayerWeight is correct now
+def rename_weight_keys(weights):
+ prefix = "language_model."
+ keys = list(weights.keys())
+ for k in keys:
+ if prefix in k:
+ weights[k.replace(prefix, "")] = weights.pop(k)
+
+
+class NeoChatPreAndPostLayerWeight(Qwen2PreAndPostLayerWeight):
+ def __init__(self, data_type, network_config):
+ super().__init__(data_type, network_config)
+ return
+
+ def load_hf_weights(self, weights):
+ rename_weight_keys(weights)
+ super().load_hf_weights(weights)
+ return
diff --git a/lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py b/lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py
new file mode 100644
index 0000000000..e62afae9bc
--- /dev/null
+++ b/lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py
@@ -0,0 +1,67 @@
+from lightllm.models.qwen3.layer_weights.transformer_layer_weight import Qwen3TransformerLayerWeight
+from lightllm.common.basemodel.layer_weights.meta_weights import (
+ QKRMSNORMWeight,
+ ROWMMWeight,
+)
+
+
+class NeoChatTransformerLayerWeight(Qwen3TransformerLayerWeight):
+ def __init__(self, layer_num, data_type, network_config, quant_cfg=None):
+ super().__init__(layer_num, data_type, network_config, quant_cfg)
+ return
+
+ def _init_weight_names(self):
+ super()._init_weight_names()
+ self._q_weight_hw_name = f"model.layers.{self.layer_num_}.self_attn.q_proj_hw.weight"
+ self._q_bias_hw_name = None
+ self._k_weight_hw_name = f"model.layers.{self.layer_num_}.self_attn.k_proj_hw.weight"
+ self._k_bias_hw_name = None
+
+ self._q_norm_h_name = f"model.layers.{self.layer_num_}.self_attn.q_norm_h.weight"
+ self._q_norm_w_name = f"model.layers.{self.layer_num_}.self_attn.q_norm_w.weight"
+
+ self._k_norm_h_name = f"model.layers.{self.layer_num_}.self_attn.k_norm_h.weight"
+ self._k_norm_w_name = f"model.layers.{self.layer_num_}.self_attn.k_norm_w.weight"
+
+ def _init_qkv(self):
+ super()._init_qkv()
+ self.q_hw_proj = ROWMMWeight(
+ in_dim=self.network_config_["hidden_size"],
+ out_dims=[self.q_head_num_ * self.head_dim],
+ weight_names=self._q_weight_hw_name,
+ data_type=self.data_type_,
+ bias_names=self._q_bias_hw_name,
+ quant_method=self.get_quant_method("q_hw_proj"),
+ )
+ self.k_hw_proj = ROWMMWeight(
+ in_dim=self.network_config_["hidden_size"],
+ out_dims=[self.k_head_num_ * self.head_dim],
+ weight_names=self._k_weight_hw_name,
+ data_type=self.data_type_,
+ bias_names=self._k_bias_hw_name,
+ quant_method=self.get_quant_method("k_hw_proj"),
+ )
+
+ def _init_norm(self):
+ super()._init_norm()
+
+ self.q_norm_h_weight_ = QKRMSNORMWeight(
+ dim=self.head_dim // 2,
+ weight_name=self._q_norm_h_name,
+ data_type=self.data_type_,
+ )
+ self.q_norm_w_weight_ = QKRMSNORMWeight(
+ dim=self.head_dim // 2,
+ weight_name=self._q_norm_w_name,
+ data_type=self.data_type_,
+ )
+ self.k_norm_h_weight_ = QKRMSNORMWeight(
+ dim=self.head_dim // 2,
+ weight_name=self._k_norm_h_name,
+ data_type=self.data_type_,
+ )
+ self.k_norm_w_weight_ = QKRMSNORMWeight(
+ dim=self.head_dim // 2,
+ weight_name=self._k_norm_w_name,
+ data_type=self.data_type_,
+ )
diff --git a/lightllm/models/neo_chat/model.py b/lightllm/models/neo_chat/model.py
new file mode 100644
index 0000000000..14d9f96dc7
--- /dev/null
+++ b/lightllm/models/neo_chat/model.py
@@ -0,0 +1,53 @@
+import os
+import json
+from lightllm.common.build_utils import repair_config
+from lightllm.models.registry import ModelRegistry, llm_model_type_is
+from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo
+from lightllm.models.qwen3_vl.layer_infer.pre_layer_infer import Qwen3VLMultimodalPreLayerInfer
+from lightllm.models.qwen3_vl.layer_infer.transformer_layer_infer import Qwen3VLTransformerLayerInfer
+from lightllm.models.qwen3_vl.layer_weights.pre_and_post_layer_weight import Qwen3VLPreAndPostLayerWeight
+from lightllm.models.qwen2_vl.model import QWen2VLTokenizer
+from lightllm.models.qwen3.model import Qwen3TpPartModel
+from lightllm.server.core.objs import SamplingParams
+from lightllm.models.qwen3_moe.model import Qwen3MOEModel
+from lightllm.server.multimodal_params import AudioItem, MultimodalParams, ImageItem
+from lightllm.models.neo_chat_moe.vision_process import smart_resize
+from lightllm.models.internvl.model import InternvlTokenizer
+from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer
+from lightllm.models.neo_chat.layer_infer.transformer_layer_infer import NeoChatTransformerLayerInfer
+from lightllm.models.llama.infer_struct import LlamaInferStateInfo
+from lightllm.models.neo_chat.layer_weights.transformer_layer_weight import NeoChatTransformerLayerWeight
+from lightllm.models.neo_chat.layer_weights.pre_and_post_layer_weight import NeoChatPreAndPostLayerWeight
+from lightllm.common.basemodel.multimodal_tokenizer import BaseMultiModalTokenizer
+from lightllm.models.neo_chat_moe.infer_struct import NeoChatInferStateInfo
+
+
+@ModelRegistry(["neo_chat"], is_multimodal=True, condition=llm_model_type_is("qwen3"))
+class NeoTpPartModel(Qwen3TpPartModel):
+
+ pre_layer_infer_class = LlamaMultimodalPreLayerInfer
+ transformer_layer_infer_class = NeoChatTransformerLayerInfer
+
+ pre_and_post_weight_class = NeoChatPreAndPostLayerWeight
+ transformer_weight_class = NeoChatTransformerLayerWeight
+
+ infer_state_class = NeoChatInferStateInfo
+
+ def __init__(self, kvargs):
+ super().__init__(kvargs)
+ return
+
+ def _init_inferstate_cls(self):
+ pass
+
+ def _init_config(self):
+ with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file:
+ all_config = json.load(json_file)
+ self.config = all_config["llm_config"]
+ # rename keys
+ repair_config(self.config, same_names=["num_attention_heads", "n_head"])
+ repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"])
+ repair_config(self.config, same_names=["num_hidden_layers", "n_layer"])
+ if self.finetune_config:
+ self.config["vocab_size"] = self.finetune_config.vocab_size
+ return
diff --git a/lightllm/models/neo_chat_moe/__init__.py b/lightllm/models/neo_chat_moe/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/lightllm/models/neo_chat_moe/infer_struct.py b/lightllm/models/neo_chat_moe/infer_struct.py
new file mode 100644
index 0000000000..961ed2a61d
--- /dev/null
+++ b/lightllm/models/neo_chat_moe/infer_struct.py
@@ -0,0 +1,103 @@
+from typing import Optional, List
+import torch
+import numpy as np
+from lightllm.models.llama.infer_struct import LlamaInferStateInfo
+from lightllm.common.req_manager import ReqManager
+from lightllm.models.neo_chat_moe.triton_kernel.get_neo_position import get_neo_position_triton
+from lightllm.models.llama.model import LlamaTpPartModel
+
+
+class NeoChatInferStateInfo(LlamaInferStateInfo):
+ def __init__(self):
+ super().__init__()
+ self.position_cos = None
+ self.position_sin = None
+ self.position_cos_h = None
+ self.position_sin_h = None
+ self.position_cos_w = None
+ self.position_sin_w = None
+
+ def init_some_extra_state(self, model: LlamaTpPartModel):
+ LlamaInferStateInfo.init_some_extra_state(self, model)
+ if self.is_prefill:
+ self.b_image_token_tag = torch.zeros([self.position_ids.size(0)], dtype=torch.bool, device="cpu").cuda(
+ non_blocking=True
+ )
+ self.position_ids = self.get_neo_position(self.multimodal_params)
+ else:
+ b_position_delta = [0 for _ in range(self.b_seq_len.shape[0])]
+ for batch_idx, p in enumerate(self.multimodal_params):
+ position_delta = 0
+ for image in p["images"]:
+ position_delta += image["grid_thwd"][3]
+ b_position_delta[batch_idx] = position_delta
+ position_ids = self.position_ids + torch.tensor(b_position_delta, device=self.position_ids.device)
+ self.position_ids = position_ids.unsqueeze(0).expand(3, -1).clone()
+ self.position_ids[1:].zero_()
+
+ self.position_ids = self.position_ids.contiguous()
+ self.position_cos = model._cos_cached[self.position_ids[0]]
+ self.position_sin = model._sin_cached[self.position_ids[0]]
+ self.position_cos_h = model._hw_cos_cached[self.position_ids[1]]
+ self.position_sin_h = model._hw_sin_cached[self.position_ids[1]]
+ self.position_cos_w = model._hw_cos_cached[self.position_ids[2]]
+ self.position_sin_w = model._hw_sin_cached[self.position_ids[2]]
+ return
+
+ def get_neo_position(self, multimodal_params: List[dict]) -> torch.Tensor:
+ if len(multimodal_params) == 0:
+ position_ids = self.position_ids.new_zeros((3, self.position_ids.size(0)))
+ position_ids[0].copy_(self.position_ids)
+ return position_ids
+ b_image_start_idx = []
+ b_image_nums = []
+ b_image_start_num = []
+ b_image_len = []
+ image_start_num = 0
+ b_image_thwd = []
+
+ # pad multimodal_params to batch size.
+ batch_size = self.b_q_seq_len.shape[0]
+ multimodal_params = multimodal_params + [
+ {"images": [], "audios": []} for _ in range(batch_size - len(multimodal_params))
+ ]
+
+ for _, p in enumerate(multimodal_params):
+ images = p.get("images", [])
+ for img in images:
+ b_image_start_idx.append(img["start_idx"])
+ a = img["start_idx"]
+ print(f"img start_idx: {a}")
+ b_image_len.append(img["token_num"])
+ b_image_thwd.append(img["grid_thwd"])
+ b_image_nums.append(len(images))
+ b_image_start_num.append(image_start_num)
+ image_start_num += len(images)
+
+ # 没有任何图片
+ if image_start_num == 0:
+ position_ids = self.position_ids.new_zeros((3, self.position_ids.size(0)))
+ position_ids[0].copy_(self.position_ids)
+ return position_ids.contiguous()
+ b_image_start_idx = torch.tensor(b_image_start_idx, device="cpu").cuda(non_blocking=True)
+ b_image_thwd = torch.tensor(b_image_thwd, device="cpu").cuda(non_blocking=True) # image_num x 4
+ b_image_nums = torch.tensor(b_image_nums, device="cpu").cuda(non_blocking=True)
+ b_image_start_num = torch.tensor(b_image_start_num, device="cpu").cuda(non_blocking=True)
+ b_image_len = torch.tensor(b_image_len, device="cpu").cuda(non_blocking=True)
+
+ position_ids = self.position_ids.new_zeros((3, self.position_ids.size(0)))
+ position_ids[0].copy_(self.position_ids)
+
+ get_neo_position_triton(
+ b_image_start_idx=b_image_start_idx,
+ b_image_thwd=b_image_thwd,
+ b_image_nums=b_image_nums,
+ b_image_start_num=b_image_start_num,
+ b_image_len=b_image_len,
+ position_ids=position_ids,
+ b_ready_cache_len=self.b_ready_cache_len,
+ b_q_seq_len=self.b_q_seq_len,
+ b_start_loc=self.b_q_start_loc,
+ b_image_token_tag=self.b_image_token_tag,
+ )
+ return position_ids
diff --git a/lightllm/models/neo_chat_moe/layer_infer/__init__.py b/lightllm/models/neo_chat_moe/layer_infer/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py
new file mode 100644
index 0000000000..1518d68748
--- /dev/null
+++ b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py
@@ -0,0 +1,250 @@
+import torch
+from functools import partial
+from typing import Tuple
+from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd
+from lightllm.models.neo_chat_moe.infer_struct import NeoChatInferStateInfo
+from lightllm.models.neo_chat_moe.triton_kernel.context_attention_fwd_neo import context_attention_fwd_neo
+from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd
+from lightllm.models.qwen3_moe.layer_infer.transformer_layer_infer import Qwen3MOETransformerLayerInfer
+from lightllm.models.neo_chat_moe.layer_weights.transformer_layer_weight import NeoChatMOETransformerLayerWeight
+from lightllm.distributed import all_reduce
+import torch.distributed as dist
+from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
+from lightllm.common.basemodel.attention.base_att import AttControl
+
+
+class NeoChatMOETransformerLayerInfer(Qwen3MOETransformerLayerInfer):
+ def __init__(self, data_type, network_config):
+ self._is_merge_kv = network_config.get("merge_kv", True)
+ super().__init__(data_type, network_config)
+ return
+
+ def _bind_attention(self):
+ self._context_attention_kernel = self._context_attention_kernel
+ self._token_attention_kernel = self._token_decode_attention_normal
+ self._copy_kv_to_mem_cache = self._copy_kv_to_mem_cache_normal
+ return
+
+ def _get_qkv(self, input, infer_state: NeoChatInferStateInfo, layer_weight: NeoChatMOETransformerLayerWeight):
+ if self._is_merge_kv:
+ return self._get_qkv_mergekv(input, infer_state, layer_weight)
+ else:
+ return self._get_qkv_not_mergekv(input, infer_state, layer_weight)
+
+ def _get_qkv_not_mergekv(
+ self, input, infer_state: NeoChatInferStateInfo, layer_weight: NeoChatMOETransformerLayerWeight
+ ):
+ input = input.view(-1, self.embed_dim_)
+ q = layer_weight.q_proj.mm(input) # [T, Hq*D]
+
+ q_hw = layer_weight.q_hw_proj.mm(input)
+ q_hw = q_hw.view(-1, self.tp_q_head_num_, self.head_dim_)
+ q_h, q_w = q_hw.chunk(2, dim=-1)
+
+ k_hw = layer_weight.k_hw_proj.mm(input)
+ k_hw = k_hw.view(-1, self.tp_k_head_num_, self.head_dim_)
+ k_h, k_w = k_hw.chunk(2, dim=-1)
+
+ cache_kv = layer_weight.kv_proj.mm(input) # [T, (Hk+Hv)*D]
+
+ layer_weight.q_norm_weight_(q, eps=self.eps_)
+
+ q_h_2d = q_h.reshape(q.shape[0], -1)
+ q_w_2d = q_w.reshape(q.shape[0], -1)
+ layer_weight.q_norm_h_weight_(q_h_2d, eps=self.eps_)
+ layer_weight.q_norm_w_weight_(q_w_2d, eps=self.eps_)
+ q_h = q_h_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2)
+ q_w = q_w_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2)
+
+ layer_weight.k_norm_weight_(
+ cache_kv[:, : self.tp_k_head_num_ * self.head_dim_],
+ eps=self.eps_,
+ )
+
+ k_h_2d = k_h.reshape(q.shape[0], -1) # [T, Hk*(D/2)]
+ k_w_2d = k_w.reshape(q.shape[0], -1)
+ layer_weight.k_norm_h_weight_(k_h_2d, eps=self.eps_)
+ layer_weight.k_norm_w_weight_(k_w_2d, eps=self.eps_)
+ k_h = k_h_2d.view(q.shape[0], self.tp_k_head_num_, self.head_dim_ // 2)
+ k_w = k_w_2d.view(q.shape[0], self.tp_k_head_num_, self.head_dim_ // 2)
+
+ cache_kv = cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
+
+ rotary_emb_fwd(
+ q.view(-1, self.tp_q_head_num_, self.head_dim_),
+ cache_kv[:, : self.tp_k_head_num_, :],
+ infer_state.position_cos,
+ infer_state.position_sin,
+ )
+ rotary_emb_fwd(
+ q_h,
+ k_h,
+ infer_state.position_cos_h,
+ infer_state.position_sin_h,
+ )
+ rotary_emb_fwd(
+ q_w,
+ k_w,
+ infer_state.position_cos_w,
+ infer_state.position_sin_w,
+ )
+
+ q3 = q.view(-1, self.tp_q_head_num_, self.head_dim_)
+ q3 = torch.cat([q3, q_h, q_w], dim=-1)
+ q = q3.reshape(q3.shape[0], -1)
+
+ k = cache_kv[:, : self.tp_k_head_num_, :]
+ k = torch.cat([k, k_h, k_w], dim=-1)
+
+ v = cache_kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :]
+ v_pad = torch.zeros((v.shape[0], v.shape[1], self.head_dim_), device=v.device, dtype=v.dtype)
+ v = torch.cat([v, v_pad], dim=-1)
+
+ cache_kv = torch.cat([k, v], dim=1)
+ return q, cache_kv
+
+ def _get_qkv_mergekv(
+ self, input, infer_state: NeoChatInferStateInfo, layer_weight: NeoChatMOETransformerLayerWeight
+ ):
+ input = input.view(-1, self.embed_dim_)
+
+ qkv = layer_weight.qkv_proj.mm(input)
+ q, cache_kv = qkv.split(
+ [self.tp_q_head_num_ * self.head_dim_, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_], dim=-1
+ )
+ q_hw = layer_weight.q_hw_proj.mm(input)
+ k_hw = layer_weight.k_hw_proj.mm(input)
+
+ layer_weight.q_norm_weight_(q, eps=self.eps_)
+ layer_weight.q_norm_hw_weight_(q_hw, eps=self.eps_)
+ layer_weight.k_norm_hw_weight_(k_hw, eps=self.eps_)
+
+ q_hw = q_hw.view(q.shape[0], self.tp_q_head_num_, self.head_dim_)
+ q_h, q_w = q_hw.chunk(2, dim=-1)
+
+ layer_weight.k_norm_weight_(
+ cache_kv[:, : self.tp_k_head_num_ * self.head_dim_],
+ eps=self.eps_,
+ )
+
+ k_hw = k_hw.view(q.shape[0], self.tp_k_head_num_, self.head_dim_)
+ k_h, k_w = k_hw.chunk(2, dim=-1)
+
+ cache_kv = cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
+
+ rotary_emb_fwd(
+ q.view(-1, self.tp_q_head_num_, self.head_dim_),
+ cache_kv[:, : self.tp_k_head_num_, :],
+ infer_state.position_cos,
+ infer_state.position_sin,
+ )
+ rotary_emb_fwd(
+ q_h,
+ k_h,
+ infer_state.position_cos_h,
+ infer_state.position_sin_h,
+ )
+ rotary_emb_fwd(
+ q_w,
+ k_w,
+ infer_state.position_cos_w,
+ infer_state.position_sin_w,
+ )
+
+ q3 = q.view(-1, self.tp_q_head_num_, self.head_dim_)
+ q3 = torch.cat([q3, q_h, q_w], dim=-1)
+ q = q3.reshape(q3.shape[0], -1)
+
+ k = cache_kv[:, : self.tp_k_head_num_, :]
+ k = torch.cat([k, k_h, k_w], dim=-1)
+
+ v = cache_kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :]
+ v_pad = torch.zeros((v.shape[0], v.shape[1], self.head_dim_), device=v.device, dtype=v.dtype)
+ v = torch.cat([v, v_pad], dim=-1)
+
+ cache_kv = torch.cat([k, v], dim=1)
+ return q, cache_kv
+
+ def _context_attention_kernel(
+ self, q, kv, infer_state: NeoChatInferStateInfo, layer_weight, out=None
+ ) -> torch.Tensor:
+ o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out
+ kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
+ context_attention_fwd_neo(
+ q.view(-1, self.tp_q_head_num_, self.head_dim_ * 2),
+ kv[:, 0 : self.tp_k_head_num_, :],
+ kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :],
+ o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2),
+ infer_state.position_ids[0], # [0,0,1,2,3,3,3,4]
+ infer_state.b_req_idx,
+ infer_state.b_q_start_loc,
+ infer_state.b_seq_len,
+ infer_state.b_ready_cache_len,
+ infer_state.max_q_seq_len,
+ infer_state.req_manager.req_to_token_indexs,
+ infer_state.b_image_token_tag,
+ )
+ o3 = o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2)
+ o3 = o3[:, :, : self.head_dim_].contiguous()
+ return o3.view(o3.shape[0], -1)
+
+ def _token_attention_kernel(
+ self,
+ q: torch.Tensor,
+ infer_state: NeoChatInferStateInfo,
+ layer_weight: NeoChatMOETransformerLayerWeight,
+ ) -> torch.Tensor:
+ _k, _v = infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_)
+ _q = q.view(-1, self.tp_q_head_num_, self.head_dim_ * 2)
+ att_control = AttControl()
+ if att_control.scale is None:
+ att_control.scale = 1.0 / (self.head_dim_ ** 0.5)
+ # att_control.mla_decode_dict["softmax_scale"] = 1.0 / (self.head_dim_ ** 0.5)
+ o_tensor = infer_state.decode_att_state.decode_att(
+ q=_q, k=_k, v=_v, att_control=att_control, alloc_func=self.alloc_tensor
+ )
+ o_tensor = o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2)[:, :, : self.head_dim_].contiguous()
+ return o_tensor
+
+ # def _token_attention_kernel(self, q, infer_state: NeoChatInferStateInfo, layer_weight):
+ # total_token_num = infer_state.total_token_num
+ # batch_size = infer_state.batch_size
+
+ # q_3d = q.view(batch_size, self.tp_q_head_num_, self.head_dim_ * 2)
+
+ # att_m_tensor = self.alloc_tensor((self.tp_q_head_num_, total_token_num), torch.float32)
+
+ # k_3d = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :]
+ # token_att_fwd(
+ # q_3d,
+ # k_3d,
+ # att_m_tensor,
+ # infer_state.req_manager.req_to_token_indexs,
+ # infer_state.b_req_idx,
+ # infer_state.b_kv_start_loc,
+ # infer_state.b_seq_len,
+ # infer_state.max_kv_seq_len,
+ # )
+
+ # from lightllm.common.basemodel.triton_kernel.att.decode_att.mha.stage3_decode_att import (
+ # token_attention_softmax_and_reducev,
+ # )
+
+ # token_softmax_reducev_fwd = token_attention_softmax_and_reducev.token_softmax_reducev_fwd
+
+ # v_3d = infer_state.mem_manager.kv_buffer[self.layer_num_][
+ # :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : self.head_dim_
+ # ]
+
+ # o_3d = self.alloc_tensor((batch_size, self.tp_q_head_num_, self.head_dim_), q.dtype)
+
+ # token_softmax_reducev_fwd(
+ # att_m_tensor,
+ # v_3d,
+ # o_3d,
+ # infer_state.req_manager.req_to_token_indexs,
+ # infer_state.b_req_idx,
+ # infer_state.b_kv_start_loc,
+ # infer_state.b_seq_len,
+ # )
+ # return o_3d.view(batch_size, -1)
diff --git a/lightllm/models/neo_chat_moe/layer_weights/__init__.py b/lightllm/models/neo_chat_moe/layer_weights/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/lightllm/models/neo_chat_moe/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/neo_chat_moe/layer_weights/pre_and_post_layer_weight.py
new file mode 100644
index 0000000000..4b0eae91c3
--- /dev/null
+++ b/lightllm/models/neo_chat_moe/layer_weights/pre_and_post_layer_weight.py
@@ -0,0 +1,23 @@
+import torch
+import numpy as np
+from lightllm.models.qwen2.layer_weights.pre_and_post_layer_weight import Qwen2PreAndPostLayerWeight
+
+# add key: language_model.xxx -> xxx
+# only change keys at PreAndPostLayerWeight load, TransformLayerWeight is correct now
+def rename_weight_keys(weights):
+ prefix = "language_model."
+ keys = list(weights.keys())
+ for k in keys:
+ if prefix in k:
+ weights[k.replace(prefix, "")] = weights.pop(k)
+
+
+class NeoChatMOEPreAndPostLayerWeight(Qwen2PreAndPostLayerWeight):
+ def __init__(self, data_type, network_config):
+ super().__init__(data_type, network_config)
+ return
+
+ def load_hf_weights(self, weights):
+ rename_weight_keys(weights)
+ super().load_hf_weights(weights)
+ return
diff --git a/lightllm/models/neo_chat_moe/layer_weights/transformer_layer_weight.py b/lightllm/models/neo_chat_moe/layer_weights/transformer_layer_weight.py
new file mode 100644
index 0000000000..26e986cdd7
--- /dev/null
+++ b/lightllm/models/neo_chat_moe/layer_weights/transformer_layer_weight.py
@@ -0,0 +1,83 @@
+from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight
+from lightllm.common.basemodel.layer_weights.meta_weights import (
+ QKRMSNORMWeight,
+ ROWMMWeight,
+)
+
+
+class NeoChatMOETransformerLayerWeight(Qwen3MOETransformerLayerWeight):
+ def __init__(self, layer_num, data_type, network_config, quant_cfg=None):
+ self._is_merge_kv = network_config.get("merge_kv", True)
+ super().__init__(layer_num, data_type, network_config, quant_cfg)
+ return
+
+ def _init_weight_names(self):
+ super()._init_weight_names()
+ self._q_weight_hw_name = f"model.layers.{self.layer_num_}.self_attn.q_proj_hw.weight"
+ self._q_bias_hw_name = None
+ self._k_weight_hw_name = f"model.layers.{self.layer_num_}.self_attn.k_proj_hw.weight"
+ self._k_bias_hw_name = None
+
+ if self._is_merge_kv:
+ self._q_norm_hw_name = f"model.layers.{self.layer_num_}.self_attn.q_norm_hw.weight"
+ self._k_norm_hw_name = f"model.layers.{self.layer_num_}.self_attn.k_norm_hw.weight"
+ else:
+ self._q_norm_h_name = f"model.layers.{self.layer_num_}.self_attn.q_norm_h.weight"
+ self._q_norm_w_name = f"model.layers.{self.layer_num_}.self_attn.q_norm_w.weight"
+
+ self._k_norm_h_name = f"model.layers.{self.layer_num_}.self_attn.k_norm_h.weight"
+ self._k_norm_w_name = f"model.layers.{self.layer_num_}.self_attn.k_norm_w.weight"
+
+ def _init_qkv(self):
+ super()._init_qkv()
+ self.q_hw_proj = ROWMMWeight(
+ in_dim=self.network_config_["hidden_size"],
+ out_dims=[self.q_head_num_ * self.head_dim],
+ weight_names=self._q_weight_hw_name,
+ data_type=self.data_type_,
+ bias_names=self._q_bias_hw_name,
+ quant_method=self.get_quant_method("q_hw_proj"),
+ )
+ self.k_hw_proj = ROWMMWeight(
+ in_dim=self.network_config_["hidden_size"],
+ out_dims=[self.k_head_num_ * self.head_dim],
+ weight_names=self._k_weight_hw_name,
+ data_type=self.data_type_,
+ bias_names=self._k_bias_hw_name,
+ quant_method=self.get_quant_method("k_hw_proj"),
+ )
+
+ def _init_norm(self):
+ super()._init_norm()
+ if self._is_merge_kv:
+ self.q_norm_hw_weight_ = QKRMSNORMWeight(
+ dim=self.head_dim,
+ weight_name=self._q_norm_hw_name,
+ data_type=self.data_type_,
+ )
+ self.k_norm_hw_weight_ = QKRMSNORMWeight(
+ dim=self.head_dim,
+ weight_name=self._k_norm_hw_name,
+ data_type=self.data_type_,
+ )
+ else:
+ self.q_norm_h_weight_ = QKRMSNORMWeight(
+ dim=self.head_dim // 2,
+ weight_name=self._q_norm_h_name,
+ data_type=self.data_type_,
+ )
+ self.q_norm_w_weight_ = QKRMSNORMWeight(
+ dim=self.head_dim // 2,
+ weight_name=self._q_norm_w_name,
+ data_type=self.data_type_,
+ )
+ self.k_norm_h_weight_ = QKRMSNORMWeight(
+ dim=self.head_dim // 2,
+ weight_name=self._k_norm_h_name,
+ data_type=self.data_type_,
+ )
+ self.k_norm_w_weight_ = QKRMSNORMWeight(
+ dim=self.head_dim // 2,
+ weight_name=self._k_norm_w_name,
+ data_type=self.data_type_,
+ )
diff --git a/lightllm/models/neo_chat_moe/model.py b/lightllm/models/neo_chat_moe/model.py
new file mode 100644
index 0000000000..cf4404090f
--- /dev/null
+++ b/lightllm/models/neo_chat_moe/model.py
@@ -0,0 +1,150 @@
+import os
+import json
+from lightllm.common.build_utils import repair_config
+from lightllm.models.registry import ModelRegistry, llm_model_type_is
+from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo
+from lightllm.models.qwen3_vl.layer_infer.pre_layer_infer import Qwen3VLMultimodalPreLayerInfer
+from lightllm.models.qwen3_vl.layer_infer.transformer_layer_infer import Qwen3VLTransformerLayerInfer
+from lightllm.models.qwen3_vl.layer_weights.pre_and_post_layer_weight import Qwen3VLPreAndPostLayerWeight
+from lightllm.models.qwen2_vl.model import QWen2VLTokenizer
+from lightllm.models.qwen3.model import Qwen3TpPartModel
+from lightllm.server.core.objs import SamplingParams
+from lightllm.models.qwen3_moe.model import Qwen3MOEModel
+from lightllm.server.multimodal_params import AudioItem, MultimodalParams, ImageItem
+from lightllm.models.neo_chat_moe.vision_process import smart_resize
+from lightllm.models.internvl.model import InternvlTokenizer
+from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer
+from lightllm.models.neo_chat_moe.layer_infer.transformer_layer_infer import NeoChatMOETransformerLayerInfer
+from lightllm.models.llama.infer_struct import LlamaInferStateInfo
+from lightllm.models.neo_chat_moe.layer_weights.transformer_layer_weight import NeoChatMOETransformerLayerWeight
+from lightllm.models.neo_chat_moe.layer_weights.pre_and_post_layer_weight import NeoChatMOEPreAndPostLayerWeight
+from lightllm.common.basemodel.multimodal_tokenizer import BaseMultiModalTokenizer
+from lightllm.models.neo_chat_moe.infer_struct import NeoChatInferStateInfo
+
+IMG_START_TOKEN = "
"
+IMG_END_TOKEN = ""
+IMG_TOKEN = ""
+AUDIO_START_TOKEN = ""
+
+
+class NeoChatTokenizer(BaseMultiModalTokenizer):
+ def __init__(self, tokenizer, model_cfg, **kwargs):
+ super().__init__(tokenizer)
+ self.tokenizer = tokenizer
+ self.min_pixel = model_cfg.get("vision_config").get("min_pixels")
+ self.max_pixel = model_cfg.get("vision_config").get("max_pixels")
+ self.patch_size = model_cfg.get("vision_config").get("patch_size")
+ self.downsample_ratio = model_cfg.get("vision_config").get("downsample_ratio")
+
+ self.image_token_id = model_cfg.get("image_token_id")
+ self.image_start_tag = IMG_START_TOKEN
+ self.image_start_id = tokenizer.convert_tokens_to_ids(self.image_start_tag)
+ self.image_end_tag = IMG_END_TOKEN
+ self.image_end_id = tokenizer.convert_tokens_to_ids(self.image_end_tag)
+
+ def init_imageitem_extral_params(
+ self, img: ImageItem, multi_params: MultimodalParams, sampling_params: SamplingParams
+ ):
+ img.extra_params["min_pixels"] = (
+ sampling_params.min_pixels if sampling_params.min_pixels > 0 else self.min_pixel
+ )
+ img.extra_params["max_pixels"] = (
+ sampling_params.max_pixels if sampling_params.max_pixels > 0 else self.max_pixel
+ )
+ assert (
+ img.extra_params["min_pixels"] <= img.extra_params["max_pixels"]
+ ), "min_pixels should be less than or equal to max_pixels"
+ return
+
+ def init_audioitem_extral_params(
+ self, audio: AudioItem, multi_params: MultimodalParams, sampling_params: SamplingParams
+ ):
+ raise NotImplementedError
+
+ def get_audio_token_length(self, audio: AudioItem):
+ raise NotImplementedError
+
+ def get_image_token_length(self, img: ImageItem):
+ width, height = img.image_w, img.image_h
+ resized_height, resized_width = smart_resize(
+ height=height,
+ width=width,
+ factor=int(self.patch_size // self.downsample_ratio),
+ min_pixels=img.extra_params["min_pixels"],
+ max_pixels=img.extra_params["max_pixels"],
+ )
+ grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
+ token_num = int((grid_h * grid_w) * (self.downsample_ratio ** 2))
+ # 这里的grid_h和grid_w需要* self.downsample_ratio么?再仔细看下代码
+ img.grid_thwd = (1, int(grid_h * self.downsample_ratio), int(grid_w * self.downsample_ratio), 1 - token_num)
+ return token_num
+
+ # only change the impl of the encode func:
+ def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs):
+ # TEXTTEXTTEXT --> TEXT
TEXT
TEXT
+ image_tokens = IMG_START_TOKEN + IMG_END_TOKEN
+ if multimodal_params is None:
+ add_special_tokens = kwargs.get("add_special_tokens", True)
+ return self.tokenizer.encode(prompt, add_special_tokens=add_special_tokens)
+ image_count = len(multimodal_params.images)
+ if not kwargs.get("already_tokenized", False):
+ prompt = prompt.replace(IMG_TOKEN, image_tokens, image_count)
+ origin_ids = self.tokenizer.encode(prompt, add_special_tokens=kwargs["add_special_tokens"])
+ else:
+ origin_ids = prompt
+ #
-->
id,id+1...id+num
+ input_ids = []
+ image_id = 0
+ start_idx = 0
+ while True:
+ try:
+ start_idx = origin_ids.index(self.image_start_id)
+ if start_idx + 1 >= len(origin_ids):
+ break
+ if origin_ids[start_idx + 1] == self.image_end_id:
+ input_ids.extend(origin_ids[: start_idx + 1])
+ token_id = multimodal_params.images[image_id].token_id
+ token_num = multimodal_params.images[image_id].token_num
+ multimodal_params.images[image_id].start_idx = len(input_ids)
+ input_ids.extend(range(token_id, token_id + token_num))
+ input_ids.append(self.image_end_id)
+ origin_ids = origin_ids[start_idx + 2 :]
+ image_id += 1
+ else:
+ raise ValueError("image token error")
+ except ValueError:
+ break
+ input_ids.extend(origin_ids)
+ return input_ids
+
+
+@ModelRegistry(["neo_chat"], is_multimodal=True, condition=llm_model_type_is("qwen3_moe"))
+class NeoTpMOEPartModel(Qwen3MOEModel):
+
+ pre_layer_infer_class = LlamaMultimodalPreLayerInfer
+ transformer_layer_infer_class = NeoChatMOETransformerLayerInfer
+
+ pre_and_post_weight_class = NeoChatMOEPreAndPostLayerWeight
+ transformer_weight_class = NeoChatMOETransformerLayerWeight
+
+ infer_state_class = NeoChatInferStateInfo
+
+ def __init__(self, kvargs):
+ super().__init__(kvargs)
+ return
+
+ def _init_inferstate_cls(self):
+ pass
+
+ def _init_config(self):
+ with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file:
+ all_config = json.load(json_file)
+ self.config = all_config["llm_config"]
+ # rename keys
+ repair_config(self.config, same_names=["num_attention_heads", "n_head"])
+ repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"])
+ repair_config(self.config, same_names=["num_hidden_layers", "n_layer"])
+ if self.finetune_config:
+ self.config["vocab_size"] = self.finetune_config.vocab_size
+ return
diff --git a/lightllm/models/neo_chat_moe/neo_visual.py b/lightllm/models/neo_chat_moe/neo_visual.py
new file mode 100644
index 0000000000..60fa82f2b9
--- /dev/null
+++ b/lightllm/models/neo_chat_moe/neo_visual.py
@@ -0,0 +1,281 @@
+import os
+import torch
+import torch.nn.functional as F
+from PIL import Image
+from typing import List
+from io import BytesIO
+import torch.nn as nn
+from transformers.activations import ACT2FN
+from safetensors import safe_open
+from lightllm.server.multimodal_params import ImageItem
+from transformers.modeling_outputs import BaseModelOutputWithPooling
+from transformers.modeling_utils import PreTrainedModel
+from lightllm.models.neo_chat_moe.vision_process import load_image_native
+from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data
+
+
+def apply_rotary_emb_1d(
+ x: torch.Tensor,
+ cos_cached: torch.Tensor,
+ sin_cached: torch.Tensor,
+ positions: torch.Tensor,
+):
+ """对输入张量的一部分应用1D RoPE。"""
+ # x: (..., seq_len, dim_part)
+ # positions: (..., seq_len)
+ # cos_cached: (max_pos, dim_part / 2)
+ cos_cached = cos_cached.to(device=positions.device)
+ sin_cached = sin_cached.to(device=positions.device)
+
+ cos = cos_cached[positions] # Shape: (positions.shape, dim_part / 2)
+ sin = sin_cached[positions] # Shape: (positions.shape, dim_part / 2)
+
+ x1 = x[..., 0::2]
+ x2 = x[..., 1::2]
+
+ rotated_x1 = x1 * cos - x2 * sin
+ rotated_x2 = x1 * sin + x2 * cos
+
+ x_rotated = torch.empty_like(x)
+ x_rotated[..., 0::2] = rotated_x1
+ x_rotated[..., 1::2] = rotated_x2
+ return x_rotated
+
+
+def apply_2d_rotary_pos_emb(
+ x: torch.Tensor,
+ cos_cached_x: torch.Tensor,
+ sin_cached_x: torch.Tensor,
+ cos_cached_y: torch.Tensor,
+ sin_cached_y: torch.Tensor,
+ abs_positions_x: torch.Tensor,
+ abs_positions_y: torch.Tensor,
+):
+ """应用2D RoPE到输入张量x。"""
+ dim = x.shape[-1]
+ dim_half = dim // 2
+
+ # 假设我们将embedding的前半部分用于一个方向的RoPE,后半部分用于另一个方向
+ # 例如,前一半给X坐标,后一半给Y坐标 (或者反过来,但要保持一致)
+ x_part_1 = x[..., :dim_half]
+ x_part_2 = x[..., dim_half:]
+
+ # 将与 abs_positions_x 相关的旋转应用于 x_part_1
+ rotated_part_1 = apply_rotary_emb_1d(x_part_1, cos_cached_x, sin_cached_x, abs_positions_x)
+ # 将与 abs_positions_y 相关的旋转应用于 x_part_2
+ rotated_part_2 = apply_rotary_emb_1d(x_part_2, cos_cached_y, sin_cached_y, abs_positions_y)
+
+ # 将它们重新拼接起来。确保顺序与你分割时一致。
+ return torch.cat((rotated_part_1, rotated_part_2), dim=-1)
+
+
+def build_abs_positions_from_grid_hw(grid_hw: torch.Tensor, device=None):
+ """
+ Compute patch coordinates (x, y)
+
+ Args:
+ grid_hw: (B, 2) tensor representing (H, W) per image
+ """
+ device = grid_hw.device
+ B = grid_hw.shape[0]
+
+ # Get the number of patches per image
+ H = grid_hw[:, 0]
+ W = grid_hw[:, 1]
+ N = H * W
+ N_total = N.sum()
+
+ # Create the batch index for each patch (B x patch count)
+ patch_to_sample = torch.repeat_interleave(torch.arange(B, device=device), N) # (N_total,)
+
+ # Generate intra-image patch index (row-major order)
+ patch_id_within_image = torch.arange(N_total, device=device)
+ patch_id_within_image = (
+ patch_id_within_image
+ - torch.cumsum(torch.cat([torch.tensor([0], device=device), N[:-1]]), dim=0)[patch_to_sample]
+ )
+
+ # Get H/W for each patch according to its image
+ W_per_patch = W[patch_to_sample]
+ abs_x = patch_id_within_image % W_per_patch
+ abs_y = patch_id_within_image // W_per_patch
+
+ return abs_x, abs_y
+
+
+class NeoVisionTransformerPretrainedModel(nn.Module):
+ def __init__(
+ self,
+ kvargs,
+ hidden_size: int = 1024,
+ llm_hidden_size: int = 2048,
+ downsample_ratio: float = 0.5,
+ patch_size: int = 16,
+ num_channels: int = 3,
+ max_position_embeddings_vision: int = 10000,
+ rope_theta_vision: float = 10000.0,
+ min_pixels: int = 65536,
+ max_pixels: int = 2408448,
+ **kwargs,
+ ):
+ super().__init__()
+ self.weight_dir = kvargs["weight_dir"]
+ self.data_type = kvargs.get("data_type", "bfloat16")
+ self.embed_dim = hidden_size
+ self.llm_hidden_size = llm_hidden_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.downsample_ratio = downsample_ratio
+ self.downsample_factor = int(1 / downsample_ratio)
+ self.max_position_embeddings_vision = max_position_embeddings_vision
+ self.rope_theta_vision = rope_theta_vision
+ self.rope_dim_part = self.embed_dim // 2
+ self.min_pixels = min_pixels
+ self.max_pixels = max_pixels
+
+ self.patch_embedding = nn.Conv2d(
+ in_channels=num_channels, out_channels=self.embed_dim, kernel_size=patch_size, stride=patch_size
+ )
+
+ self.dense_embedding = nn.Conv2d(
+ in_channels=self.embed_dim,
+ out_channels=self.llm_hidden_size,
+ kernel_size=self.downsample_factor,
+ stride=self.downsample_factor,
+ )
+ self.gelu = nn.GELU()
+
+ self.repe_dim_part = self.embed_dim // 2
+ self.cos_x, self.sin_x = self.precompute_rope_freqs_sincos()
+ self.cos_y, self.sin_y = self.precompute_rope_freqs_sincos()
+ self._init_datatype()
+
+ def _init_datatype(self):
+ if isinstance(self.data_type, torch.dtype):
+ return
+ if self.data_type in ["fp16", "float16"]:
+ self.data_type = torch.float16
+ elif self.data_type in ["bf16", "bfloat16"]:
+ self.data_type = torch.bfloat16
+ elif self.data_type in ["fp32", "float32"]:
+ self.data_type = torch.float32
+ else:
+ raise ValueError(f"Unsupport datatype {self.data_type}!")
+ return
+
+ def load_model(self, weight_dir):
+ bin_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")]
+ if bin_weight_files:
+ weight_dict = {}
+ for file_ in bin_weight_files:
+ f = torch.load(os.path.join(weight_dir, file_), "cpu")
+ for k, v in f.items():
+ if "vision_model" in k:
+ weight_dict[k[len("vision_model.embeddings.") :]] = v
+ else:
+ hf_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".safetensors")]
+ weight_dict = {}
+ for file_ in hf_weight_files:
+ f = safe_open(os.path.join(weight_dir, file_), "pt", "cpu")
+ for k in f.keys():
+ if "vision_model" in k:
+ weight_dict[k[len("vision_model.embeddings.") :]] = f.get_tensor(k)
+ self.load_state_dict(weight_dict)
+
+ def precompute_rope_freqs_sincos(self):
+ inv_freq = 1.0 / (
+ self.rope_theta_vision ** (torch.arange(0, self.rope_dim_part, 2).float() / self.rope_dim_part)
+ )
+ t = torch.arange(self.max_position_embeddings_vision).type_as(inv_freq)
+ freqs = torch.outer(t, inv_freq)
+ return torch.cos(freqs), torch.sin(freqs)
+
+ def _apply_2d_rotary_pos_emb(self, patch_embeds, grid_hw):
+ """
+ Apply 2D Rotary Position Embedding to the patch embeddings.
+ """
+ abs_pos_x, abs_pos_y = build_abs_positions_from_grid_hw(grid_hw, device=patch_embeds.device)
+ embeddings = apply_2d_rotary_pos_emb(
+ patch_embeds.to(torch.float32), # RoPE calculations are often more stable in float32
+ self.cos_x,
+ self.sin_x,
+ self.cos_y,
+ self.sin_y,
+ abs_pos_x,
+ abs_pos_y,
+ ).to(self.patch_embedding.weight.dtype)
+ return embeddings
+
+ def forward(self, pixel_values: torch.Tensor, grid_hw: torch.Tensor) -> torch.Tensor:
+ pixel_values = pixel_values.view(
+ -1,
+ 3,
+ self.patch_size,
+ self.patch_size,
+ )
+ patch_embeds = self.gelu(self.patch_embedding(pixel_values)).view(-1, self.embed_dim)
+ patch_embeds = self._apply_2d_rotary_pos_emb(patch_embeds, grid_hw)
+ assert (grid_hw[:, 0] * grid_hw[:, 1]).sum() == patch_embeds.shape[
+ 0
+ ], "Grid size and patch embeds size mismatch."
+
+ patches_list = []
+ cur_position = 0
+ for i in range(grid_hw.shape[0]):
+ h, w = grid_hw[i]
+ patches_per_img = patch_embeds[cur_position : cur_position + h * w].view(h, w, -1).unsqueeze(0)
+ patches_per_img = self.dense_embedding(patches_per_img.permute(0, 3, 1, 2))
+ patches_per_img = patches_per_img.permute(0, 2, 3, 1)
+ patches_list.append(patches_per_img.view(-1, patches_per_img.shape[-1]))
+ cur_position += h * w
+
+ embeddings = torch.cat(patches_list, dim=0) # (N_total // downsample_factor**2, C)
+ assert cur_position == patch_embeds.shape[0]
+ assert embeddings.shape[0] == int(patch_embeds.shape[0] / self.downsample_factor ** 2)
+
+ return embeddings
+
+ def encode(self, images: List[ImageItem]):
+ img_tensors = []
+ valid_ids = []
+ valid_id = 0
+ img_grids = []
+ uuids = []
+
+ for i, img in enumerate(images):
+ if isinstance(img, ImageItem):
+ uuids.append(img.uuid)
+ image_data = read_shm(get_shm_name_data(img.uuid))
+ image_data = Image.open(BytesIO(image_data))
+ # a = img.extra_params["min_pixels"]
+ # b = img.extra_params["max_pixels"]
+ # print(f"self.min_pixels is {a} ,max_pixelx is {b}")
+ pixel_values, image_grid_hw = load_image_native(
+ image_data,
+ patch_size=self.patch_size,
+ downsample_ratio=self.downsample_ratio,
+ min_pixels=img.extra_params["min_pixels"],
+ max_pixels=img.extra_params["max_pixels"],
+ )
+ img_tensors.append(pixel_values)
+ img_grids.append(image_grid_hw)
+ else:
+ raise Exception("Unsupport input types: {} for {}".format(type(img), img))
+
+ # must devide merge_length
+ cur_num = int(img_tensors[-1].shape[0] * (self.downsample_ratio ** 2))
+ valid_ids.append([valid_id, valid_id + cur_num])
+ valid_id += cur_num
+
+ if len(img_tensors) <= 0:
+ return None
+
+ imgs = torch.cat(img_tensors, dim=0)
+ grid_hw = torch.cat(img_grids, dim=0)
+
+ pixel_values = imgs.to("cuda", dtype=self.data_type, non_blocking=True)
+ image_grid_hw = grid_hw.to("cuda", non_blocking=True)
+
+ all_img_embeds = self.forward(pixel_values, grid_hw=image_grid_hw)
+
+ return all_img_embeds, uuids, valid_ids
diff --git a/lightllm/models/neo_chat_moe/triton_kernel/__init__.py b/lightllm/models/neo_chat_moe/triton_kernel/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py b/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py
new file mode 100644
index 0000000000..42c3254e27
--- /dev/null
+++ b/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py
@@ -0,0 +1,431 @@
+import math
+import torch
+import triton
+import triton.language as tl
+
+from lightllm.utils.device_utils import is_tesla
+
+
+@triton.jit
+def _fwd_kernel(
+ Q,
+ K,
+ V,
+ sm_scale,
+ Out,
+ position_ids, # 1D: packed like Q (only NEW tokens), length == Q.shape[0]
+ B_Start_Loc,
+ B_Seqlen,
+ Req_to_tokens,
+ B_req_idx,
+ stride_qbs,
+ stride_qh,
+ stride_qd,
+ stride_kbs,
+ stride_kh,
+ stride_kd,
+ stride_vbs,
+ stride_vh,
+ stride_vd,
+ stride_obs,
+ stride_oh,
+ stride_od,
+ stride_req_to_tokens_b,
+ stride_req_to_tokens_s,
+ kv_group_num,
+ b_prompt_cache_len,
+ b_image_token_tag,
+ H: tl.constexpr,
+ QK_HEAD_DIM: tl.constexpr,
+ V_HEAD_DIM: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+):
+ start_m = tl.program_id(0)
+ cur_bh = tl.program_id(1)
+ cur_batch = cur_bh // H
+ cur_head = cur_bh % H
+
+ cur_kv_head = cur_head // kv_group_num
+
+ cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
+ prompt_cache_len = tl.load(b_prompt_cache_len + cur_batch)
+ total_len = tl.load(B_Seqlen + cur_batch)
+ cur_batch_seq_len = total_len - prompt_cache_len # NEW len
+ cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
+
+ block_start_loc = BLOCK_M * start_m
+ if block_start_loc >= cur_batch_seq_len:
+ return
+
+ offs_n = tl.arange(0, BLOCK_N)
+ offs_d_qk = tl.arange(0, QK_HEAD_DIM)
+ offs_d_v = tl.arange(0, V_HEAD_DIM)
+ offs_m = block_start_loc + tl.arange(0, BLOCK_M)
+
+ # Q pointers
+ off_q = (
+ (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs
+ + cur_head * stride_qh
+ + offs_d_qk[None, :] * stride_qd
+ )
+
+ q_valid = offs_m < cur_batch_seq_len
+ q = tl.load(Q + off_q, mask=q_valid[:, None], other=0.0)
+
+ # online softmax state
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
+ acc = tl.zeros([BLOCK_M, V_HEAD_DIM], dtype=tl.float32)
+ block_end_loc = total_len
+
+ # absolute q positions in the request
+ q_pos = prompt_cache_len + offs_m # [M]
+ q_image_token_tag = tl.load(b_image_token_tag + cur_batch_in_all_start_index + offs_m, mask=q_valid, other=False)
+
+ for start_n in range(0, block_end_loc, BLOCK_N):
+ start_n = tl.multiple_of(start_n, BLOCK_N)
+
+ k_pos = start_n + offs_n # [N]
+ k_valid = k_pos < block_end_loc
+
+ # map logical pos -> mem_index (for K/V)
+ kv_loc = tl.load(
+ Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * k_pos,
+ mask=k_valid,
+ other=0,
+ ).to(tl.int64)
+
+ # load K
+ off_k = kv_loc[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d_qk[:, None] * stride_kd
+ k = tl.load(K + off_k, mask=k_valid[None, :], other=0.0)
+ qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
+ qk += tl.dot(q, k)
+
+ # mask: causal OR same gid (only possible inside NEW part)
+ mask = (q_pos[:, None] >= k_pos[None, :]) | q_image_token_tag[:, None]
+ qk = tl.where(mask, qk * sm_scale, -1.0e8)
+
+ # online softmax
+ m_ij = tl.maximum(m_i, tl.max(qk, 1))
+ qk -= m_ij[:, None]
+ p = tl.math.exp2(qk)
+ l_ij = tl.sum(p, 1)
+
+ alpha = tl.math.exp2(m_i - m_ij)
+ l_i = l_i * alpha + l_ij
+ acc = acc * alpha[:, None]
+
+ # load V
+ off_v = kv_loc[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d_v[None, :] * stride_vd
+ v = tl.load(V + off_v, mask=k_valid[:, None], other=0.0)
+
+ p = p.to(v.dtype)
+ acc = tl.dot(p, v, acc)
+
+ m_i = m_ij
+
+ acc = acc / l_i[:, None]
+
+ off_o = (
+ (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs
+ + cur_head * stride_oh
+ + offs_d_v[None, :] * stride_od
+ )
+ tl.store(Out + off_o, acc, mask=q_valid[:, None])
+
+
+@torch.no_grad()
+def context_attention_fwd_neo(
+ q,
+ k,
+ v,
+ o,
+ position_ids, # 1D packed like q (only NEW tokens)
+ b_req_idx,
+ b_start_loc,
+ b_seq_len,
+ b_prompt_cache_len,
+ max_input_len,
+ req_to_token_indexs,
+ b_image_token_tag,
+):
+ # minimal safety: position_ids must cover packed q rows
+ assert position_ids.numel() >= q.shape[0], (position_ids.numel(), q.shape[0])
+
+ BLOCK_M = 128 if not is_tesla() else 64
+
+ Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
+ assert Lq == Lk and Lk == Lv
+ assert Lk in {16, 32, 64, 128, 256}
+ base_head_dim = Lq // 2
+ sm_scale = 1.0 / (base_head_dim ** 0.5) * 1.4426950408889634
+
+ batch, head = b_seq_len.shape[0], q.shape[1]
+ kv_group_num = q.shape[1] // k.shape[1]
+
+ grid = lambda meta: (triton.cdiv(max_input_len, meta["BLOCK_M"]), batch * head, 1)
+
+ BLOCK_N = BLOCK_M
+ num_warps = 4 if Lk <= 64 else 8
+ num_stages = 1
+
+ _fwd_kernel[grid](
+ q,
+ k,
+ v,
+ sm_scale,
+ o,
+ position_ids,
+ b_start_loc,
+ b_seq_len,
+ req_to_token_indexs,
+ b_req_idx,
+ q.stride(0),
+ q.stride(1),
+ q.stride(2),
+ k.stride(0),
+ k.stride(1),
+ k.stride(2),
+ v.stride(0),
+ v.stride(1),
+ v.stride(2),
+ o.stride(0),
+ o.stride(1),
+ o.stride(2),
+ req_to_token_indexs.stride(0),
+ req_to_token_indexs.stride(1),
+ kv_group_num=kv_group_num,
+ b_prompt_cache_len=b_prompt_cache_len,
+ b_image_token_tag=b_image_token_tag,
+ H=head,
+ QK_HEAD_DIM=Lk,
+ V_HEAD_DIM=Lk // 2,
+ BLOCK_M=BLOCK_M,
+ BLOCK_N=BLOCK_N,
+ num_warps=num_warps,
+ num_stages=num_stages,
+ )
+
+
+def reference_attention(
+ q,
+ k,
+ v,
+ position_ids_q, # 1D packed like q (only NEW tokens)
+ b_req_idx,
+ b_start_loc,
+ b_seq_len,
+ b_prompt_cache_len,
+ req_to_token_indexs,
+):
+ device = q.device
+ dtype = q.dtype
+ sum_q, Hq, D = q.shape
+ Hk = k.shape[1]
+ kv_group_num = Hq // Hk
+
+ batch = b_seq_len.shape[0]
+ out = torch.empty_like(q)
+ scale = 1.0 / math.sqrt(D)
+
+ for b in range(batch):
+ req = int(b_req_idx[b].item())
+ total_len = int(b_seq_len[b].item())
+ prompt_len = int(b_prompt_cache_len[b].item())
+ new_len = total_len - prompt_len
+
+ q_start = int(b_start_loc[b].item())
+ q_blk = q[q_start : q_start + new_len] # [M, Hq, D]
+ gid_new = position_ids_q[q_start : q_start + new_len].to(torch.int64) # [M]
+
+ # gather K/V for full request by logical pos -> mem_index
+ token_locs = req_to_token_indexs[req, :total_len].to(torch.int64) # [L]
+ k_blk = k[token_locs] # [L, Hk, D]
+ v_blk = v[token_locs] # [L, Hk, D]
+
+ # expand kv heads to q heads (GQA)
+ k_hq = k_blk.repeat_interleave(kv_group_num, dim=1) # [L, Hq, D]
+ v_hq = v_blk.repeat_interleave(kv_group_num, dim=1) # [L, Hq, D]
+
+ # positions
+ q_pos = torch.arange(prompt_len, total_len, device=device, dtype=torch.int64) # [M]
+ k_pos = torch.arange(0, total_len, device=device, dtype=torch.int64) # [L]
+
+ # build allow mask:
+ # causal always
+ allow = k_pos[None, :] <= q_pos[:, None]
+
+ # full-attn only inside NEW part by gid
+ # compare only when k_pos in NEW
+ k_in_new = k_pos >= prompt_len
+ k_rel = (k_pos - prompt_len).clamp_min(0) # [L]
+ # map k_rel to gid_new, but only valid where k_in_new
+ k_gid = torch.empty((total_len,), device=device, dtype=torch.int64)
+ k_gid[:] = 10 ** 12 + k_pos # never equal to gid_new
+ k_gid[k_in_new] = gid_new[k_rel[k_in_new]]
+
+ allow = allow | (gid_new[q_pos - prompt_len][:, None] == k_gid[None, :])
+
+ # scores: [Hq, M, L]
+ q_t = q_blk.permute(1, 0, 2).to(torch.float32) # [Hq, M, D]
+ k_t = k_hq.permute(1, 2, 0).to(torch.float32) # [Hq, D, L]
+ scores = torch.matmul(q_t, k_t) * scale # [Hq, M, L]
+
+ neg = torch.tensor(-1.0e9, device=device, dtype=torch.float32)
+ scores = torch.where(allow[None, :, :], scores, neg)
+
+ p = torch.softmax(scores, dim=-1).to(torch.float32) # [Hq, M, L]
+ v_t = v_hq.permute(1, 0, 2).to(torch.float32) # [Hq, L, D]
+ out_hq = torch.matmul(p, v_t) # [Hq, M, D]
+ out_blk = out_hq.permute(1, 0, 2).to(dtype) # [M, Hq, D]
+
+ out[q_start : q_start + new_len] = out_blk
+
+ return out
+
+
+def make_test_case(
+ device="cuda",
+ dtype=torch.float16,
+ batch=3,
+ Hq=8,
+ Hk=4,
+ D=64,
+ seed=0,
+ base_index=50000,
+):
+ torch.manual_seed(seed)
+
+ # prompt (cached) len and new len
+ prompt_lens = torch.randint(low=2, high=8, size=(batch,), device=device)
+ new_lens = torch.randint(low=1, high=8, size=(batch,), device=device)
+ total_lens = (prompt_lens + new_lens).to(torch.int32)
+
+ max_total_len = int(total_lens.max().item())
+ max_new_len = int(new_lens.max().item())
+
+ # packed q start
+ b_start_loc = torch.zeros((batch,), device=device, dtype=torch.int32)
+ cur = 0
+ for b in range(batch):
+ b_start_loc[b] = cur
+ cur += int(new_lens[b].item())
+ sum_q = cur
+
+ b_seq_len = total_lens
+ b_prompt_cache_len = prompt_lens.to(torch.int32)
+
+ # one req per batch
+ num_req = batch
+ b_req_idx = torch.arange(batch, device=device, dtype=torch.int32)
+
+ # global KV space large, indices not small
+ sum_kv = int(total_lens.sum().item())
+ kv_size = base_index + sum_kv + 1024
+ pool = torch.randperm(kv_size - base_index, device=device, dtype=torch.int64)[:sum_kv] + base_index
+
+ # Req_to_tokens [num_req, max_total_len]
+ req_to_token_indexs = torch.zeros((num_req, max_total_len), device=device, dtype=torch.int32)
+ p = 0
+ for r in range(num_req):
+ L = int(total_lens[r].item())
+ req_to_token_indexs[r, :L] = pool[p : p + L].to(torch.int32)
+ p += L
+
+ # position_ids_q: only NEW tokens, packed like q
+ position_ids_q = torch.empty((sum_q,), device=device, dtype=torch.int32)
+ for b in range(batch):
+ M = int(new_lens[b].item())
+ start = int(b_start_loc[b].item())
+
+ gid = torch.arange(M, device=device, dtype=torch.int32)
+
+ # make one repeated block inside NEW part to simulate image tokens
+ if M >= 4 and torch.rand((), device=device).item() > 0.3:
+ s = int(torch.randint(0, M - 2, (1,), device=device).item())
+ e = min(M, s + 3)
+ gid[s:e] = gid[s]
+
+ position_ids_q[start : start + M] = gid
+
+ q = torch.randn((sum_q, Hq, D), device=device, dtype=dtype)
+ k = torch.randn((kv_size, Hk, D), device=device, dtype=dtype)
+ v = torch.randn((kv_size, Hk, D), device=device, dtype=dtype)
+ o = torch.empty((sum_q, Hq, D), device=device, dtype=dtype)
+
+ return (
+ q,
+ k,
+ v,
+ o,
+ position_ids_q,
+ b_req_idx,
+ b_start_loc,
+ b_seq_len,
+ b_prompt_cache_len,
+ max_new_len,
+ req_to_token_indexs,
+ )
+
+
+def check_once(device="cuda", dtype=torch.float16, seed=0):
+ (
+ q,
+ k,
+ v,
+ o,
+ position_ids_q,
+ b_req_idx,
+ b_start_loc,
+ b_seq_len,
+ b_prompt_cache_len,
+ max_new_len,
+ req_to_token_indexs,
+ ) = make_test_case(device=device, dtype=dtype, seed=seed)
+
+ context_attention_fwd_neo(
+ q,
+ k,
+ v,
+ o,
+ position_ids_q,
+ b_req_idx,
+ b_start_loc,
+ b_seq_len,
+ b_prompt_cache_len,
+ max_new_len,
+ req_to_token_indexs,
+ )
+
+ ref = reference_attention(
+ q,
+ k,
+ v,
+ position_ids_q,
+ b_req_idx,
+ b_start_loc,
+ b_seq_len,
+ b_prompt_cache_len,
+ req_to_token_indexs,
+ )
+
+ diff = (o - ref).abs()
+ max_abs = diff.max().item()
+ denom = ref.abs().max().item() + 1e-6
+ max_rel = max_abs / denom
+
+ print(f"seed={seed}, dtype={dtype}")
+ print(f"max_abs_error = {max_abs:.6e}")
+ print(f"max_rel_error = {max_rel:.6e}")
+ print("allclose(fp16 tol)?", torch.allclose(o, ref, atol=5e-2, rtol=5e-2))
+
+
+if __name__ == "__main__":
+ if not torch.cuda.is_available():
+ print("No CUDA, skip.")
+ else:
+ torch.cuda.synchronize()
+ check_once(dtype=torch.bfloat16, seed=0)
+ check_once(dtype=torch.bfloat16, seed=1)
+ check_once(dtype=torch.bfloat16, seed=2)
diff --git a/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py b/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py
new file mode 100644
index 0000000000..1a3d4af73b
--- /dev/null
+++ b/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py
@@ -0,0 +1,191 @@
+import torch
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _get_neo_position_triton(
+ b_image_start_idx: torch.Tensor,
+ b_image_thwd: torch.Tensor,
+ b_image_thwd_stride0: torch.Tensor,
+ b_image_nums: torch.Tensor,
+ b_image_start_num: torch.Tensor,
+ b_image_len: torch.Tensor,
+ position_ids: torch.Tensor,
+ position_ids_stride0: torch.Tensor,
+ b_ready_cache_len: torch.Tensor,
+ b_q_seq_len: torch.Tensor,
+ b_start_loc: torch.Tensor,
+ b_image_token_tag: torch.Tensor,
+ BLOCK_SIZE: tl.constexpr,
+) -> torch.Tensor:
+ cur_batch = tl.program_id(0)
+ cache_len = tl.load(b_ready_cache_len + cur_batch)
+ q_seq_len = tl.load(b_q_seq_len + cur_batch)
+ image_num = tl.load(b_image_nums + cur_batch)
+ image_start_num = tl.load(b_image_start_num + cur_batch)
+ start_loc = tl.load(b_start_loc + cur_batch)
+ for i in range(image_num):
+ local_image_start_idx = tl.load(b_image_start_idx + image_start_num + i)
+ image_start_idx = start_loc + local_image_start_idx - cache_len
+ image_len = tl.load(b_image_len + image_start_num + i)
+ # image_h = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 1)
+ image_w = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 2)
+ for j in range(0, image_len, BLOCK_SIZE):
+ off = j + tl.arange(0, BLOCK_SIZE)
+ # 目前没考虑视频,所以t 恒为 0
+ t_pos = local_image_start_idx + off * 0
+ h_pos = off // image_w
+ w_pos = off % image_w
+ tl.store(
+ b_image_token_tag + off + image_start_idx,
+ True,
+ mask=(off < image_len)
+ & (off + local_image_start_idx - cache_len < q_seq_len)
+ & (local_image_start_idx - cache_len + off >= 0),
+ )
+ tl.store(
+ position_ids + off + image_start_idx,
+ t_pos,
+ mask=(off < image_len)
+ & (off + local_image_start_idx - cache_len < q_seq_len)
+ & (local_image_start_idx - cache_len + off >= 0),
+ )
+ tl.store(
+ position_ids + position_ids_stride0 + off + image_start_idx,
+ h_pos,
+ mask=(off < image_len)
+ & (off + local_image_start_idx - cache_len < q_seq_len)
+ & (local_image_start_idx - cache_len + off >= 0),
+ )
+ tl.store(
+ position_ids + position_ids_stride0 * 2 + off + image_start_idx,
+ w_pos,
+ mask=(off < image_len)
+ & (off + local_image_start_idx - cache_len < q_seq_len)
+ & (local_image_start_idx - cache_len + off >= 0),
+ )
+
+ for i in range(image_num):
+ local_image_start_idx = tl.load(b_image_start_idx + image_start_num + i)
+ image_len = tl.load(b_image_len + image_start_num + i)
+ image_delta = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 3)
+ image_end = local_image_start_idx + image_len - cache_len
+ text_start = tl.maximum(0, image_end)
+ for j in range(text_start, q_seq_len, BLOCK_SIZE):
+ off = j + tl.arange(0, BLOCK_SIZE)
+ t_pos = tl.load(position_ids + off + start_loc, mask=(off < q_seq_len), other=0.0) + image_delta
+ h_pos = tl.load(position_ids + position_ids_stride0 + off + start_loc, mask=(off < q_seq_len), other=0.0)
+ w_pos = tl.load(
+ position_ids + position_ids_stride0 * 2 + off + start_loc, mask=(off < q_seq_len), other=0.0
+ )
+ tl.store(position_ids + off + start_loc, t_pos, mask=(off < q_seq_len))
+ tl.store(position_ids + position_ids_stride0 + off + start_loc, h_pos, mask=(off < q_seq_len))
+ tl.store(position_ids + position_ids_stride0 * 2 + off + start_loc, w_pos, mask=(off < q_seq_len))
+ return
+
+
+def get_neo_position_triton(
+ b_image_start_idx: torch.Tensor,
+ b_image_thwd: torch.Tensor,
+ b_image_nums: torch.Tensor,
+ b_image_start_num: torch.Tensor,
+ b_image_len: torch.Tensor,
+ position_ids: torch.Tensor,
+ b_ready_cache_len: torch.Tensor,
+ b_q_seq_len: torch.Tensor,
+ b_start_loc: torch.Tensor,
+ b_image_token_tag: torch.Tensor,
+) -> torch.Tensor:
+
+ batch_size = b_q_seq_len.shape[0]
+ assert batch_size == b_image_nums.shape[0]
+ grid = (batch_size,)
+ BLOCK_SIZE = 64
+ _get_neo_position_triton[grid](
+ b_image_start_idx=b_image_start_idx,
+ b_image_thwd=b_image_thwd,
+ b_image_thwd_stride0=b_image_thwd.stride(0),
+ b_image_nums=b_image_nums,
+ b_image_start_num=b_image_start_num,
+ b_image_len=b_image_len,
+ position_ids=position_ids,
+ position_ids_stride0=position_ids.stride(0),
+ b_ready_cache_len=b_ready_cache_len,
+ b_q_seq_len=b_q_seq_len,
+ b_start_loc=b_start_loc,
+ b_image_token_tag=b_image_token_tag,
+ BLOCK_SIZE=BLOCK_SIZE,
+ )
+
+
+def test():
+ b_image_start_idx = torch.tensor([0, 0, 4], dtype=torch.int32, device="cuda")
+ b_image_thwd = torch.tensor([[1, 2, 2, -3], [1, 2, 2, -3], [1, 2, 4, -7]], dtype=torch.int32, device="cuda")
+ b_image_nums = torch.tensor([1, 2], dtype=torch.int32, device="cuda")
+ b_image_start_num = torch.tensor([0, 1], dtype=torch.int32, device="cuda")
+ b_image_len = torch.tensor([4, 4, 8], dtype=torch.int32, device="cuda")
+ position_ids = (
+ torch.tensor([0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=torch.int32, device="cuda")
+ .unsqueeze(0)
+ .expand(3, -1)
+ .contiguous()
+ )
+ b_image_token_tag = torch.zeros([position_ids.size(1)], dtype=torch.bool, device="cuda")
+ position_ids[1:].zero_()
+ b_ready_cache_len = torch.tensor([0, 0], dtype=torch.int32, device="cuda")
+ b_q_seq_len = torch.tensor([7, 13], dtype=torch.int32, device="cuda")
+ b_start_loc = torch.tensor([0, 7], dtype=torch.int32, device="cuda")
+ get_neo_position_triton(
+ b_image_start_idx,
+ b_image_thwd,
+ b_image_nums,
+ b_image_start_num,
+ b_image_len,
+ position_ids,
+ b_ready_cache_len,
+ b_q_seq_len,
+ b_start_loc,
+ b_image_token_tag,
+ )
+
+ print(b_image_token_tag)
+ print(position_ids)
+ # old_value = torch.cat([position_ids[:, 2:7], position_ids[:, 7 + 2 :]], dim=1)
+
+ # position_ids = (
+ # torch.tensor([2, 3, 4, 5, 6, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=torch.int32, device="cuda")
+ # .unsqueeze(0)
+ # .expand(3, -1)
+ # .contiguous()
+ # )
+ # b_ready_cache_len = torch.tensor([2, 2], dtype=torch.int32, device="cuda")
+ # b_q_seq_len = torch.tensor([5, 11], dtype=torch.int32, device="cuda")
+ # b_start_loc = torch.tensor([0, 5], dtype=torch.int32, device="cuda")
+
+ # get_neo_position_triton(
+ # b_image_start_idx,
+ # b_image_thwd,
+ # b_image_nums,
+ # b_image_start_num,
+ # b_image_len,
+ # position_ids,
+ # b_ready_cache_len,
+ # b_q_seq_len,
+ # b_start_loc,
+ # )
+
+ # print(f"old_value:\n{old_value}")
+ # print(f"position_ids:\n{position_ids}")
+ # assert torch.equal(old_value, position_ids)
+
+ """
+ tensor([[0, 0, 0, 0, 2, 3, 4, 0, 0, 0, 0, 2, 2, 2, 2, 4, 5, 6, 7, 8],
+ [0, 0, 1, 1, 2, 3, 4, 0, 0, 1, 1, 2, 2, 3, 3, 4, 5, 6, 7, 8],
+ [0, 1, 0, 1, 2, 3, 4, 0, 1, 0, 1, 2, 3, 2, 3, 4, 5, 6, 7, 8]],
+ device='cuda:0', dtype=torch.int32)
+ """
+
+
+if __name__ == "__main__":
+ test()
diff --git a/lightllm/models/neo_chat_moe/vision_process.py b/lightllm/models/neo_chat_moe/vision_process.py
new file mode 100644
index 0000000000..fbd57a5e9c
--- /dev/null
+++ b/lightllm/models/neo_chat_moe/vision_process.py
@@ -0,0 +1,141 @@
+import re
+import math
+import torch
+import string
+import numpy as np
+import pandas as pd
+from PIL import Image
+import torch.distributed as dist
+import torchvision.transforms as T
+
+IMAGENET_MEAN = (0.485, 0.456, 0.406)
+IMAGENET_STD = (0.229, 0.224, 0.225)
+
+
+def round_by_factor(number: int, factor: int) -> int:
+ """Returns the closest integer to 'number' that is divisible by 'factor'."""
+ return round(number / factor) * factor
+
+
+def ceil_by_factor(number: int, factor: int) -> int:
+ """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
+ return math.ceil(number / factor) * factor
+
+
+def floor_by_factor(number: int, factor: int) -> int:
+ """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
+ return math.floor(number / factor) * factor
+
+
+# copy from https://github.com/QwenLM/Qwen2.5-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py#L60
+def smart_resize(
+ height: int, width: int, factor: int = 32, min_pixels: int = 65536, max_pixels: int = 4194304
+) -> tuple[int, int]:
+ """
+ Rescales the image so that the following conditions are met:
+
+ 1. Both dimensions (height and width) are divisible by 'factor'.
+
+ 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
+
+ 3. The aspect ratio of the image is maintained as closely as possible.
+ """
+ if max(height, width) / min(height, width) > 200:
+ raise ValueError(
+ f"absolute aspect ratio must be smaller than {200}, got {max(height, width) / min(height, width)}"
+ )
+ h_bar = max(factor, round_by_factor(height, factor))
+ w_bar = max(factor, round_by_factor(width, factor))
+ if h_bar * w_bar > max_pixels:
+ beta = math.sqrt((height * width) / max_pixels)
+ h_bar = max(factor, floor_by_factor(height / beta, factor))
+ w_bar = max(factor, floor_by_factor(width / beta, factor))
+ elif h_bar * w_bar < min_pixels:
+ beta = math.sqrt(min_pixels / (height * width))
+ h_bar = ceil_by_factor(height * beta, factor)
+ w_bar = ceil_by_factor(width * beta, factor)
+ return h_bar, w_bar
+
+
+def dynamic_preprocess_native_resolution(image, size_factor=32, min_pixels=65536, max_pixels=4194304, **kwargs):
+ width, height = image.size
+ resized_height, resized_width = smart_resize(
+ height,
+ width,
+ factor=size_factor,
+ min_pixels=min_pixels,
+ max_pixels=max_pixels,
+ )
+ image = image.resize((resized_width, resized_height))
+
+ return image
+
+
+def preprocess_pixel_values(pixel_values, patch_size=16):
+ c, h, w = pixel_values.shape
+ grid_h = h // patch_size
+ grid_w = w // patch_size
+
+ flatten_pixel_values = (
+ pixel_values.view(c, grid_h, patch_size, grid_w, patch_size)
+ .permute(1, 3, 0, 2, 4) # [grid_h, grid_w, c, patch_size, patch_size]
+ .reshape(grid_h * grid_w, c * patch_size ** 2)
+ )
+
+ grid_hw = torch.tensor([[grid_h, grid_w]]).to(device=pixel_values.device)
+
+ return flatten_pixel_values, grid_hw
+
+
+def get_contrasting_background(image):
+ """
+ Calculate the color (white or black) that is different from the average foreground color
+ to use as the background color
+ """
+ image_np = np.array(image)
+ if (image_np[:, :, 3] == 0).any():
+ non_transparent_pixels = image_np[:, :, :3][image_np[:, :, 3] > 0]
+ if non_transparent_pixels.size == 0:
+ return None
+ pixel_mean = non_transparent_pixels.mean()
+ contrasting_color = (0, 0, 0) if pixel_mean > 382.5 else (255, 255, 255)
+ return contrasting_color
+ else:
+ return None
+
+
+def load_image_native(image, patch_size=16, downsample_ratio=0.5, min_pixels=65536, max_pixels=4194304, upscale=False):
+ """
+ Load and preprocess an image file, converting it to RGB mode,
+ resizing, normalizing, and optionally adding a thumbnail version.
+ """
+ if image.mode == "RGBA":
+ bg_color = get_contrasting_background(image)
+ if bg_color:
+ background = Image.new("RGB", image.size, bg_color)
+ background.paste(image, mask=image.split()[3])
+ image = background.convert("RGB")
+ else:
+ image = image.convert("RGB")
+ else:
+ image = image.convert("RGB")
+
+ if upscale:
+ image = image.resize((image.width * 2, image.height * 2), Image.BILINEAR)
+
+ transform = T.Compose(
+ [
+ T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
+ T.ToTensor(),
+ T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
+ ]
+ )
+
+ new_image = dynamic_preprocess_native_resolution(
+ image, size_factor=int(patch_size // downsample_ratio), min_pixels=min_pixels, max_pixels=max_pixels
+ )
+ pixel_values, grid_hw = preprocess_pixel_values(transform(new_image).to(torch.float32), patch_size=patch_size)
+
+ # print(f"Transfer image_size from ({image.height, image.width}) to ({new_image.height, new_image.width})")
+
+ return pixel_values, grid_hw
diff --git a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py
index 9eccddffc1..af035e81b6 100644
--- a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py
+++ b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py
@@ -93,8 +93,10 @@ def _tpsp_get_qkv(
input = gather_input[0 : len(infer_state.input_ids), :]
input = input.view(-1, self.embed_dim_)
- q = layer_weight.q_proj.mm(input)
- cache_kv = layer_weight.kv_proj.mm(input)
+ qkv = layer_weight.qkv_proj.mm(input)
+ q, cache_kv = qkv.split(
+ [self.tp_q_head_num_ * self.head_dim_, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_], dim=-1
+ )
layer_weight.q_norm_weight_(q, eps=self.eps_)
layer_weight.k_norm_weight_(
cache_kv[:, : self.tp_k_head_num_ * self.head_dim_],
@@ -130,6 +132,7 @@ def _moe_ffn(
use_grouped_topk=False,
topk_group=None,
num_expert_group=None,
+ microbatch_index=infer_state.microbatch_index,
)
return hidden_states.view(num_tokens, hidden_dim)
@@ -150,6 +153,7 @@ def _moe_ffn_edp(
topk_group=None,
num_expert_group=None,
is_prefill=infer_state.is_prefill,
+ microbatch_index=infer_state.microbatch_index,
)
ep_output = ep_output.view(token_num, hidden_dim)
diff --git a/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py
index 13ba6cbe0f..5a857fd093 100644
--- a/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py
+++ b/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py
@@ -52,6 +52,11 @@ def _init_moe(self):
tp_rank=0,
tp_world_size=1,
)
+ mlp_only = set(self.network_config_.get("mlp_only_layers", []))
+ step = self.network_config_.get("decoder_sparse_step", 1)
+ moe_layer_index = sum(
+ 1 for i in range(self.layer_num_) if self.n_routed_experts > 0 and i not in mlp_only and (i + 1) % step == 0
+ )
self.experts = FusedMoeWeight(
gate_proj_name="gate_proj",
down_proj_name="down_proj",
@@ -65,6 +70,7 @@ def _init_moe(self):
quant_method=self.quant_cfg.get_quant_method(self.layer_num_, "fused_moe"),
layer_num=self.layer_num_,
network_config=self.network_config_,
+ moe_layer_index=moe_layer_index,
)
def _init_qkv(self):
diff --git a/lightllm/models/qwen3_moe/model.py b/lightllm/models/qwen3_moe/model.py
index 10a5051276..2926a12b1f 100644
--- a/lightllm/models/qwen3_moe/model.py
+++ b/lightllm/models/qwen3_moe/model.py
@@ -4,6 +4,7 @@
from lightllm.models.qwen3_moe.layer_infer.transformer_layer_infer import Qwen3MOETransformerLayerInfer
from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight
from lightllm.models.qwen3.model import Qwen3TpPartModel
+from lightllm.common.basemodel.routing_manager import init_routing_capture
from lightllm.utils.log_utils import init_logger
from lightllm.distributed.communication_op import dist_group_manager
@@ -26,3 +27,6 @@ def __init__(self, kvargs):
def _init_custom(self):
super()._init_custom()
dist_group_manager.new_deepep_group(self.config["num_experts"], self.config["hidden_size"])
+ if self.args.enable_return_routed_experts:
+ num_moe_layers = sum(1 for w in self.trans_layers_weight if w.is_moe)
+ init_routing_capture(self, num_moe_layers)
diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py
index 73b9bad4a4..409460feb0 100644
--- a/lightllm/server/api_cli.py
+++ b/lightllm/server/api_cli.py
@@ -1,8 +1,7 @@
import argparse
-def make_argument_parser() -> argparse.ArgumentParser:
- parser = argparse.ArgumentParser()
+def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser.add_argument(
"--run_mode",
@@ -607,6 +606,12 @@ def make_argument_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--disk_cache_storage_size", type=float, default=10, help="""The capacity of disk cache. GB used."""
)
+ parser.add_argument(
+ "--enable_torch_memory_saver",
+ action="store_true",
+ help="""enable torch memory saver, which is used for release_memory and resume_memory during RL training.""",
+ )
+ parser.add_argument("--enable_weight_cpu_backup", action="store_true", help="""enable weight cpu backup.""")
parser.add_argument(
"--disk_cache_dir",
type=str,
@@ -639,4 +644,10 @@ def make_argument_parser() -> argparse.ArgumentParser:
If the op is not implemented for the platform and the hardware support triton,
it will use triton implementation.""",
)
+ parser.add_argument(
+ "--enable_return_routed_experts",
+ action="store_true",
+ default=False,
+ help="Enable returning routed expert indices for MoE models (R3 feature).",
+ )
return parser
diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py
index be2315d34c..2502df9777 100755
--- a/lightllm/server/api_http.py
+++ b/lightllm/server/api_http.py
@@ -33,7 +33,7 @@
import uuid
from PIL import Image
import multiprocessing as mp
-from typing import AsyncGenerator, Union
+from typing import Any, AsyncGenerator, Union
from typing import Callable
from lightllm.server import TokenLoad
from fastapi import BackgroundTasks, FastAPI, Request, WebSocket, WebSocketDisconnect
@@ -49,6 +49,7 @@
from lightllm.utils.error_utils import ServerBusyError
from lightllm.server.metrics.manager import MetricClient
from lightllm.utils.envs_utils import get_unique_server_name
+from lightllm.server.io_struct import ReleaseMemoryReq, ResumeMemoryReq
from dataclasses import dataclass
from .api_openai import chat_completions_impl, completions_impl
@@ -58,6 +59,15 @@
CompletionRequest,
CompletionResponse,
)
+from .io_struct import (
+ AbortReq,
+ FlushCacheReq,
+ InitWeightsUpdateGroupReq,
+ DestroyWeightsUpdateGroupReq,
+ UpdateWeightsFromDistributedReq,
+ UpdateWeightsFromTensorReq,
+ GeneralModelToHttpRpcRsp,
+)
from .build_prompt import build_prompt, init_tokenizer
logger = init_logger(__name__)
@@ -132,6 +142,22 @@ def get_model_name():
return {"model_name": g_objs.args.model_name}
+@app.get("/get_server_info")
+@app.post("/get_server_info")
+def get_server_info():
+ # 将 StartArgs 转换为字典格式
+ from dataclasses import asdict
+
+ server_info: dict[str, Any] = asdict(g_objs.args)
+ return {**server_info}
+
+
+@app.get("/get_weight_version")
+@app.post("/get_weight_version")
+def get_weight_version():
+ return {"weight_version": g_objs.args.weight_version}
+
+
@app.get("/healthz", summary="Check server health")
@app.get("/health", summary="Check server health")
@app.head("/health", summary="Check server health")
@@ -293,6 +319,83 @@ async def metrics() -> Response:
return response
+@app.post("/abort_request")
+async def abort_request(request: AbortReq, raw_request: Request):
+ """Abort a request."""
+ try:
+ await g_objs.httpserver_manager.abort_request(request)
+ return Response(status_code=200)
+ except Exception as e:
+ return create_error_response(HTTPStatus.EXPECTATION_FAILED, f"error: {str(e)}")
+
+
+async def handle_request_common(request_obj, handler):
+ try:
+ ret: GeneralModelToHttpRpcRsp = await handler(request_obj)
+ if ret.success:
+ return JSONResponse({"success": ret.success, "message": ret.msg}, status_code=200)
+ else:
+ return create_error_response(HTTPStatus.BAD_REQUEST, ret.msg)
+ except Exception as e:
+ logger.error("handle_request_common (%s) error occurred: %s", str(request_obj), str(e), exc_info=True)
+ return create_error_response(HTTPStatus.EXPECTATION_FAILED, f"error: {str(e)}")
+
+
+@app.post("/init_weights_update_group")
+async def init_weights_update_group(request: InitWeightsUpdateGroupReq, raw_request: Request):
+ """Init weights update group."""
+ return await handle_request_common(request, g_objs.httpserver_manager.init_weights_update_group)
+
+
+@app.post("/destroy_weights_update_group")
+async def destroy_weights_update_group(request: DestroyWeightsUpdateGroupReq, raw_request: Request):
+ """Destroy weights update group."""
+ return await handle_request_common(request, g_objs.httpserver_manager.destroy_weights_update_group)
+
+
+@app.post("/update_weights_from_distributed")
+async def update_weights_from_distributed(request: UpdateWeightsFromDistributedReq, raw_request: Request):
+ """Update model parameter from distributed online."""
+ return await handle_request_common(request, g_objs.httpserver_manager.update_weights_from_distributed)
+
+
+@app.post("/update_weights_from_tensor")
+async def update_weights_from_tensor(request: UpdateWeightsFromTensorReq, raw_request: Request):
+ """Update model parameter from distributed online."""
+ return await handle_request_common(request, g_objs.httpserver_manager.update_weights_from_tensor)
+
+
+@app.post("/flush_cache")
+@app.get("/flush_cache")
+async def flush_cache():
+ """Flush the radix cache."""
+ return await handle_request_common(FlushCacheReq(), g_objs.httpserver_manager.flush_cache)
+
+
+@app.post("/pause_generation")
+async def pause_generation():
+ await g_objs.httpserver_manager.pause_generation()
+ return Response(content="Generation paused successfully.", status_code=200)
+
+
+@app.post("/continue_generation")
+async def continue_generation():
+ await g_objs.httpserver_manager.continue_generation()
+ return Response(content="Generation continued successfully.", status_code=200)
+
+
+@app.get("/release_memory_occupation")
+@app.post("/release_memory_occupation")
+async def release_memory_occupation(request: ReleaseMemoryReq):
+ return await handle_request_common(request, g_objs.httpserver_manager.release_memory_occupation)
+
+
+@app.get("/resume_memory_occupation")
+@app.post("/resume_memory_occupation")
+async def resume_memory_occupation(request: ResumeMemoryReq):
+ return await handle_request_common(request, g_objs.httpserver_manager.resume_memory_occupation)
+
+
@app.websocket("/pd_register")
async def register_and_keep_alive(websocket: WebSocket):
await websocket.accept()
diff --git a/lightllm/server/api_lightllm.py b/lightllm/server/api_lightllm.py
index d3592a5f54..d15bec6485 100644
--- a/lightllm/server/api_lightllm.py
+++ b/lightllm/server/api_lightllm.py
@@ -35,6 +35,9 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana
prompt = request_dict.pop("inputs")
sample_params_dict = request_dict["parameters"]
return_details = sample_params_dict.pop("return_details", False)
+ return_routed_experts = sample_params_dict.pop(
+ "return_routed_experts", httpserver_manager.args.enable_return_routed_experts
+ )
sampling_params = SamplingParams()
sampling_params.init(tokenizer=httpserver_manager.tokenizer, **sample_params_dict)
sampling_params.verify()
@@ -53,6 +56,7 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana
prompt_token_ids = None
is_first_metadata = True
input_usage = None
+ routed_experts_data = None
async for sub_req_id, request_output, metadata, finish_status in results_generator:
# when set "--return_all_prompt_logprobs", the first token metadata will contains
# prompt_logprobs and prompt_token_ids
@@ -78,6 +82,8 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana
if finish_status.is_finished():
finish_reason_dict[sub_req_id] = finish_status
+ if "routed_experts" in metadata:
+ routed_experts_data = metadata["routed_experts"]
n = sampling_params.n
sub_ids = list(final_output_dict.keys())[:n]
final_output_list = ["".join(final_output_dict[sub_id]) for sub_id in sub_ids]
@@ -102,6 +108,8 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana
ret["prompt_logprobs"] = prompt_logprobs
if input_usage is not None:
ret["input_usage"] = input_usage
+ if return_routed_experts and routed_experts_data is not None:
+ ret["routed_experts"] = routed_experts_data
return Response(content=json.dumps(ret, ensure_ascii=False).encode("utf-8"))
@@ -112,6 +120,7 @@ async def lightllm_generate_stream(request: Request, httpserver_manager: HttpSer
prompt = request_dict.pop("inputs")
sample_params_dict = request_dict["parameters"]
_ = sample_params_dict.pop("return_details", False)
+ _ = sample_params_dict.pop("return_routed_experts", None)
sampling_params = SamplingParams()
sampling_params.init(tokenizer=httpserver_manager.tokenizer, **sample_params_dict)
sampling_params.verify()
diff --git a/lightllm/server/api_server.py b/lightllm/server/api_server.py
index b4447d808a..1eb5ff24c0 100755
--- a/lightllm/server/api_server.py
+++ b/lightllm/server/api_server.py
@@ -1,15 +1,36 @@
import torch
-from .api_cli import make_argument_parser
+from .api_cli import add_cli_args
+from lightllm.server.core.objs.start_args_type import StartArgs
+from lightllm.utils.log_utils import init_logger
-if __name__ == "__main__":
- torch.multiprocessing.set_start_method("spawn") # this code will not be ok for settings to fork to subprocess
- parser = make_argument_parser()
- args = parser.parse_args()
+logger = init_logger(__name__)
+
+
+def launch_server(args: StartArgs):
from .api_start import pd_master_start, normal_or_p_d_start, config_server_start
+ try:
+ # this code will not be ok for settings to fork to subprocess
+ torch.multiprocessing.set_start_method("spawn")
+ except RuntimeError as e:
+ logger.warning(f"Failed to set start method: {e}")
+ except Exception as e:
+ logger.error(f"Failed to set start method: {e}")
+ raise e
+
if args.run_mode == "pd_master":
pd_master_start(args)
elif args.run_mode == "config_server":
config_server_start(args)
else:
normal_or_p_d_start(args)
+
+
+if __name__ == "__main__":
+ from argparse import ArgumentParser
+
+ parser = ArgumentParser()
+ add_cli_args(parser)
+ args = parser.parse_args()
+
+ launch_server(StartArgs(**vars(args)))
diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py
index bd8e4db8b2..58dac941b0 100644
--- a/lightllm/server/api_start.py
+++ b/lightllm/server/api_start.py
@@ -1,3 +1,4 @@
+import multiprocessing as mp
import os
import sys
import time
@@ -16,6 +17,7 @@
from lightllm.utils.process_check import is_process_active
from lightllm.utils.multinode_utils import send_and_receive_node_ip
from lightllm.utils.shm_size_check import check_recommended_shm_size
+from lightllm.server.core.objs.start_args_type import StartArgs
logger = init_logger(__name__)
@@ -51,20 +53,44 @@ def signal_handler(sig, frame):
process_manager.terminate_all_processes()
logger.info("All processes have been terminated gracefully.")
sys.exit(0)
+ elif sig == signal.SIGHUP:
+ logger.info("Received SIGHUP (terminal closed), shutting down gracefully...")
+ if http_server_process and http_server_process.poll() is None:
+ http_server_process.send_signal(signal.SIGTERM)
+
+ start_time = time.time()
+ while (time.time() - start_time) < 60:
+ if not is_process_active(http_server_process.pid):
+ logger.info("httpserver exit")
+ break
+ time.sleep(1)
+
+ if time.time() - start_time < 60:
+ logger.info("HTTP server has exited gracefully")
+ else:
+ logger.warning("HTTP server did not exit in time, killing it...")
+ kill_recursive(http_server_process)
+
+ process_manager.terminate_all_processes()
+ logger.info("All processes have been terminated gracefully due to terminal closure.")
+ sys.exit(0)
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
+ signal.signal(signal.SIGHUP, signal_handler)
logger.info(f"start process pid {os.getpid()}")
logger.info(f"http server pid {http_server_process.pid}")
return
-def normal_or_p_d_start(args):
- from lightllm.server.core.objs.start_args_type import StartArgs
+def _set_envs_and_config(args: StartArgs):
+ mp.set_start_method("spawn", force=True)
+
- args: StartArgs = args
+def _launch_subprocesses(args: StartArgs):
+ _set_envs_and_config(args)
set_unique_server_name(args)
if not args.disable_shm_warning:
@@ -138,6 +164,10 @@ def normal_or_p_d_start(args):
assert args.mtp_draft_model_dir is None
assert args.mtp_step == 0
+ # automatically set visual_dp based on visual_tp and tp
+ if args.visual_tp < args.tp and args.tp % args.visual_tp == 0:
+ args.visual_dp = args.tp // args.visual_tp
+
# 检查GPU数量是否足够
if args.visual_gpu_ids is None:
args.visual_gpu_ids = list(range(args.visual_dp * args.visual_tp))
@@ -238,6 +268,7 @@ def normal_or_p_d_start(args):
visual_nccl_ports.append(can_use_ports[0])
can_use_ports = can_use_ports[1:]
+ args.visual_nccl_ports = visual_nccl_ports
# 将申请好的端口放入args参数中
if args.nccl_port is None:
args.nccl_port = nccl_port
@@ -333,6 +364,13 @@ def normal_or_p_d_start(args):
],
)
+ return process_manager
+
+
+def normal_or_p_d_start(args: StartArgs):
+
+ process_manager = _launch_subprocesses(args)
+
# 启动 gunicorn
command = [
"gunicorn",
@@ -372,7 +410,7 @@ def normal_or_p_d_start(args):
return
-def pd_master_start(args):
+def pd_master_start(args: StartArgs):
set_unique_server_name(args)
if args.run_mode != "pd_master":
return
@@ -435,7 +473,7 @@ def pd_master_start(args):
http_server_process.wait()
-def config_server_start(args):
+def config_server_start(args: StartArgs):
set_unique_server_name(args)
if args.run_mode != "config_server":
return
diff --git a/lightllm/server/audioserver/manager.py b/lightllm/server/audioserver/manager.py
index 945e67681d..fdea28f576 100644
--- a/lightllm/server/audioserver/manager.py
+++ b/lightllm/server/audioserver/manager.py
@@ -11,7 +11,7 @@
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
from lightllm.utils.log_utils import init_logger
-from lightllm.server.core.objs.io_objs.group_req import GroupReqIndexes
+from lightllm.server.io_struct import BaseReq, GenerateReqIndex
from lightllm.server.core.objs.shm_req_manager import ShmReqManager, StartArgs
from lightllm.server.multimodal_params import AudioItem
from .model_infer.model_rpc import start_model_process, AudioModelRpcClient
@@ -42,7 +42,7 @@ def __init__(
self.cache_client = rpyc.connect("localhost", args.cache_port, config={"allow_pickle": True})
self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
self.cache_port = args.cache_port
- self.waiting_reqs: List[GroupReqIndexes] = []
+ self.waiting_reqs: List[GenerateReqIndex] = []
self.model_weightdir = args.model_dir
self.tp_world_size = args.tp
self.world_size = 1
@@ -136,8 +136,8 @@ async def loop_for_fwd(self):
async def loop_for_netio_req(self):
while True:
- recv_req: GroupReqIndexes = await self.zmq_recv_socket.recv_pyobj()
- if isinstance(recv_req, GroupReqIndexes):
+ recv_req: BaseReq = await self.zmq_recv_socket.recv_pyobj()
+ if isinstance(recv_req, GenerateReqIndex):
self.waiting_reqs.append(recv_req)
else:
assert False, f"Error Req Inf {recv_req}"
diff --git a/lightllm/server/core/objs/io_objs/__init__.py b/lightllm/server/core/objs/io_objs/__init__.py
index c9b806c47d..10386b70e6 100644
--- a/lightllm/server/core/objs/io_objs/__init__.py
+++ b/lightllm/server/core/objs/io_objs/__init__.py
@@ -1 +1 @@
-from .group_req import GroupReqIndexes, GroupReqObjs, AbortedReqCmd, StopStrMatchedReqCmd
+from .group_req import AbortedReqCmd, StopStrMatchedReqCmd
diff --git a/lightllm/server/core/objs/io_objs/group_req.py b/lightllm/server/core/objs/io_objs/group_req.py
index dfcbdd2562..d644c0c316 100644
--- a/lightllm/server/core/objs/io_objs/group_req.py
+++ b/lightllm/server/core/objs/io_objs/group_req.py
@@ -1,33 +1,10 @@
from dataclasses import dataclass
from lightllm.server.multimodal_params import MultimodalParams
+from lightllm.server.core.objs.sampling_params import SamplingParams
from typing import List
from ..req import Req
-@dataclass
-class GroupReqIndexes:
- group_req_id: int
- multimodal_params: MultimodalParams
- shm_req_indexes: List[int]
- time_mark: float
-
-
-@dataclass
-class GroupReqObjs:
- group_req_id: int
- multimodal_params: MultimodalParams
- shm_req_objs: List[Req]
- time_mark: float
-
- def to_group_req_index(self):
- return GroupReqIndexes(
- group_req_id=self.group_req_id,
- multimodal_params=self.multimodal_params,
- shm_req_indexes=[req.index_in_shm_mem for req in self.shm_req_objs],
- time_mark=self.time_mark,
- )
-
-
@dataclass
class AbortedReqCmd:
req_id: int
diff --git a/lightllm/server/core/objs/py_sampling_params.py b/lightllm/server/core/objs/py_sampling_params.py
index 887f360c84..08921317e8 100644
--- a/lightllm/server/core/objs/py_sampling_params.py
+++ b/lightllm/server/core/objs/py_sampling_params.py
@@ -54,6 +54,8 @@ def __init__(
# processor which only retains scores for the given token ids. Defaults to None.
# allowed_token_ids only can be used in "--output_constraint_mode outlines" started server.
allowed_token_ids: Optional[List[int]] = None,
+ # if provided, the invalid token ids will be ignored during generation
+ invalid_token_ids: Optional[List[int]] = None,
# p d mode used params
group_request_id: Optional[int] = None,
# move kv to deocde node, only used in pd mode
@@ -88,6 +90,7 @@ def __init__(
self.guided_grammar = guided_grammar
self.guided_json = guided_json
self.allowed_token_ids = allowed_token_ids
+ self.invalid_token_ids = invalid_token_ids
self.group_request_id = group_request_id
self.move_kv_to_decode_node = move_kv_to_decode_node
self.suggested_dp_index = suggested_dp_index
@@ -109,13 +112,18 @@ def __init__(
def load_generation_cfg(cls, weight_dir):
try:
generation_cfg = GenerationConfig.from_pretrained(weight_dir, trust_remote_code=True).to_dict()
- cls._do_sample = generation_cfg.get("do_sample", False)
- cls._presence_penalty = generation_cfg.get("presence_penalty", 0.0)
- cls._frequency_penalty = generation_cfg.get("frequency_penalty", 0.0)
- cls._repetition_penalty = generation_cfg.get("repetition_penalty", 1.0)
- cls._temperature = generation_cfg.get("temperature", 1.0)
- cls._top_p = generation_cfg.get("top_p", 1.0)
- cls._top_k = generation_cfg.get("top_k", -1)
+
+ def _cfg(key, default):
+ v = generation_cfg.get(key)
+ return v if v is not None else default
+
+ cls._do_sample = _cfg("do_sample", False)
+ cls._presence_penalty = _cfg("presence_penalty", 0.0)
+ cls._frequency_penalty = _cfg("frequency_penalty", 0.0)
+ cls._repetition_penalty = _cfg("repetition_penalty", 1.0)
+ cls._temperature = _cfg("temperature", 1.0)
+ cls._top_p = _cfg("top_p", 1.0)
+ cls._top_k = _cfg("top_k", -1)
cls._stop_sequences = generation_cfg.get("stop", None)
except:
pass
@@ -267,6 +275,7 @@ def to_dict(self):
ret["guided_grammar"] = self.guided_grammar
ret["guided_json"] = self.guided_json
ret["allowed_token_ids"] = self.allowed_token_ids
+ ret["invalid_token_ids"] = self.invalid_token_ids
ret["move_kv_to_decode_node"] = self.move_kv_to_decode_node
return ret
diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py
index f489aac9c2..5c7e56843b 100644
--- a/lightllm/server/core/objs/req.py
+++ b/lightllm/server/core/objs/req.py
@@ -1,6 +1,7 @@
import os
import math
import ctypes
+import base64
import numpy as np
import time
from .sampling_params import SamplingParams
@@ -13,6 +14,7 @@
from lightllm.utils.kv_cache_utils import compute_token_list_hash
from typing import List, Any, Union
from lightllm.utils.log_utils import init_logger
+from lightllm.utils.shm_utils import create_or_link_shm
logger = init_logger(__name__)
@@ -24,19 +26,20 @@ class FinishStatus(ctypes.Structure):
NO_FINISH = 0
FINISHED_STOP = 1
FINISHED_LENGTH = 2
+ FINISHED_ABORTED = 3
def __init__(self, init_state=NO_FINISH):
self.status = init_state
def set_status(self, new_status):
- assert 0 <= new_status <= 2
+ assert 0 <= new_status <= 3
self.status = new_status
def get_status(self):
return self.status
def is_finished(self):
- return self.FINISHED_STOP <= self.status <= self.FINISHED_LENGTH
+ return self.FINISHED_STOP <= self.status <= self.FINISHED_ABORTED
def is_stopped(self):
return self.status == self.FINISHED_STOP
@@ -49,6 +52,8 @@ def get_finish_reason(self):
return "stop"
elif self.status == self.FINISHED_LENGTH:
return "length"
+ elif self.status == self.FINISHED_ABORTED:
+ return "abort"
return None
@@ -122,6 +127,8 @@ class Req(ctypes.Structure):
("cpu_cache_match_page_indexes", CpuCachePageList),
# 分块hash的块大小
("cpu_cache_token_page_size", ctypes.c_int),
+ # Number of tokens in routing data SHM, written by model worker, read by HTTP server.
+ ("shm_routing_num_tokens", ctypes.c_int),
]
def get_str(self):
@@ -179,6 +186,7 @@ def init(
self._mtp_step = get_env_start_args().mtp_step
self.stop_str_matched = False
self.stop_str_matched_token_index = -1
+ self.shm_routing_num_tokens = 0
self.post_init()
@@ -227,6 +235,69 @@ def link_logprobs_shm_array(self):
self.shm_logprobs.link_shm()
return
+ def create_routing_data_shm_array(self, num_moe_layers: int, num_tokens: int, topk: int, np_dtype=np.int8):
+ """Create routing SHM at actual size (on-demand, not pre-allocated).
+
+ Uses smart mode: links if same-sized SHM exists, otherwise creates new.
+ """
+ service_uni_name = get_unique_server_name()
+ name = f"{service_uni_name}_shm_routing_{self.index_in_shm_mem}"
+ shape = (num_tokens, num_moe_layers, topk)
+ self.shm_routing_data = ShmArray(name, shape, dtype=np_dtype)
+ self.shm_routing_data.create_shm()
+ self.shm_routing_num_tokens = num_tokens
+ return
+
+ def link_routing_data_shm_array(self, num_moe_layers: int, topk: int, np_dtype=np.int8):
+ """Link to routing SHM from the reader side (HTTP server)."""
+ if num_moe_layers == 0:
+ return
+ num_tokens = self.shm_routing_num_tokens
+ if num_tokens <= 0:
+ return
+ service_uni_name = get_unique_server_name()
+ name = f"{service_uni_name}_shm_routing_{self.index_in_shm_mem}"
+ shape = (num_tokens, num_moe_layers, topk)
+ self.shm_routing_data = ShmArray(name, shape, dtype=np_dtype)
+ self.shm_routing_data.link_shm()
+ return
+
+ def get_routing_data(self):
+ if not hasattr(self, "shm_routing_data") or self.shm_routing_data is None:
+ return None
+ return self.shm_routing_data.arr
+
+ def close_routing_data_shm_array(self):
+ """Close and unlink routing SHM (on-demand, no longer pooled)."""
+ if hasattr(self, "shm_routing_data") and self.shm_routing_data is not None:
+ self.shm_routing_data.close_shm()
+ self.shm_routing_data = None
+ self.shm_routing_num_tokens = 0
+ return
+
+ def get_routing_metadata(self, num_moe_layers: int, topk: int, dtype_id: int = 1):
+ if num_moe_layers == 0 or topk == 0:
+ return None
+ if self.shm_routing_num_tokens <= 0:
+ return None
+ try:
+ from lightllm.common.basemodel.routing_manager import routing_dtype_id_to_np
+
+ np_dtype = routing_dtype_id_to_np(dtype_id)
+ if not hasattr(self, "shm_routing_data") or self.shm_routing_data is None:
+ self.link_routing_data_shm_array(num_moe_layers, topk, np_dtype=np_dtype)
+ routing_data = self.get_routing_data()
+ if routing_data is None:
+ return None
+ return {
+ "shape": list(routing_data.shape),
+ "dtype": str(routing_data.dtype),
+ "data": base64.b64encode(routing_data.tobytes()).decode("ascii"),
+ }
+ except Exception as e:
+ logger.warning(f"Failed to read routing data for req {self.request_id}: {e}")
+ return None
+
def get_prompt_ids(self):
return self.shm_prompt_ids.arr[: self.input_len].tolist()
@@ -247,9 +318,8 @@ def can_release(self):
ref_count_ok = self.ref_count == 1
can_released_mark = self.can_released_mark
- if self.is_aborted and can_released_mark and ref_count_ok:
- return True
-
+ # if self.is_aborted and can_released_mark and ref_count_ok:
+ # return True
ok_finished_gen_req = self.finish_status.is_finished() or self.stop_str_matched
if ok_finished_gen_req and can_released_mark and ref_count_ok and self.out_tokens_queue.is_empty():
diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py
index d955aa6a87..31e2fbefed 100644
--- a/lightllm/server/core/objs/sampling_params.py
+++ b/lightllm/server/core/objs/sampling_params.py
@@ -17,6 +17,7 @@
REGULAR_CONSTRAINT_MAX_LENGTH = int(os.getenv("LIGHTLLM_REGULAR_CONSTRAINT_MAX_LENGTH", 2048))
GRAMMAR_CONSTRAINT_MAX_LENGTH = int(os.getenv("LIGHTLLM_GRAMMAR_CONSTRAINT_MAX_LENGTH", 2048))
JSON_SCHEMA_MAX_LENGTH = int(os.getenv("LIGHTLLM_JSON_SCHEMA_MAX_LENGTH", 2048))
+INVALID_TOKEN_IDS_MAX_LENGTH = int(os.getenv("LIGHTLLM_INVALID_TOKEN_IDS_MAX_LENGTH", 10))
class StopSequence(ctypes.Structure):
@@ -205,6 +206,25 @@ def to_list(self):
return list(self.ids[: self.size])
+class InvalidTokenIds(ctypes.Structure):
+ _pack_ = 4
+ _fields_ = [
+ ("ids", ctypes.c_int * INVALID_TOKEN_IDS_MAX_LENGTH),
+ ("size", ctypes.c_int),
+ ]
+
+ def initialize(self, ids: List[int]):
+ self.size = len(ids)
+ assert (
+ self.size <= INVALID_TOKEN_IDS_MAX_LENGTH
+ ), f"Too many invalid token IDs {self.size} > {INVALID_TOKEN_IDS_MAX_LENGTH}."
+ self.ids[: self.size] = ids[:]
+ return
+
+ def to_list(self):
+ return list(self.ids[: self.size])
+
+
class ExponentialDecayLengthPenalty(ctypes.Structure):
_pack_ = 4
_fields_ = [
@@ -293,6 +313,8 @@ class SamplingParams(ctypes.Structure):
("ignore_eos", ctypes.c_bool),
# the max number of image patches to be used in the internvl model, for the test
("image_max_patch_num", ctypes.c_int),
+ ("min_pixels", ctypes.c_int),
+ ("max_pixels", ctypes.c_int),
("max_new_tokens", ctypes.c_int),
("min_new_tokens", ctypes.c_int),
# Whether to count input tokens for presence_penalty, frequency_penalty and repetition_penalty
@@ -304,6 +326,8 @@ class SamplingParams(ctypes.Structure):
# processor which only retains scores for the given token ids. Defaults to None.
# allowed_token_ids only can be used in "--output_constraint_mode outlines" started server.
("allowed_token_ids", AllowedTokenIds),
+ # if provided, the invalid token ids will be ignored during generation
+ ("invalid_token_ids", InvalidTokenIds),
("stop_sequences", StopSequenceGroups),
("exponential_decay_length_penalty", ExponentialDecayLengthPenalty),
("group_request_id", ctypes.c_int64), # p d mode used params
@@ -333,22 +357,28 @@ class SamplingParams(ctypes.Structure):
def init(self, tokenizer, **kwargs):
super().__init__()
- self.best_of = kwargs.get("best_of", 1)
- self.n = kwargs.get("n", self.best_of)
- self.do_sample = kwargs.get("do_sample", SamplingParams._do_sample)
- self.presence_penalty = kwargs.get("presence_penalty", SamplingParams._presence_penalty)
- self.frequency_penalty = kwargs.get("frequency_penalty", SamplingParams._frequency_penalty)
- self.repetition_penalty = kwargs.get("repetition_penalty", SamplingParams._repetition_penalty)
- self.temperature = kwargs.get("temperature", SamplingParams._temperature)
- self.top_p = kwargs.get("top_p", SamplingParams._top_p)
- self.top_k = kwargs.get("top_k", SamplingParams._top_k)
- self.ignore_eos = kwargs.get("ignore_eos", False)
- self.image_max_patch_num = kwargs.get("image_max_patch_num", -1)
- self.max_new_tokens = kwargs.get("max_new_tokens", 16)
- self.min_new_tokens = kwargs.get("min_new_tokens", 1)
- self.input_penalty = kwargs.get("input_penalty", DEFAULT_INPUT_PENALTY)
- self.group_request_id = kwargs.get("group_request_id", -1)
- self.suggested_dp_index = kwargs.get("suggested_dp_index", -1)
+
+ def _get(key, default):
+ v = kwargs.get(key)
+ return v if v is not None else default
+
+ self.best_of = _get("best_of", 1)
+ self.n = _get("n", self.best_of)
+ self.do_sample = _get("do_sample", SamplingParams._do_sample)
+ self.presence_penalty = _get("presence_penalty", SamplingParams._presence_penalty)
+ self.frequency_penalty = _get("frequency_penalty", SamplingParams._frequency_penalty)
+ self.repetition_penalty = _get("repetition_penalty", SamplingParams._repetition_penalty)
+ self.temperature = _get("temperature", SamplingParams._temperature)
+ self.top_p = _get("top_p", SamplingParams._top_p)
+ self.top_k = _get("top_k", SamplingParams._top_k)
+ self.ignore_eos = _get("ignore_eos", False)
+ self.min_pixels = _get("min_pixels", -1)
+ self.max_pixels = _get("max_pixels", -1)
+ self.max_new_tokens = _get("max_new_tokens", 16)
+ self.min_new_tokens = _get("min_new_tokens", 1)
+ self.input_penalty = _get("input_penalty", DEFAULT_INPUT_PENALTY)
+ self.group_request_id = _get("group_request_id", -1)
+ self.suggested_dp_index = _get("suggested_dp_index", -1)
self.skip_special_tokens = kwargs.get("skip_special_tokens", SKIP_SPECIAL_TOKENS)
self.disable_prompt_cache = kwargs.get("disable_prompt_cache", False)
@@ -392,6 +422,11 @@ def init(self, tokenizer, **kwargs):
self.allowed_token_ids = AllowedTokenIds()
self.allowed_token_ids.initialize(allowed_token_ids)
+ # Initialize invalid_token_ids
+ invalid_token_ids = map(int, kwargs.get("logit_bias", {}).keys())
+ self.invalid_token_ids = InvalidTokenIds()
+ self.invalid_token_ids.initialize(list[int](invalid_token_ids))
+
if self.do_sample is False:
self.temperature = 1.0
self.top_p = 1.0
@@ -408,13 +443,18 @@ def init(self, tokenizer, **kwargs):
def load_generation_cfg(cls, weight_dir):
try:
generation_cfg = GenerationConfig.from_pretrained(weight_dir, trust_remote_code=True).to_dict()
- cls._do_sample = generation_cfg.get("do_sample", False)
- cls._presence_penalty = generation_cfg.get("presence_penalty", 0.0)
- cls._frequency_penalty = generation_cfg.get("frequency_penalty", 0.0)
- cls._repetition_penalty = generation_cfg.get("repetition_penalty", 1.0)
- cls._temperature = generation_cfg.get("temperature", 1.0)
- cls._top_p = generation_cfg.get("top_p", 1.0)
- cls._top_k = generation_cfg.get("top_k", -1)
+
+ def _cfg(key, default):
+ v = generation_cfg.get(key)
+ return v if v is not None else default
+
+ cls._do_sample = _cfg("do_sample", False)
+ cls._presence_penalty = _cfg("presence_penalty", 0.0)
+ cls._frequency_penalty = _cfg("frequency_penalty", 0.0)
+ cls._repetition_penalty = _cfg("repetition_penalty", 1.0)
+ cls._temperature = _cfg("temperature", 1.0)
+ cls._top_p = _cfg("top_p", 1.0)
+ cls._top_k = _cfg("top_k", -1)
except:
pass
@@ -482,6 +522,8 @@ def to_dict(self):
"image_max_patch_num": self.image_max_patch_num,
"max_new_tokens": self.max_new_tokens,
"min_new_tokens": self.min_new_tokens,
+ "min_pixels": self.min_pixels,
+ "max_pixels": self.max_pixels,
"exponential_decay_length_penalty": self.exponential_decay_length_penalty.to_tuple(),
"stop_sequences": self.stop_sequences.to_list(),
"best_of": self.best_of,
@@ -490,6 +532,7 @@ def to_dict(self):
"guided_grammar": self.guided_grammar.to_str(),
"guided_json": self.guided_json.to_str(),
"allowed_token_ids": self.allowed_token_ids.to_list(),
+ "invalid_token_ids": self.invalid_token_ids.to_list(),
"group_request_id": self.group_request_id,
"move_kv_to_decode_node": self.move_kv_to_decode_node.to_dict(),
"skip_special_tokens": self.skip_special_tokens,
diff --git a/lightllm/server/core/objs/shm_array.py b/lightllm/server/core/objs/shm_array.py
index c5ad512c6b..1bf20535ad 100644
--- a/lightllm/server/core/objs/shm_array.py
+++ b/lightllm/server/core/objs/shm_array.py
@@ -26,6 +26,19 @@ def link_shm(self):
self.arr = np.ndarray(self.shape, dtype=self.dtype, buffer=self.shm.buf)
return
+ def link_shm_partial(self):
+ """Link to an existing SHM that may be larger than the needed shape."""
+ self.shm = create_or_link_shm(self.name, -1, force_mode="link")
+ assert self.shm.size >= self.dest_size, f"SHM {self.name} too small: need {self.dest_size}, got {self.shm.size}"
+ self.arr = np.ndarray(self.shape, dtype=self.dtype, buffer=self.shm.buf)
+
+ def detach_shm(self):
+ """Close handle without unlinking (SHM persists for reuse)."""
+ if self.shm is not None:
+ self.shm.close()
+ self.shm = None
+ self.arr = None
+
def close_shm(self):
if self.shm is not None:
self.shm.close()
diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py
index a369cf7f7f..4ac0a4dd2b 100644
--- a/lightllm/server/core/objs/start_args_type.py
+++ b/lightllm/server/core/objs/start_args_type.py
@@ -1,37 +1,42 @@
from dataclasses import dataclass, field
from typing import List, Optional, Tuple
-# 只是为了更好的编程提示
+# 服务启动参数
@dataclass
class StartArgs:
run_mode: str = field(
default="normal",
- metadata={"choices": ["normal", "prefill", "decode", "pd_master", "nixl_prefill", "nixl_decode"]},
+ metadata={
+ "choices": ["normal", "prefill", "decode", "nixl_prefill", "nixl_decode", "pd_master", "config_server"]
+ },
)
host: str = field(default="127.0.0.1")
port: int = field(default=8000)
+ httpserver_workers: int = field(default=1)
zmq_mode: str = field(
default="ipc:///tmp/",
metadata={"help": "use socket mode or ipc mode, only can be set in ['tcp://', 'ipc:///tmp/']"},
)
- pd_master_ip: str = field(default="127.0.0.1")
+ pd_master_ip: str = field(default="0.0.0.0")
pd_master_port: int = field(default=1212)
config_server_host: str = field(default=None)
config_server_port: int = field(default=None)
pd_decode_rpyc_port: int = field(default=None)
- select_p_d_node_strategy: str = field(default=None)
+ select_p_d_node_strategy: str = field(
+ default="round_robin", metadata={"choices": ["random", "round_robin", "adaptive_load"]}
+ )
model_name: str = field(default="default_model_name")
model_dir: Optional[str] = field(default=None)
- tokenizer_mode: str = field(default="slow")
+ tokenizer_mode: str = field(default="fast")
load_way: str = field(default="HF")
max_total_token_num: Optional[int] = field(default=None)
mem_fraction: float = field(default=0.9)
batch_max_tokens: Optional[int] = field(default=None)
- eos_id: List[int] = field(default_factory=list)
+ eos_id: Optional[List[int]] = field(default=None)
tool_call_parser: Optional[str] = field(
- default=None, metadata={"choices": ["llama3", "qwen25", "mistral", "deepseekv3", "kimi_k2", "qwen"]}
+ default=None, metadata={"choices": ["qwen25", "llama3", "mistral", "deepseekv3", "qwen"]}
)
reasoning_parser: Optional[str] = field(
default=None,
@@ -59,7 +64,7 @@ class StartArgs:
dp: int = field(default=1)
nnodes: int = field(default=1)
node_rank: int = field(default=0)
- max_req_total_len: int = field(default=2048 + 1024)
+ max_req_total_len: int = field(default=16384)
nccl_host: str = field(default="127.0.0.1")
nccl_port: int = field(default=None)
use_config_server_to_init_nccl: bool = field(default=False)
@@ -75,7 +80,7 @@ class StartArgs:
disable_chunked_prefill: bool = field(default=False)
diverse_mode: bool = field(default=False)
token_healing_mode: bool = field(default=False)
- output_constraint_mode: str = field(default="none", metadata={"choices": ["none", "simple", "xgrammar"]})
+ output_constraint_mode: str = field(default="none", metadata={"choices": ["outlines", "xgrammar", "none"]})
first_token_constraint_mode: bool = field(default=False)
enable_multimodal: bool = field(default=False)
enable_multimodal_audio: bool = field(default=False)
@@ -95,11 +100,11 @@ class StartArgs:
health_monitor: bool = field(default=False)
metric_gateway: Optional[str] = field(default=None)
job_name: str = field(default="lightllm")
- grouping_key: List[str] = field(default_factory=list)
+ grouping_key: List[str] = field(default_factory=lambda: [])
push_interval: int = field(default=10)
visual_infer_batch_size: int = field(default=None)
visual_send_batch_size: int = field(default=1)
- visual_gpu_ids: List[int] = field(default_factory=lambda: [0])
+ visual_gpu_ids: List[int] = field(default=None)
visual_tp: int = field(default=1)
visual_dp: int = field(default=1)
visual_nccl_ports: List[int] = field(default=None)
@@ -111,9 +116,9 @@ class StartArgs:
graph_split_batch_size: int = field(default=32)
graph_grow_step_size: int = field(default=16)
graph_max_len_in_batch: int = field(default=0)
- quant_type: Optional[str] = field(default=None)
+ quant_type: Optional[str] = field(default="none")
quant_cfg: Optional[str] = field(default=None)
- vit_quant_type: Optional[str] = field(default=None)
+ vit_quant_type: Optional[str] = field(default="none")
vit_quant_cfg: Optional[str] = field(default=None)
llm_prefill_att_backend: List[str] = field(
default=("auto",), metadata={"choices": ["auto", "triton", "fa3", "flashinfer"]}
@@ -144,7 +149,7 @@ class StartArgs:
pd_node_id: int = field(default=-1)
enable_cpu_cache: bool = field(default=False)
cpu_cache_storage_size: float = field(default=2)
- cpu_cache_token_page_size: int = field(default=64)
+ cpu_cache_token_page_size: int = field(default=256)
enable_disk_cache: bool = field(default=False)
disk_cache_storage_size: float = field(default=10)
disk_cache_dir: Optional[str] = field(default=None)
@@ -162,3 +167,23 @@ class StartArgs:
# multi_modal
enable_multimodal: bool = field(default=False)
enable_multimodal_audio: bool = field(default=False)
+
+ httpserver_workers: int = field(default=1)
+ disable_shm_warning: bool = field(default=False)
+ dp_balancer: str = field(default="bs_balancer", metadata={"choices": ["round_robin", "bs_balancer"]})
+ enable_custom_allgather: bool = field(default=False)
+ enable_fused_shared_experts: bool = field(default=False)
+ enable_mps: bool = field(default=False)
+ multinode_router_gloo_port: int = field(default=20001)
+ schedule_time_interval: float = field(default=0.03)
+ use_dynamic_prompt_cache: bool = field(default=False)
+ disable_custom_allreduce: bool = field(default=False)
+ enable_torch_memory_saver: bool = field(default=False)
+ enable_weight_cpu_backup: bool = field(default=False)
+ hardware_platform: str = field(default="cuda", metadata={"choices": ["cuda", "musa"]})
+ enable_torch_fallback: bool = field(default=False)
+ enable_triton_fallback: bool = field(default=False)
+
+ enable_return_routed_experts: bool = field(default=False)
+
+ weight_version: str = "default"
diff --git a/lightllm/server/detokenization/decode_req.py b/lightllm/server/detokenization/decode_req.py
index 9aa3a8effc..c77379986c 100644
--- a/lightllm/server/detokenization/decode_req.py
+++ b/lightllm/server/detokenization/decode_req.py
@@ -62,11 +62,7 @@ def stop_sequences_str_match(self) -> bool:
return False
def need_detoken(self):
- if (
- (not self.req.is_aborted)
- and (not self.req.stop_str_matched)
- and len(self.output_ids) < self.req.candetoken_out_len
- ):
+ if (not self.req.stop_str_matched) and len(self.output_ids) < self.req.candetoken_out_len:
return True
return False
@@ -83,8 +79,6 @@ def get_decode_tokens(self):
return prefix_tokens, read_tokens
def can_set_release_mark(self):
- if self.req.is_aborted:
- return True
if self.req.stop_str_matched:
return True
if (
diff --git a/lightllm/server/detokenization/manager.py b/lightllm/server/detokenization/manager.py
index 389171ba8a..17a47dfde6 100644
--- a/lightllm/server/detokenization/manager.py
+++ b/lightllm/server/detokenization/manager.py
@@ -6,7 +6,6 @@
import zmq
import inspect
from lightllm.server.core.objs import ShmReqManager, StartArgs
-from lightllm.server.core.objs.io_objs import GroupReqIndexes
from lightllm.utils.graceful_utils import graceful_registry
from typing import Union, Dict, List
from .decode import decode_token
@@ -17,6 +16,14 @@
import time
from lightllm.utils.log_utils import init_logger
from lightllm.utils.envs_utils import get_unique_server_name
+from lightllm.server.io_struct import (
+ BaseReq,
+ GenerateResp,
+ FlushCacheResp,
+ ReleaseMemoryResp,
+ ResumeMemoryResp,
+ GeneralModelToHttpRpcRsp,
+)
logger = init_logger(__name__)
@@ -31,9 +38,9 @@ def __init__(
self.zmq_recv_socket = context.socket(zmq.PULL)
self.zmq_recv_socket.bind(f"{args.zmq_mode}127.0.0.1:{args.detokenization_port}")
- self.pub_to_httpserver = context.socket(zmq.PUB)
- self.pub_to_httpserver.bind(f"{args.zmq_mode}127.0.0.1:{args.http_server_port}")
- logger.info(f"pub_to_httpserver sendhwm {self.pub_to_httpserver.getsockopt(zmq.SNDHWM)}")
+ self.send_to_httpserver = context.socket(zmq.PUSH)
+ self.send_to_httpserver.bind(f"{args.zmq_mode}127.0.0.1:{args.http_server_port}")
+ logger.info(f"send_to_httpserver sendhwm {self.send_to_httpserver.getsockopt(zmq.SNDHWM)}")
self.tokenizer = get_tokenizer(args.model_dir, args.tokenizer_mode, trust_remote_code=args.trust_remote_code)
self.all_special_ids = set(self.tokenizer.all_special_ids)
self.req_id_to_out: Dict[int, DecodeReq] = {}
@@ -46,7 +53,7 @@ def _init_get_token_id_to_token_str(self):
self.token_id_to_token = {token_id: token for token, token_id in self.tokenizer.get_vocab().items()}
return
- def _add_new_group_req_index(self, recv_obj: GroupReqIndexes):
+ def _add_new_group_req_index(self, recv_obj: BaseReq):
for req_index in recv_obj.shm_req_indexes:
req = self.shm_req_manager.get_req_obj_by_index(req_index)
req.link_prompt_ids_shm_array()
@@ -74,8 +81,10 @@ def handle_loop(self):
try:
# 一次最多从 zmq 中取 recv_max_count 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。
for _ in range(recv_max_count):
- recv_obj: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK)
- assert isinstance(recv_obj, GroupReqIndexes)
+ recv_obj: BaseReq = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK)
+ if isinstance(recv_obj, GeneralModelToHttpRpcRsp):
+ self.send_to_httpserver.send_pyobj(recv_obj, protocol=pickle.HIGHEST_PROTOCOL)
+ continue
self._add_new_group_req_index(recv_obj=recv_obj)
# 当队列中存在较多的请求时,将一次接受的数量上调
@@ -146,7 +155,7 @@ def gen_token_out(self):
# 通知 httpserver 进程
if exist_decode:
- self.pub_to_httpserver.send_pyobj(None, protocol=pickle.HIGHEST_PROTOCOL)
+ self.send_to_httpserver.send_pyobj(GenerateResp(), protocol=pickle.HIGHEST_PROTOCOL)
self.remove_finished_reqs()
diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py
index 212e037e90..3ef778ca42 100644
--- a/lightllm/server/httpserver/manager.py
+++ b/lightllm/server/httpserver/manager.py
@@ -10,6 +10,7 @@
import hashlib
import datetime
import pickle
+import inspect
from frozendict import frozendict
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
@@ -25,16 +26,41 @@
from lightllm.server.core.objs import Req, FinishStatus, StartArgs
from lightllm.server.core.objs import SamplingParams
from lightllm.server.core.objs.out_token_circlequeue import LIGHTLLM_OUT_TOKEN_QUEUE_SIZE
-from lightllm.server.core.objs.io_objs import GroupReqObjs
from lightllm.server.core.objs.shm_req_manager import ShmReqManager
from lightllm.server.core.objs.atomic_array_lock import AtomicShmArrayLock, AsyncLock, AtomicLockItem
from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt
+from lightllm.common.basemodel.routing_manager import get_routing_config_shm
from lightllm.utils.log_utils import init_logger
from lightllm.server.metrics.manager import MetricClient
+from lightllm.server.io_struct import (
+ AbortReq,
+ BaseReq,
+ FlushCacheReq,
+ FlushCacheResp,
+ GenerateReq,
+ GenerateResp,
+ GenerateReqMeta,
+ GenerateReqIndex,
+ ReleaseMemoryReq,
+ ReleaseMemoryResp,
+ ResumeMemoryReq,
+ ResumeMemoryResp,
+ InitWeightsUpdateGroupReq,
+ InitWeightsUpdateGroupRsp,
+ DestroyWeightsUpdateGroupReq,
+ DestroyWeightsUpdateGroupRsp,
+ UpdateWeightsFromDistributedReq,
+ UpdateWeightsFromDistributedRsp,
+ UpdateWeightsFromTensorReq,
+ UpdateWeightsFromTensorRsp,
+ GeneralHttpToModelRpcReq,
+ GeneralModelToHttpRpcRsp,
+)
from lightllm.utils.statics_utils import MovingAverage
from lightllm.utils.config_utils import get_vocab_size
from lightllm.utils.envs_utils import get_unique_server_name
from lightllm.utils.error_utils import NixlPrefillNodeStopGenToken
+from lightllm.utils.torch_memory_saver_utils import MemoryTag
from rpyc.utils.classic import obtain
logger = init_logger(__name__)
@@ -74,7 +100,7 @@ def __init__(
self.multinode_req_manager = context.socket(zmq.PULL)
self.multinode_req_manager.bind(f"tcp://*:{args.multinode_httpmanager_port}")
logger.info(
- f"HttpServerManager listening for child node requests on *:{args.multinode_httpmanager_port}"
+ f"HttpServerManager listening for master node requests on *:{args.multinode_httpmanager_port}"
)
self.enable_multimodal = args.enable_multimodal
@@ -90,9 +116,8 @@ def __init__(
self.shm_req_manager = ShmReqManager()
# recv from detokenization
- self.zmq_recv_socket = context.socket(zmq.SUB)
+ self.zmq_recv_socket = context.socket(zmq.PULL)
self.zmq_recv_socket.connect(f"{args.zmq_mode}127.0.0.1:{args.http_server_port}")
- self.zmq_recv_socket.setsockopt(zmq.SUBSCRIBE, b"")
self.tokenizer = get_tokenizer(args.model_dir, args.tokenizer_mode, trust_remote_code=args.trust_remote_code)
@@ -114,6 +139,18 @@ def __init__(
# If the timemark is not updated for a pre-set time, a prob request will be sent to the backend.
self.latest_success_infer_time_mark = SharedInt(f"{get_unique_server_name()}_latest_success_infer_time_mark")
self.latest_success_infer_time_mark.set_value(int(time.time()))
+
+ # Cache routing config for MoE expert routing data extraction
+ self._routing_shm = get_routing_config_shm() if args.enable_return_routed_experts else None
+
+ self.is_pause = False
+ self.is_pause_cond = asyncio.Condition()
+
+ # 交互式请求 event
+ self.flush_cache_event: Optional[asyncio.Event] = None
+ self.release_memory_event: Optional[asyncio.Event] = None
+ self.resume_memory_event: Optional[asyncio.Event] = None
+ self.async_events_per_func: Dict[str, asyncio.Event] = {}
return
async def _alloc_resource(self, items, md5sums, token_nums, datas):
@@ -227,18 +264,32 @@ def tokens(self, prompt, multimodal_params, samping_params: SamplingParams, kwar
async def loop_for_request(self):
assert self.args.node_rank > 0
while True:
- (
- prompt,
- sampling_params,
- multimodal_params,
- ) = await self.multinode_req_manager.recv_pyobj()
- results_generator = self.generate(prompt, sampling_params, multimodal_params, None)
+ req_obj = await self.multinode_req_manager.recv_pyobj()
+ if req_obj is None:
+ continue
+ if isinstance(req_obj, GenerateReqMeta):
+ self.process_generate_request(req_obj)
+ elif isinstance(req_obj, AbortReq):
+ self.process_abort_request(req_obj)
+ else:
+ assert False, f"Unknown request type: {type(req_obj)}"
+ return
+
+ def process_generate_request(self, req_meta: GenerateReqMeta):
+ prompt = req_meta.prompt
+ sampling_params = req_meta.sampling_params
+ multimodal_params = req_meta.multimodal_params
+ results_generator = self.generate(prompt, sampling_params, multimodal_params, None)
- async def generate_wrapper(results_generator):
- async for _, _, _, _ in results_generator:
- pass
+ async def generate_wrapper(results_generator):
+ async for _, _, _, _ in results_generator:
+ pass
+
+ asyncio.create_task(generate_wrapper(results_generator))
+ return
- asyncio.create_task(generate_wrapper(results_generator))
+ def process_abort_request(self, request: AbortReq):
+ asyncio.create_task(self.abort_request(request))
return
def alloc_req_id(self, sampling_params, is_health_req: bool = False):
@@ -281,15 +332,15 @@ async def generate(
group_request_id = self.alloc_req_id(sampling_params, is_health_req)
try:
- original_multimodal_params = None
- if self.is_multinode_tp_master:
- original_multimodal_params = copy.deepcopy(multimodal_params)
-
if self.pd_mode.is_P_or_NORMAL():
await multimodal_params.verify_and_preload(request)
# 记录请求到达的相关信息
await self._log_req_header(request_headers, group_request_id)
+
+ async with self.is_pause_cond:
+ await self.is_pause_cond.wait_for(lambda: not self.is_pause)
+
# encode
prompt_ids = await self._encode(prompt, multimodal_params, sampling_params)
@@ -348,12 +399,17 @@ async def generate(
)
req_objs.append(req_obj)
- req_status = ReqStatus(group_request_id, multimodal_params, req_objs, start_time)
+ req_status = ReqStatus(
+ group_request_id=group_request_id,
+ prompt=prompt,
+ sampling_params=sampling_params,
+ multimodal_params=multimodal_params,
+ req_objs=req_objs,
+ start_time=start_time,
+ )
self.req_id_to_out_inf[group_request_id] = req_status
- await self.transfer_to_next_module_or_node(
- prompt, sampling_params, original_multimodal_params, req_status.group_req_objs
- )
+ await self.transfer_to_next_module_or_node(req_status.group_req_objs)
results_generator = self._wait_to_token_package(
start_time,
@@ -441,7 +497,21 @@ async def _encode(
# 这里的校验对多模态不是很充分, to do
if all(isinstance(e, int) for e in prompt):
- if not self.enable_multimodal and not self.pd_mode.is_D():
+ if self.enable_multimodal:
+ assert (
+ len(multimodal_params.images + multimodal_params.audios) <= self.args.cache_capacity
+ ), "too many multimodal items!"
+ if multimodal_params.audios:
+ assert self.args.enable_multimodal_audio, "audio multimodal not enabled"
+ await self._alloc_multimodal_resources(multimodal_params, sampling_params)
+ prompt_ids = self.tokenizer.encode(
+ prompt,
+ multimodal_params,
+ add_special_tokens=sampling_params.add_special_tokens,
+ already_tokenized=True,
+ )
+ return prompt_ids
+ elif not self.enable_multimodal and not self.pd_mode.is_D():
if all(e < self.vocab_size for e in prompt):
return prompt
else:
@@ -484,44 +554,49 @@ async def _check_and_repair_length(self, prompt_ids: List[int], sampling_params:
async def transfer_to_next_module_or_node(
self,
- prompt: str,
- sampling_params: SamplingParams,
- original_multimodal_params: MultimodalParams,
- group_req_objs: Optional[GroupReqObjs] = None,
+ req_obj: Optional["BaseReq"] = None,
):
# 多节点纯tp 运行模式下,master 节点需要将请求转发给slave节点.
+ req_to_next_node = req_obj.get_req_to_next_node()
+ self.transfer_to_next_node(req_to_next_node)
+ req_to_next_module = req_obj.get_req_to_next_module()
+ await self.transfer_to_next_module(req_to_next_module)
+ return
+
+ def transfer_to_next_node(
+ self,
+ req_to_next_node: Optional["BaseReq"] = None,
+ ):
if self.is_multinode_tp_master:
for sender in self.multinode_req_manager:
sender.send_pyobj(
- (prompt, sampling_params, original_multimodal_params),
+ req_to_next_node,
protocol=pickle.HIGHEST_PROTOCOL,
)
-
- await self.transfer_to_next_module(group_req_objs)
return
async def transfer_to_next_module(
self,
- group_req_objs: Optional[GroupReqObjs] = None,
+ req_to_next_module: Optional["GenerateReqIndex"] = None,
):
if self.pd_mode.is_P_or_NORMAL():
if self.enable_multimodal:
self.send_to_visual.send_pyobj(
- group_req_objs.to_group_req_index(),
+ req_to_next_module,
protocol=pickle.HIGHEST_PROTOCOL,
)
return
if self.args.enable_cpu_cache:
self.send_to_multi_level_kv_cache.send_pyobj(
- group_req_objs.to_group_req_index(),
+ req_to_next_module,
protocol=pickle.HIGHEST_PROTOCOL,
)
return
self.send_to_router.send_pyobj(
- group_req_objs.to_group_req_index(),
+ req_to_next_module,
protocol=pickle.HIGHEST_PROTOCOL,
)
return
@@ -529,7 +604,7 @@ async def transfer_to_next_module(
if self.pd_mode.is_D():
# 在 D 模式下,不需要传输真的多模态参数,因为其已经被 P 处理好了
self.send_to_router.send_pyobj(
- group_req_objs.to_group_req_index(),
+ req_to_next_module,
protocol=pickle.HIGHEST_PROTOCOL,
)
return
@@ -659,12 +734,24 @@ async def abort(self, group_req_id: int) -> bool:
logger.warning(f"aborted group_request_id {group_req_id} not exist")
return False
- group_req_objs: GroupReqObjs = req_status.group_req_objs
+ group_req_objs: GenerateReq = req_status.group_req_objs
for req in group_req_objs.shm_req_objs:
req.is_aborted = True
logger.warning(f"aborted group_request_id {group_req_objs.group_req_id}")
return True
+ async def abort_request(self, request: AbortReq):
+ request_id = request.request_id
+ abort_all = request.abort_all
+ if self.is_multinode_tp_master:
+ self.transfer_to_next_node(req_to_next_node=request)
+ if request_id is not None and not abort_all:
+ await self.abort(request_id)
+ if abort_all:
+ for group_req_id in list(self.req_id_to_out_inf.keys()):
+ await self.abort(group_req_id)
+ pass
+
async def recycle_resource_loop(self):
pre_time_mark = time.time()
@@ -686,6 +773,11 @@ async def recycle_resource_loop(self):
for req_status in release_req_status:
self.req_id_to_out_inf.pop(req_status.group_req_objs.group_req_id, None)
for req in req_status.group_req_objs.shm_req_objs:
+ if hasattr(req, "shm_routing_data") and req.shm_routing_data is not None:
+ try:
+ req.close_routing_data_shm_array()
+ except Exception as e:
+ logger.debug(f"Failed to close routing data shm for req {req.request_id}: {e}")
await self.shm_req_manager.async_put_back_req_obj(req)
await self.shm_req_manager.async_release_req_index(req.index_in_shm_mem)
await self._release_multimodal_resources(req_status.group_req_objs.multimodal_params)
@@ -721,65 +813,16 @@ async def handle_loop(self):
while True:
try:
- await asyncio.wait_for(self.zmq_recv_socket.recv_pyobj(), timeout=0.05)
+ recv_obj = await asyncio.wait_for(self.zmq_recv_socket.recv_pyobj(), timeout=0.05)
except asyncio.TimeoutError:
- pass
+ recv_obj = None
try:
- for group_req_id_ in list(self.req_id_to_out_inf.keys()):
- req_status = self.req_id_to_out_inf.get(group_req_id_, None)
- if req_status is None:
- continue
-
- token_list = []
- for req in req_status.group_req_objs.shm_req_objs:
- req_id = req.request_id
- read_token_count = 1
- if req.out_tokens_queue.is_full():
- read_token_count = LIGHTLLM_OUT_TOKEN_QUEUE_SIZE
-
- for _ in range(read_token_count):
- if not req.out_tokens_queue.is_empty():
-
- text, src_index, special, count_output_tokens = req.out_tokens_queue.peek()
- req.cumlogprob += float(req.shm_logprobs.arr[src_index])
- metadata = {
- "id": int(req.shm_prompt_ids.arr[src_index]),
- "logprob": float(req.shm_logprobs.arr[src_index]),
- "cumlogprob": float(req.cumlogprob) / count_output_tokens,
- "special": special,
- "count_output_tokens": count_output_tokens,
- "prompt_cache_len": req.prompt_cache_len,
- "cpu_prompt_cache_len": req.cpu_prompt_cache_len,
- "disk_prompt_cache_len": req.disk_prompt_cache_len,
- "mtp_accepted_token_num": req.mtp_accepted_token_num,
- }
- if self.args.return_all_prompt_logprobs:
- metadata.update(req.get_all_prompt_metadata())
- if self.args.use_reward_model:
- metadata["score"] = float(req.reward_score)
-
- req.out_tokens_queue.pop_no_ret()
-
- finished_token_index = (
- req.stop_str_matched_token_index if req.stop_str_matched else req.finish_token_index
- )
-
- if finished_token_index != src_index:
- token_list.append((req_id, text, metadata, FinishStatus()))
- else:
- if req.stop_str_matched:
- finish_status = FinishStatus(FinishStatus.FINISHED_STOP)
- else:
- finish_status = FinishStatus(req.finish_status.status)
-
- token_list.append((req_id, text, metadata, finish_status))
- else:
- break
+ if recv_obj is None or isinstance(recv_obj, GenerateResp):
+ await self._handle_recv_generate_request(recv_obj)
+ elif isinstance(recv_obj, GeneralModelToHttpRpcRsp):
+ await self._handle_recv_general_model_to_http_request(recv_obj)
- async with req_status.lock:
- req_status.out_token_info_list.extend(token_list)
- req_status.event.set()
except BaseException as e:
logger.exception(str(e))
raise e
@@ -787,13 +830,189 @@ async def handle_loop(self):
self.recycle_event.set()
return
+ async def _handle_recv_generate_request(self, recv_obj: GenerateReqMeta):
+ for group_req_id_ in list(self.req_id_to_out_inf.keys()):
+ req_status = self.req_id_to_out_inf.get(group_req_id_, None)
+ if req_status is None:
+ continue
+
+ token_list = []
+ for req in req_status.group_req_objs.shm_req_objs:
+ req_id = req.request_id
+ read_token_count = 1
+ if req.out_tokens_queue.is_full():
+ read_token_count = LIGHTLLM_OUT_TOKEN_QUEUE_SIZE
+
+ for _ in range(read_token_count):
+ if not req.out_tokens_queue.is_empty():
+
+ text, src_index, special, count_output_tokens = req.out_tokens_queue.peek()
+ req.cumlogprob += float(req.shm_logprobs.arr[src_index])
+ metadata = {
+ "id": int(req.shm_prompt_ids.arr[src_index]),
+ "logprob": float(req.shm_logprobs.arr[src_index]),
+ "cumlogprob": float(req.cumlogprob) / count_output_tokens,
+ "special": special,
+ "count_output_tokens": count_output_tokens,
+ "prompt_cache_len": req.prompt_cache_len,
+ "cpu_prompt_cache_len": req.cpu_prompt_cache_len,
+ "mtp_accepted_token_num": req.mtp_accepted_token_num,
+ }
+ if self.args.return_all_prompt_logprobs:
+ metadata.update(req.get_all_prompt_metadata())
+ if self.args.use_reward_model:
+ metadata["score"] = float(req.reward_score)
+
+ req.out_tokens_queue.pop_no_ret()
+
+ finished_token_index = (
+ req.stop_str_matched_token_index if req.stop_str_matched else req.finish_token_index
+ )
+
+ if finished_token_index != src_index:
+ token_list.append((req_id, text, metadata, FinishStatus()))
+ else:
+ if req.stop_str_matched:
+ finish_status = FinishStatus(FinishStatus.FINISHED_STOP)
+ else:
+ finish_status = FinishStatus(req.finish_status.status)
+
+ if self._routing_shm is not None:
+ _num_moe = int(self._routing_shm.arr[0])
+ _topk = int(self._routing_shm.arr[1])
+ _dtype_id = int(self._routing_shm.arr[2])
+ if _num_moe > 0:
+ routing_meta = req.get_routing_metadata(_num_moe, _topk, dtype_id=_dtype_id)
+ if routing_meta is not None:
+ metadata["routed_experts"] = routing_meta
+
+ token_list.append((req_id, text, metadata, finish_status))
+ else:
+ break
+
+ async with req_status.lock:
+ req_status.out_token_info_list.extend(token_list)
+ req_status.event.set()
+
+ async def _handle_recv_general_model_to_http_request(self, recv_obj: GeneralModelToHttpRpcRsp):
+ assert recv_obj.func_name is not None
+ event = await self.get_event_for_func(recv_obj.func_name)
+ event.result = recv_obj
+ event.set()
+ return
+
+ async def pause_generation(self):
+ # 因为请求是从master node转发到slave node的
+ # 所以只要master暂停了,slave自然暂停。
+ if self.is_pause:
+ return
+ async with self.is_pause_cond:
+ self.is_pause = True
+ while True:
+ await self.abort_request(AbortReq(request_id=None, abort_all=True))
+ running_req_num = len(list(self.req_id_to_out_inf.keys()))
+ if running_req_num == 0:
+ break
+ await asyncio.sleep(1.0)
+
+ async def continue_generation(self):
+ async with self.is_pause_cond:
+ self.is_pause = False
+ self.is_pause_cond.notify_all()
+
+ async def get_event_for_func(self, func_name: str) -> asyncio.Event:
+ if func_name not in self.async_events_per_func:
+ self.async_events_per_func[func_name] = asyncio.Event()
+ return self.async_events_per_func[func_name]
+
+ async def http_to_model_special_request(
+ self, request: GeneralHttpToModelRpcReq, timeout: int = 300
+ ) -> GeneralModelToHttpRpcRsp:
+ event = await self.get_event_for_func(request.func_name)
+ await self.transfer_to_next_module(request)
+ try:
+ await asyncio.wait_for(event.wait(), timeout=timeout)
+ ret = event.result
+
+ except asyncio.TimeoutError:
+ ret = GeneralModelToHttpRpcRsp(success=False, msg="wait for response timeout", func_name=request.func_name)
+ except Exception as e:
+ ret = GeneralModelToHttpRpcRsp(
+ success=False, msg="wait for response error: %s" % str(e), func_name=request.func_name
+ )
+ return ret
+
+ async def flush_cache(self, request: FlushCacheReq):
+ return await self.http_to_model_special_request(
+ GeneralHttpToModelRpcReq(func_name="flush_cache", func_args=request)
+ )
+
+ async def release_memory_occupation(self, request: ReleaseMemoryReq):
+ assert len(self.req_id_to_out_inf) == 0, "there are still requests running, cannot release memory occupation"
+ # 暂停接受请求,除非resume
+ await self.pause_generation()
+ return await self.http_to_model_special_request(
+ GeneralHttpToModelRpcReq(func_name="release_memory_occupation", func_args=request.tags)
+ )
+
+ async def resume_memory_occupation(self, request: ResumeMemoryReq):
+ ret = await self.http_to_model_special_request(
+ GeneralHttpToModelRpcReq(func_name="resume_memory_occupation", func_args=request.tags)
+ )
+ if ret.success:
+ await self.continue_generation()
+ return ret
+
+ async def init_weights_update_group(self, request: InitWeightsUpdateGroupReq):
+ return await self.http_to_model_special_request(
+ GeneralHttpToModelRpcReq(func_name="init_weights_update_group", func_args=request)
+ )
+
+ async def destroy_weights_update_group(self, request: DestroyWeightsUpdateGroupReq):
+ return await self.http_to_model_special_request(
+ GeneralHttpToModelRpcReq(func_name="destroy_weights_update_group", func_args=request)
+ )
+
+ async def update_weights_from_distributed(self, request: UpdateWeightsFromDistributedReq):
+
+ if request.abort_all_requests:
+ await self.abort_request(AbortReq(abort_all=True))
+
+ if request.flush_cache:
+ await self.flush_cache(FlushCacheReq())
+
+ return await self.http_to_model_special_request(
+ GeneralHttpToModelRpcReq(func_name="update_weights_from_distributed", func_args=request)
+ )
+
+ async def update_weights_from_tensor(self, request: UpdateWeightsFromTensorReq) -> Tuple[bool, str]:
+ if request.abort_all_requests:
+ await self.abort_request(AbortReq(abort_all=True))
+
+ if request.flush_cache:
+ await self.flush_cache(FlushCacheReq())
+
+ return await self.http_to_model_special_request(
+ GeneralHttpToModelRpcReq(func_name="update_weights_from_tensor", func_args=request)
+ )
+
class ReqStatus:
- def __init__(self, group_request_id, multimodal_params, req_objs: List[Req], start_time) -> None:
+ def __init__(
+ self,
+ group_request_id: int,
+ prompt: str,
+ sampling_params: SamplingParams,
+ multimodal_params: MultimodalParams,
+ req_objs: List[Req],
+ start_time,
+ ) -> None:
self.lock = asyncio.Lock()
self.event = asyncio.Event()
- self.group_req_objs = GroupReqObjs(
+ self.group_req_objs = GenerateReq(
group_req_id=group_request_id,
+ prompt=prompt,
+ sampling_params=sampling_params,
multimodal_params=multimodal_params,
shm_req_objs=req_objs,
time_mark=start_time,
diff --git a/lightllm/server/io_struct.py b/lightllm/server/io_struct.py
new file mode 100644
index 0000000000..e04e8871ce
--- /dev/null
+++ b/lightllm/server/io_struct.py
@@ -0,0 +1,195 @@
+from abc import ABC
+from dataclasses import dataclass
+from lightllm.server.core.objs.req import Req
+from lightllm.server.core.objs.sampling_params import SamplingParams
+from lightllm.server.multimodal_params import MultimodalParams
+from typing import List, Optional, Any, Union
+from lightllm.utils.torch_memory_saver_utils import MemoryTag
+
+
+@dataclass
+class BaseReq(ABC):
+ def get_req_to_next_node(self):
+ return self
+
+ def get_req_to_next_module(self):
+ return self
+
+
+@dataclass
+class BaseRsp(ABC):
+ success: bool
+ msg: Optional[str]
+
+
+# for next node
+@dataclass
+class GenerateReqMeta(BaseReq):
+ prompt: str
+ sampling_params: SamplingParams
+ multimodal_params: MultimodalParams
+
+
+# for next module
+@dataclass
+class GenerateReqIndex(BaseReq):
+ group_req_id: int
+ multimodal_params: MultimodalParams
+ shm_req_indexes: List[int]
+ time_mark: float
+
+
+@dataclass
+class GenerateReq(BaseReq):
+ group_req_id: int
+ prompt: str
+ sampling_params: SamplingParams
+ multimodal_params: MultimodalParams
+ shm_req_objs: List[Req]
+ time_mark: float
+
+ def get_req_to_next_module(self):
+ # 已经完成跨节点转发,可以释放图片原始资源
+ self.multimodal_params.free()
+ return GenerateReqIndex(
+ group_req_id=self.group_req_id,
+ multimodal_params=self.multimodal_params,
+ shm_req_indexes=[req.index_in_shm_mem for req in self.shm_req_objs],
+ time_mark=self.time_mark,
+ )
+
+ def get_req_to_next_node(self):
+ return GenerateReqMeta(
+ prompt=self.prompt,
+ sampling_params=self.sampling_params,
+ multimodal_params=self.multimodal_params,
+ )
+
+
+@dataclass
+class GenerateResp(BaseReq):
+ pass
+
+
+@dataclass
+class FlushCacheReq(BaseReq):
+ pass
+
+
+@dataclass
+class FlushCacheResp(BaseReq):
+ success: bool
+
+
+@dataclass
+class AbortReq(BaseReq):
+ # 外部调用传入,等同内部的 group_req_id
+ request_id: int = None
+ abort_all: bool = False
+
+
+@dataclass
+class ReleaseMemoryReq(BaseReq):
+ tags: Optional[List[MemoryTag]] = None
+
+
+@dataclass
+class ReleaseMemoryResp(BaseReq):
+ success: bool
+
+
+@dataclass
+class ResumeMemoryReq(BaseReq):
+ tags: Optional[List[MemoryTag]] = None
+
+
+@dataclass
+class ResumeMemoryResp(BaseReq):
+ success: bool
+
+
+@dataclass
+class GeneralHttpToModelRpcReq(BaseReq):
+ func_name: str
+ func_args: Optional[Any] = None
+
+
+@dataclass
+class GeneralModelToHttpRpcRsp(BaseRsp):
+ func_name: str
+ func_rsp: Optional[Any] = None
+
+
+@dataclass
+class InitWeightsUpdateGroupReq(BaseReq):
+ # The master address
+ master_address: str
+ # The master port
+ master_port: int
+ # The rank offset
+ rank_offset: int
+ # The world size
+ world_size: int
+ # The group name
+ group_name: str = "weight_update_group"
+ # The backend
+ backend: str = "nccl"
+
+
+@dataclass
+class InitWeightsUpdateGroupRsp(BaseRsp):
+ pass
+
+
+@dataclass
+class DestroyWeightsUpdateGroupReq(BaseReq):
+ group_name: str = "weight_update_group"
+
+
+@dataclass
+class DestroyWeightsUpdateGroupRsp(BaseRsp):
+ pass
+
+
+@dataclass
+class UpdateWeightsFromDistributedReq(BaseReq):
+ names: List[str]
+ dtypes: List[str]
+ shapes: List[List[int]]
+ # The group name
+ group_name: str = "weight_update_group"
+ # Whether to flush the cache after updating weights
+ flush_cache: bool = True
+ # Whether to abort all requests before updating weights
+ abort_all_requests: bool = False
+ # Optional: Update weight version along with weights
+ weight_version: Optional[str] = None
+
+
+@dataclass
+class UpdateWeightsFromDistributedRsp(BaseRsp):
+ pass
+
+
+@dataclass
+class UpdateWeightsFromTensorReq(BaseReq):
+ """Update model weights from tensor input.
+
+ - Tensors are serialized for transmission
+ - Data is structured in JSON for easy transmission over HTTP
+ """
+
+ serialized_named_tensors: List[Union[str, bytes]]
+ # Optional format specification for loading
+ load_format: Optional[str] = None
+ # Whether to flush the cache after updating weights
+ flush_cache: bool = True
+ # Whether to abort all requests before updating weights
+ abort_all_requests: bool = False
+ # Optional: Update weight version along with weights
+ weight_version: Optional[str] = None
+
+
+@dataclass
+class UpdateWeightsFromTensorRsp(BaseRsp):
+ pass
diff --git a/lightllm/server/multi_level_kv_cache/manager.py b/lightllm/server/multi_level_kv_cache/manager.py
index 1de1b502c9..b8ff820ec5 100644
--- a/lightllm/server/multi_level_kv_cache/manager.py
+++ b/lightllm/server/multi_level_kv_cache/manager.py
@@ -12,7 +12,7 @@
from queue import Queue
from typing import List
from lightllm.server.core.objs import ShmReqManager, Req, StartArgs
-from lightllm.server.core.objs.io_objs import GroupReqIndexes
+from lightllm.server.io_struct import GenerateReqIndex, BaseReq
from lightllm.utils.graceful_utils import graceful_registry
from .cpu_cache_client import CpuKvCacheClient
from lightllm.utils.log_utils import init_logger
@@ -135,7 +135,7 @@ def _disk_cache_match(self, token_hash_list: List[int], all_pages: List[int]) ->
self.cpu_cache_client.lock.release()
return all_pages, len(new_page_indexes)
- def _handle_group_req_multi_cache_match(self, group_req_indexes: GroupReqIndexes, start_time: float):
+ def _handle_group_req_multi_cache_match(self, group_req_indexes: GenerateReqIndex, start_time: float):
"""
match cpu cache and disk cache pages
"""
@@ -198,8 +198,9 @@ def recv_loop(self):
try:
# 一次最多从 zmq 中取 recv_max_count 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。
for _ in range(recv_max_count):
- recv_obj: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK)
- assert isinstance(recv_obj, GroupReqIndexes)
+ recv_obj: BaseReq = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK)
+ if not isinstance(recv_obj, GenerateReqIndex):
+ continue
recv_objs.append(recv_obj)
start_time = recv_obj.time_mark
diff --git a/lightllm/server/multimodal_params.py b/lightllm/server/multimodal_params.py
index d957c7649a..72db3d80cd 100644
--- a/lightllm/server/multimodal_params.py
+++ b/lightllm/server/multimodal_params.py
@@ -56,9 +56,11 @@ def read(self):
assert self._preload_data is not None
ans = self._preload_data
self._preload_data = None
- self._data = None
return ans
+ def free(self):
+ self._data = None
+
def to_dict(self):
ret = {}
ret["uuid"] = self.uuid
@@ -121,9 +123,11 @@ def read(self):
assert self._preload_data is not None
ans = self._preload_data
self._preload_data = None
- self._data = None
return ans
+ def free(self):
+ self._data = None
+
def to_dict(self):
ret = {}
ret["uuid"] = self.uuid
@@ -174,3 +178,10 @@ def to_origin_dict(self):
ret = {}
ret["images"] = [i.to_origin_dict() for i in self.images]
return ret
+
+ def free(self):
+ for image in self.images:
+ image.free()
+ for audio in self.audios:
+ audio.free()
+ return
diff --git a/lightllm/server/req_id_generator.py b/lightllm/server/req_id_generator.py
index 9bf9040c30..20da121dc0 100644
--- a/lightllm/server/req_id_generator.py
+++ b/lightllm/server/req_id_generator.py
@@ -30,7 +30,8 @@ def __init__(self):
self.current_id.arr[0] = 0
self.current_id.arr[1] = 0
self.lock = AtomicShmLock(f"{get_unique_server_name()}_req_id_gen_lock")
- self._wait_all_workers_ready()
+ if self.args.httpserver_workers > 1:
+ self._wait_all_workers_ready()
logger.info("ReqIDGenerator init finished")
def _wait_all_workers_ready(self):
diff --git a/lightllm/server/router/dynamic_prompt/radix_cache.py b/lightllm/server/router/dynamic_prompt/radix_cache.py
index c517748984..e96b1d1a32 100644
--- a/lightllm/server/router/dynamic_prompt/radix_cache.py
+++ b/lightllm/server/router/dynamic_prompt/radix_cache.py
@@ -424,6 +424,30 @@ def clear_tree_nodes(self):
self.refed_tokens_num.arr[0] = 0
return
+ def flush_cache(self):
+ nodes_to_clear = collections.deque(self.root_node.children.values())
+ self.root_node.children.clear()
+ while nodes_to_clear:
+ node = nodes_to_clear.popleft()
+ nodes_to_clear.extend(node.children.values())
+ node.parent = None
+ node.children.clear()
+
+ self.root_node.token_id_key[:] = 0
+ self.root_node.token_mem_index_value[:] = 0
+ self.root_node.ref_counter = 1 # 保持为1,确保不会被evict
+ self.root_node.time_id = time_gen.generate_time_id()
+ self.root_node.node_value_len = 0
+ self.root_node.node_prefix_total_len = 0
+
+ self.evict_tree_set.clear()
+ self.evict_tree_set.add(self.root_node)
+
+ self.tree_total_tokens_num.arr[0] = 0
+ self.refed_tokens_num.arr[0] = 0
+
+ return
+
def dec_node_ref_counter(self, node: TreeNode):
if node is None:
return
diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py
index ac5c1abee3..7d5cfa7bea 100644
--- a/lightllm/server/router/manager.py
+++ b/lightllm/server/router/manager.py
@@ -5,6 +5,7 @@
import pickle
import inspect
import setproctitle
+import rpyc
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
import zmq
@@ -17,7 +18,6 @@
from .model_infer.model_rpc import start_model_process, ModelRpcClient
from .req_queue import build_req_queue
from lightllm.server.core.objs.io_objs import (
- GroupReqIndexes,
AbortedReqCmd,
StopStrMatchedReqCmd,
)
@@ -29,11 +29,23 @@
from lightllm.server.router.token_load import TokenLoad
from lightllm.server.metrics.manager import MetricClient
from lightllm.common.basemodel.infer_lock import g_router_lock
+from lightllm.server.io_struct import (
+ BaseReq,
+ GenerateReqIndex,
+ FlushCacheReq,
+ FlushCacheResp,
+ ReleaseMemoryReq,
+ ReleaseMemoryResp,
+ ResumeMemoryReq,
+ ResumeMemoryResp,
+ GeneralHttpToModelRpcReq,
+ GeneralModelToHttpRpcRsp,
+)
from lightllm.common.kv_cache_mem_manager import ReadOnlyStaticsMemoryManager
from lightllm.utils.graceful_utils import graceful_registry
from lightllm.utils.process_check import start_parent_check_thread
from lightllm.utils.envs_utils import get_unique_server_name
-
+from lightllm.utils.torch_memory_saver_utils import MemoryTag
logger = init_logger(__name__)
@@ -356,8 +368,13 @@ def _get_aborted_reqs_from_running_batch(self) -> List[Req]:
ans = []
if self.running_batch is None:
return ans
- for req in self.running_batch.reqs:
- if req.is_aborted and req._router_aborted is False:
+ aborted_req_mask = torch.tensor(
+ [req.is_aborted for req in self.running_batch.reqs], dtype=torch.bool, device="cpu"
+ )
+ if self.is_multinode_tp:
+ dist.all_reduce(aborted_req_mask, op=dist.ReduceOp.MIN, group=self.mulitnode_group)
+ for req, is_aborted in zip(self.running_batch.reqs, aborted_req_mask.numpy()):
+ if is_aborted and req._router_aborted is False:
req._router_aborted = True
ans.append(req)
return ans
@@ -406,7 +423,7 @@ def get_used_tokens(self, dp_index):
else:
return self.max_total_token_num - self.read_only_statics_mem_manager.get_unrefed_token_num(dp_index)
- def _add_req(self, group_req_indexes: GroupReqIndexes):
+ def _add_req(self, group_req_indexes: BaseReq):
req_group = []
for req_index in group_req_indexes.shm_req_indexes:
req = self.shm_req_manager.get_req_obj_by_index(req_index)
@@ -452,9 +469,22 @@ def _multinode_tp_generate_new_batch(self):
dist.broadcast_object_list(req_ids, src=0, group=self.mulitnode_group)
req_id_select_mark = [1 for _ in range(len(req_ids))]
req_id_select_mark = torch.tensor(req_id_select_mark, dtype=torch.int32, device="cpu")
+ # TODO: 这里可以合成一个 allreudce,req_id_select_mark + aborted_req_mask
dist.all_reduce(req_id_select_mark, op=dist.ReduceOp.MIN, group=self.mulitnode_group)
+ aborted_req_mask = torch.tensor(
+ [req.is_aborted for req in new_batch.reqs], dtype=torch.bool, device="cpu"
+ )
+ dist.all_reduce(aborted_req_mask, op=dist.ReduceOp.MIN, group=self.mulitnode_group)
back_req_list = []
- for req_id, select in zip(req_ids, req_id_select_mark.numpy()):
+ for req_id, select, is_aborted in zip(
+ req_ids, req_id_select_mark.numpy(), aborted_req_mask.numpy()
+ ):
+ # 释放多节点abort 请求,如果select == 0, is_aborted 一定为False
+ if is_aborted and select == 1:
+ req = new_batch.pop_req(req_id)
+ self.req_queue.free_aborted_req(req)
+ self.shm_req_manager.put_back_req_obj(req)
+ continue
if select == 0:
req = new_batch.pop_req(req_id)
back_req_list.append(req)
@@ -470,23 +500,28 @@ def _multinode_tp_generate_new_batch(self):
else:
req_ids = [None for _ in range(req_num)]
dist.broadcast_object_list(req_ids, src=0, group=self.mulitnode_group)
- all_req_id_set = set([req.request_id for req in self.req_queue.waiting_req_list])
+ # all_req_id_set = set([req.request_id for req in self.req_queue.waiting_req_list])
+ id_to_req_obj = {req.request_id: req for req in self.req_queue.waiting_req_list}
req_id_select_mark = []
+ aborted_req_mask = []
for req_id in req_ids:
- req_id_select_mark.append(1 if req_id in all_req_id_set else 0)
+ req_id_select_mark.append(1 if req_id in id_to_req_obj else 0)
+ aborted_req_mask.append(id_to_req_obj[req_id].is_aborted if req_id in id_to_req_obj else False)
req_id_select_mark = torch.tensor(req_id_select_mark, dtype=torch.int32, device="cpu")
dist.all_reduce(req_id_select_mark, op=dist.ReduceOp.MIN, group=self.mulitnode_group)
- select_req_ids = []
- for req_id, select in zip(req_ids, req_id_select_mark.numpy()):
- if select == 1:
- select_req_ids.append(req_id)
-
+ aborted_req_mask = torch.tensor(aborted_req_mask, dtype=torch.bool, device="cpu")
+ dist.all_reduce(aborted_req_mask, op=dist.ReduceOp.MIN, group=self.mulitnode_group)
select_reqs = []
- for req_id in select_req_ids:
- for req in self.req_queue.waiting_req_list:
- if req.request_id == req_id:
- select_reqs.append(req)
-
+ for req_id, select, is_aborted in zip(
+ req_ids, req_id_select_mark.numpy(), aborted_req_mask.numpy()
+ ):
+ if select == 1:
+ req = id_to_req_obj[req_id]
+ if is_aborted:
+ self.req_queue.free_aborted_req(req)
+ self.shm_req_manager.put_back_req_obj(req)
+ continue
+ select_reqs.append(req)
for req in select_reqs:
self.req_queue.waiting_req_list.remove(req)
if select_reqs:
@@ -507,13 +542,17 @@ async def _recv_new_reqs_and_schedule(self):
self.recv_max_count = 64
try:
+ # 多机tp需要广播给其他node的请求
+ special_reqs = []
# 一次最多从 zmq 中取 recv_max_count 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。
for _ in range(self.recv_max_count):
- recv_req: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK)
- if isinstance(recv_req, GroupReqIndexes):
+ recv_req: BaseReq = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK)
+ if isinstance(recv_req, GenerateReqIndex):
self._add_req(recv_req)
+ elif isinstance(recv_req, GeneralHttpToModelRpcReq):
+ special_reqs.append(recv_req)
else:
- assert False, f"Error Req Inf {recv_req}"
+ raise ValueError(f"Unknown request type: {type(recv_req)}")
# 当队列中存在较多的请求时,将一次接受的数量上调
self.recv_max_count = min(int(self.recv_max_count * 1.3), 256)
@@ -522,6 +561,8 @@ async def _recv_new_reqs_and_schedule(self):
# 当队列已经开始清空的时候,将一次接受的数量下调
self.recv_max_count = 64
+ self._process_special_reqs(special_reqs)
+
if self.is_multinode_tp:
self._multinode_tp_generate_new_batch()
else:
@@ -529,6 +570,44 @@ async def _recv_new_reqs_and_schedule(self):
self._generate_new_batch()
return
+ def _process_special_reqs(self, special_reqs: List[BaseReq]):
+ if self.is_multinode_tp:
+ special_reqs = self.broadcast_reqs_to_other_nodes(special_reqs)
+ for req in special_reqs:
+ assert isinstance(req, GeneralHttpToModelRpcReq), "special request must be GeneralHttpToModelRpcReq"
+ self.forward_to_model(req)
+
+ def broadcast_reqs_to_other_nodes(self, reqs: List[BaseReq]):
+ req_num = len(reqs)
+ if self.node_rank == 0:
+ req_nums = [len(reqs)]
+ dist.broadcast_object_list(req_nums, src=0, group=self.mulitnode_group)
+ req_num = req_nums[0]
+ if req_num > 0:
+ dist.broadcast_object_list(reqs, src=0, group=self.mulitnode_group)
+ else:
+ req_nums = [None]
+ dist.broadcast_object_list(req_nums, src=0, group=self.mulitnode_group)
+ req_num = req_nums[0]
+ if req_num > 0:
+ reqs = [None for _ in range(req_num)]
+ dist.broadcast_object_list(reqs, src=0, group=self.mulitnode_group)
+ return reqs
+
+ def forward_to_model(self, req: GeneralHttpToModelRpcReq) -> None:
+ ret = self.model_rpc_client.forward_to_model(req)
+ if self.is_multinode_tp:
+ output_list = [None for _ in self.nnodes] if self.node_rank == 0 else None
+ dist.gather_object(ret, output_list, dst=0, group=self.mulitnode_group)
+ for res in output_list:
+ res: GeneralModelToHttpRpcRsp
+ if not res.success:
+ ret = res
+ break
+
+ if self.node_rank == 0:
+ self.send_to_detokenization.send_pyobj(ret, protocol=pickle.HIGHEST_PROTOCOL)
+
def clean_up(self):
return
diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py
index 4b8b3c538f..66aeb6e95d 100644
--- a/lightllm/server/router/model_infer/infer_batch.py
+++ b/lightllm/server/router/model_infer/infer_batch.py
@@ -19,6 +19,7 @@
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.server.pd_io_struct import NIXLDecodeNodeInfo
from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient
+from lightllm.common.basemodel import routing_manager as _routing_mgr
logger = init_logger(__name__)
@@ -113,6 +114,16 @@ def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_prefix_cache:
return req_objs
+ def _extract_routing_data(self, req: "InferReq"):
+ if req.shm_req.shm_routing_num_tokens > 0:
+ return
+ mem_indexes = self.req_manager.req_to_token_indexs[req.req_idx][0 : req.cur_kv_len]
+ mgr = _routing_mgr.g_routing_capture_manager
+ routing_data = mgr.extract_routing_data(mem_indexes)
+ req.shm_req.create_routing_data_shm_array(mgr.num_moe_layers, req.cur_kv_len, mgr.topk, np_dtype=mgr.np_dtype)
+ req.shm_req.shm_routing_data.arr[:] = routing_data
+ req.shm_req.shm_routing_data.detach_shm()
+
def free_a_req_mem(self, free_token_index: List, req: "InferReq"):
if self.radix_cache is None:
free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][0 : req.cur_kv_len])
@@ -149,12 +160,18 @@ def _filter(self, finished_request_ids: List[int]):
if len(finished_request_ids) == 0:
return
+ need_routing_data = _routing_mgr.g_routing_capture_manager is not None
+
free_req_index = []
free_token_index = []
for request_id in finished_request_ids:
req: InferReq = self.requests_mapping.pop(request_id)
if self.args.diverse_mode:
req.clear_master_slave_state()
+
+ if need_routing_data:
+ self._extract_routing_data(req)
+
self.free_a_req_mem(free_token_index, req)
free_req_index.append(req.req_idx)
@@ -266,6 +283,7 @@ def __init__(
self.fsm_current_state: int = 0
self.allowed_token_ids = self.shm_param.allowed_token_ids.to_list()
+ self.invalid_token_ids = self.shm_param.invalid_token_ids.to_list()
if len(self.allowed_token_ids) == 0:
self.allowed_token_ids = None
@@ -281,6 +299,11 @@ def __init__(
logger.error("allowed_token_ids contain tokenid >= vobsize, we remove these token ids")
self.allowed_token_ids = [e for e in self.allowed_token_ids if e < vocab_size]
+ if len(self.invalid_token_ids) > 0:
+ if not all(e < vocab_size for e in self.invalid_token_ids):
+ logger.error("invalid_token_ids contain tokenid >= vobsize, we remove these token ids")
+ self.invalid_token_ids = [e for e in self.invalid_token_ids if e < vocab_size]
+
# nixl decode node information
if self.shm_param.nixl_params.data_len > 0:
self.nixl_decode_node: NIXLDecodeNodeInfo = pickle.loads(self.shm_param.nixl_params.get())
@@ -491,6 +514,8 @@ def update_finish_status(self, eos_ids, output_len: int):
self.finish_status.set_status(FinishStatus.FINISHED_STOP)
elif output_len >= self.sampling_param.shm_param.max_new_tokens:
self.finish_status.set_status(FinishStatus.FINISHED_LENGTH)
+ elif self.infer_aborted:
+ self.finish_status.set_status(FinishStatus.FINISHED_ABORTED)
return
def _stop_sequences_matched(self, output_len: int):
@@ -580,6 +605,8 @@ def handle(
shm_req.shm_cur_output_len = self.output_len
if finish_status.is_finished():
+ if _routing_mgr.g_routing_capture_manager is not None:
+ g_infer_context._extract_routing_data(req_obj)
shm_req.finish_token_index = shm_req.input_len + self.output_len - 1
shm_req.finish_status = req_obj.finish_status
diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py
index 64310d6b0d..70b0ec9ebf 100644
--- a/lightllm/server/router/model_infer/mode_backend/base_backend.py
+++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py
@@ -4,7 +4,7 @@
import time
import threading
import torch.distributed as dist
-from typing import List, Tuple, Callable, Optional
+from typing import List, Tuple, Callable, Optional, Union
from transformers.configuration_utils import PretrainedConfig
from lightllm.utils.infer_utils import set_random_seed
from lightllm.utils.log_utils import init_logger
@@ -16,7 +16,7 @@
from lightllm.common.basemodel.basemodel import TpPartBaseModel
from lightllm.common.basemodel.batch_objs import ModelOutput, ModelInput
from lightllm.common.basemodel.triton_kernel.mtp_utils import mtp_verify
-from lightllm.utils.dist_utils import init_distributed_env
+from lightllm.utils.dist_utils import init_distributed_env, init_custom_process_group
from lightllm.utils.envs_utils import get_unique_server_name
from lightllm.server.core.objs import ShmReqManager, StartArgs
from lightllm.server.core.objs.io_objs import AbortedReqCmd, StopStrMatchedReqCmd
@@ -31,6 +31,9 @@
enable_radix_tree_timer_merge,
get_radix_tree_merge_update_delta,
)
+from lightllm.utils.serializer import LocalSerializedTensor, MultiprocessingSerializer
+from lightllm.utils.patch_torch import monkey_patch_torch_reductions
+from lightllm.utils.tensor_bucket import FlattenedTensorBucket, FlattenedTensorMetadata
from lightllm.distributed import dist_group_manager
from lightllm.server.core.objs.shm_objs_io_buffer import ShmObjsIOBuffer
from lightllm.server.router.model_infer.mode_backend.overlap_events import OverlapEventManager, OverlapEventPack
@@ -41,7 +44,16 @@
from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample
from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token
from lightllm.server.pd_io_struct import NIXLChunckedTransTaskRet
+from lightllm.common.basemodel import routing_manager as _routing_mgr
+from lightllm.utils.torch_memory_saver_utils import MemoryTag
from .multi_level_kv_cache import MultiLevelKvCacheModule
+from lightllm.server.io_struct import (
+ FlushCacheReq,
+ InitWeightsUpdateGroupReq,
+ DestroyWeightsUpdateGroupReq,
+ UpdateWeightsFromDistributedReq,
+ UpdateWeightsFromTensorReq,
+)
class ModeBackend:
@@ -114,6 +126,8 @@ def init_model(self, kvargs):
)
dist_group_manager.create_groups(group_size=group_size) # set the default group
+ self._model_update_group = {}
+
self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", self.dp_size_in_node)
# 为 p d 分离模式添加的全局锁管理,用于做一些同步操作。 一定需要在
@@ -338,6 +352,191 @@ def init_mtp_draft_model(self, main_kvargs: dict):
self.logger.info(f"loaded mtp model class {self.draft_models[i].__class__}")
return
+ def flush_cache(self, request: FlushCacheReq):
+ if self.radix_cache is not None:
+ self.radix_cache.flush_cache()
+ return True, "Succeeded to flush cache."
+
+ def release_memory_occupation(self, tags: List[MemoryTag]):
+ try:
+ self.model.release_memory_occupation(tags)
+ self.flush_cache(request=None)
+ self.model.req_manager.free_all()
+ self.model.mem_manager.free_all()
+ return True, "Succeeded to release memory occupation."
+ except Exception as e:
+ self.logger.error(f"release memory occupation failed: {str(e)}")
+ return False, f"release memory occupation failed: {str(e)}"
+
+ def resume_memory_occupation(self, tags: List[MemoryTag]):
+ try:
+ self.model.resume_memory_occupation(tags)
+ return True, "Succeeded to resume memory occupation."
+ except Exception as e:
+ self.logger.error(f"resume memory occupation failed: {str(e)}")
+ return False, f"resume memory occupation failed: {str(e)}"
+
+ def init_weights_update_group(self, request: InitWeightsUpdateGroupReq):
+ assert torch.distributed.is_initialized(), "Default torch process group must be initialized"
+
+ assert request.group_name != "", "Group name cannot be empty"
+ rank_offset = request.rank_offset
+ rank = rank_offset + self.rank_in_dp
+ world_size = request.world_size
+ group_name = request.group_name
+ self.logger.info(
+ f"init custom process group: master_address={request.master_address}, master_port={request.master_port}, "
+ f"rank_offset={rank_offset}, rank={rank}, world_size={world_size}, group_name={group_name}, "
+ f" backend={request.backend}"
+ )
+
+ try:
+ if group_name in self._model_update_group:
+ raise ValueError(f"Process group with name {group_name} already exists.")
+
+ self._model_update_group[group_name] = init_custom_process_group(
+ backend=request.backend,
+ init_method=f"tcp://{request.master_address}:{request.master_port}",
+ world_size=world_size,
+ rank=rank,
+ group_name=group_name,
+ )
+ return True, "Succeeded to initialize custom process group."
+
+ except Exception as e:
+ message = f"Failed to initialize custom process group: {e}."
+ self.logger.error(message)
+ return False, message
+
+ def destroy_weights_update_group(self, request: DestroyWeightsUpdateGroupReq):
+ try:
+ if request.group_name in self._model_update_group:
+ pg = self._model_update_group.pop(request.group_name)
+ torch.distributed.destroy_process_group(pg)
+ return True, "Succeeded to destroy custom process group."
+ else:
+ return False, "The group to be destroyed does not exist."
+ except Exception as e:
+ message = f"Failed to destroy custom process group: {e}."
+ self.logger.error(message)
+ return False, message
+
+ def update_weights_from_distributed(self, request: UpdateWeightsFromDistributedReq):
+ """
+ Update specific parameter in the model weights online
+ through `_model_update_group` process group.
+
+ Args:
+ name: the name of the parameter to be updated.
+ dtype: the data type of the parameter to be updated.
+ shape: the shape of the parameter to be updated.
+ """
+
+ assert request.group_name in self._model_update_group, (
+ f"Group {request.group_name} not in {list(self._model_update_group.keys())}. "
+ "Please call `init_weights_update_group` first."
+ )
+
+ try:
+ weights = []
+ handles = []
+ for name, dtype, shape in zip(request.names, request.dtypes, request.shapes):
+ target_dtype = dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
+ weight = torch.empty(shape, dtype=target_dtype, device="cuda")
+ handles.append(
+ torch.distributed.broadcast(
+ weight,
+ src=0,
+ group=self._model_update_group[request.group_name],
+ async_op=True,
+ )
+ )
+ weights.append((name, weight))
+ for handle in handles:
+ handle.wait()
+
+ self.model.load_weights(weights)
+ return True, "Succeeded to update parameter online from distributed."
+
+ except Exception as e:
+ error_msg = (
+ f"Failed to update parameter online: {e}. "
+ f"The full weights of the ModelRunner are partially updated. "
+ f"Please discard the whole weights."
+ )
+ self.logger.error(error_msg)
+ return False, error_msg
+
+ def _update_weights_from_flattened_bucket(
+ self,
+ flattened_tensor_bucket_dict,
+ ):
+ """Handle flattened bucket format for weight updates"""
+ flattened_tensor = flattened_tensor_bucket_dict["flattened_tensor"]
+ metadata = flattened_tensor_bucket_dict["metadata"]
+
+ # Convert metadata dict to our format
+ converted_metadata = []
+ for meta in metadata:
+ converted_meta = FlattenedTensorMetadata(
+ name=meta.name,
+ shape=meta.shape,
+ dtype=meta.dtype,
+ start_idx=meta.start_idx,
+ end_idx=meta.end_idx,
+ numel=meta.numel,
+ )
+ converted_metadata.append(converted_meta)
+
+ # Create bucket and reconstruct tensors
+ bucket = FlattenedTensorBucket(flattened_tensor=flattened_tensor, metadata=converted_metadata)
+ reconstructed_tensors = bucket.reconstruct_tensors()
+
+ named_tensors = {name: tensor for name, tensor in reconstructed_tensors}
+
+ # Load the reconstructed tensors using the standard method
+ self.model.load_weights(named_tensors)
+
+ return True, "Succeeded to update parameter online from flattened bucket tensor."
+
+ def update_weights_from_tensor(self, request: UpdateWeightsFromTensorReq):
+ try:
+ monkey_patch_torch_reductions()
+ if request.load_format == "flattened_bucket":
+ # Handle flattened bucket format
+ serialized_named_tensors = MultiprocessingSerializer.deserialize(
+ request.serialized_named_tensors[self.rank_in_dp]
+ )
+ return self._update_weights_from_flattened_bucket(flattened_tensor_bucket_dict=serialized_named_tensors)
+
+ # We need to get device after patch otherwise the device would be wrong
+ self.device_module = torch.get_device_module("cuda")
+ infered_device = self.device_module.current_device()
+
+ named_tensors = MultiprocessingSerializer.deserialize(request.serialized_named_tensors[self.rank_in_dp])
+
+ def _unwrap_tensor(tensor, tp_rank, device):
+ if isinstance(tensor, LocalSerializedTensor):
+ tensor = tensor.get(tp_rank)
+ clone = tensor.to(device).clone()
+ del tensor # free the ipc tensor
+ return clone
+
+ named_tensors = {
+ name: _unwrap_tensor(tensor, tp_rank=self.rank_in_dp, device=infered_device)
+ for name, tensor in named_tensors
+ }
+
+ self.model.load_weights(named_tensors)
+
+ return True, "Succeeded to update parameter online from tensor."
+
+ except Exception as e:
+ message = f"Failed to update parameter online from tensor. Reason: {e}."
+ self.logger.error(message)
+
+ return False, message
+
def _async_copy_next_token_infos_to_pin_mem(self, next_token_ids: torch.Tensor, next_token_logprobs: torch.Tensor):
"""
这个函数会把next token id和logprobs保存到pinned memory中
@@ -798,6 +997,18 @@ def _sample_and_scatter_token(
)
return next_token_ids, next_token_ids_cpu, next_token_logprobs_cpu
+ def _flush_routing_to_kv_buffer(self, mem_indexes: torch.Tensor, microbatch_index: int = 0) -> None:
+ """Scatter captured routing data from capture buffer to KV-indexed GPU buffer.
+
+ Must be called AFTER model.forward() completes. mem_indexes should be the
+ original (unpadded) tensor — either CPU or CUDA.
+ """
+ if _routing_mgr.g_routing_capture_manager is not None and mem_indexes is not None:
+ if not mem_indexes.is_cuda:
+ mem_indexes = mem_indexes.cuda(non_blocking=True)
+ num_tokens = mem_indexes.shape[0]
+ _routing_mgr.g_routing_capture_manager.flush_to_routing_buffer(mem_indexes, num_tokens, microbatch_index)
+
def _dp_all_gather_prefill_and_decode_req_num(
self, prefill_reqs: List[InferReq], decode_reqs: List[InferReq]
) -> Tuple[np.ndarray, np.ndarray]:
diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py
index a8a5224ebc..9f4443e48e 100644
--- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py
+++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py
@@ -109,6 +109,7 @@ def prefill_normal(
model_input, run_reqs = prepare_prefill_inputs(prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill)
with torch.cuda.stream(g_infer_context.get_overlap_stream()):
model_output = self.model.forward(model_input)
+ self._flush_routing_to_kv_buffer(model_input.mem_indexes)
_, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token(
logits=model_output.logits,
b_req_idx=model_input.b_req_idx,
@@ -148,6 +149,7 @@ def decode_normal(
model_input, run_reqs = prepare_decode_inputs(decode_reqs)
with torch.cuda.stream(g_infer_context.get_overlap_stream()):
model_output = self.model.forward(model_input)
+ self._flush_routing_to_kv_buffer(model_input.mem_indexes)
_, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token(
logits=model_output.logits,
b_req_idx=model_input.b_req_idx,
@@ -186,6 +188,7 @@ def prefill_mtp(
model_input, run_reqs = prepare_prefill_inputs(prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill)
with torch.cuda.stream(g_infer_context.get_overlap_stream()):
model_output = self.model.forward(model_input)
+ self._flush_routing_to_kv_buffer(model_input.mem_indexes)
next_token_ids, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token(
logits=model_output.logits,
b_req_idx=model_input.b_req_idx,
@@ -236,6 +239,7 @@ def decode_mtp(
with torch.cuda.stream(g_infer_context.get_overlap_stream()):
b_mtp_index_cpu = model_input.b_mtp_index
model_output = self.model.forward(model_input)
+ self._flush_routing_to_kv_buffer(model_input.mem_indexes)
next_token_ids, next_token_logprobs = sample(model_output.logits, run_reqs, self.eos_id)
# verify the next_token_ids
b_req_mtp_start_loc = [index for index, mtp_index in enumerate(b_mtp_index_cpu) if mtp_index == 0]
diff --git a/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py
index 5a179cb620..ebc55b7ef4 100644
--- a/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py
+++ b/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py
@@ -40,8 +40,8 @@ def beam_prefill(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq
)
with torch.cuda.stream(g_infer_context.get_overlap_stream()):
-
model_output = self.model.forward(model_input)
+ self._flush_routing_to_kv_buffer(model_input.mem_indexes)
logits = model_output.logits
batch_idx, run_reqs = self._diverse_copy(
diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py
index bb0e848e76..f01e5fe935 100644
--- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py
+++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py
@@ -145,6 +145,7 @@ def prefill_normal(
run_reqs_num = len(run_reqs)
with torch.cuda.stream(g_infer_context.get_overlap_stream()):
model_output = self.model.forward(model_input)
+ self._flush_routing_to_kv_buffer(model_input.mem_indexes)
if run_reqs_num > 0:
_, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token(
logits=model_output.logits[:run_reqs_num],
@@ -188,6 +189,7 @@ def decode_normal(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq
run_reqs_num = len(run_reqs)
with torch.cuda.stream(g_infer_context.get_overlap_stream()):
model_output = self.model.forward(model_input)
+ self._flush_routing_to_kv_buffer(model_input.mem_indexes)
if run_reqs_num > 0:
_, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token(
logits=model_output.logits[:run_reqs_num],
@@ -236,6 +238,8 @@ def prefill_overlap(self, event_pack: OverlapEventPack, prefill_reqs: List[Infer
with torch.cuda.stream(g_infer_context.get_overlap_stream()):
model_output0, model_output1 = self.model.microbatch_overlap_prefill(model_input0, model_input1)
+ self._flush_routing_to_kv_buffer(model_input0.mem_indexes, microbatch_index=0)
+ self._flush_routing_to_kv_buffer(model_input1.mem_indexes, microbatch_index=1)
logits0 = model_output0.logits
logits1 = model_output1.logits
@@ -305,6 +309,8 @@ def decode_overlap(self, event_pack: OverlapEventPack, decode_reqs: List[InferRe
with torch.cuda.stream(g_infer_context.get_overlap_stream()):
model_output0, model_output1 = self.model.microbatch_overlap_decode(model_input0, model_input1)
+ self._flush_routing_to_kv_buffer(model_input0.mem_indexes, microbatch_index=0)
+ self._flush_routing_to_kv_buffer(model_input1.mem_indexes, microbatch_index=1)
logits0 = model_output0.logits
logits1 = model_output1.logits
@@ -359,6 +365,7 @@ def prefill_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq]
req_num = len(run_reqs)
with torch.cuda.stream(g_infer_context.get_overlap_stream()):
model_output: ModelOutput = self.model.forward(model_input)
+ self._flush_routing_to_kv_buffer(model_input.mem_indexes)
b_has_out_cpu = model_input.b_prefill_has_output_cpu[0:req_num]
logits = model_output.logits[0:req_num, :]
b_req_idx = model_input.b_req_idx[0:req_num]
@@ -421,6 +428,7 @@ def decode_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq]):
with torch.cuda.stream(g_infer_context.get_overlap_stream()):
model_output = self.model.forward(model_input)
+ self._flush_routing_to_kv_buffer(model_input.mem_indexes)
mtp_accept_len, b_req_mtp_start_loc, next_token_ids = None, None, None
if req_num > 0:
logits = model_output.logits[0:req_num, :]
@@ -629,6 +637,8 @@ def prefill_overlap_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[I
) = padded_overlap_prepare_prefill_inputs(prefill_reqs)
with torch.cuda.stream(g_infer_context.get_overlap_stream()):
model_output0, model_output1 = self.model.microbatch_overlap_prefill(model_input0, model_input1)
+ self._flush_routing_to_kv_buffer(model_input0.mem_indexes, microbatch_index=0)
+ self._flush_routing_to_kv_buffer(model_input1.mem_indexes, microbatch_index=1)
logits0 = model_output0.logits
logits1 = model_output1.logits
req_num0, req_num1 = len(run_reqs0), len(run_reqs1)
@@ -726,8 +736,9 @@ def decode_overlap_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[Inf
b_mtp_index_cpu0 = model_input0.b_mtp_index
b_mtp_index_cpu1 = model_input1.b_mtp_index
with torch.cuda.stream(g_infer_context.get_overlap_stream()):
-
model_output0, model_output1 = self.model.microbatch_overlap_decode(model_input0, model_input1)
+ self._flush_routing_to_kv_buffer(model_input0.mem_indexes, microbatch_index=0)
+ self._flush_routing_to_kv_buffer(model_input1.mem_indexes, microbatch_index=1)
logits0 = model_output0.logits
logits1 = model_output1.logits
run_reqs = run_reqs0 + run_reqs1
diff --git a/lightllm/server/router/model_infer/mode_backend/generic_post_process.py b/lightllm/server/router/model_infer/mode_backend/generic_post_process.py
index e2ccf290e8..fc551b08ea 100644
--- a/lightllm/server/router/model_infer/mode_backend/generic_post_process.py
+++ b/lightllm/server/router/model_infer/mode_backend/generic_post_process.py
@@ -1,7 +1,8 @@
import torch
from typing import List
-from lightllm.common.basemodel.triton_kernel.apply_penalty import apply_penalty
-from lightllm.common.basemodel.triton_kernel.apply_penalty_gpu_cache import apply_penalty_gpu_cache
+from lightllm.common.basemodel.triton_kernel.post_process.apply_penalty import apply_penalty
+from lightllm.common.basemodel.triton_kernel.post_process.apply_penalty_gpu_cache import apply_penalty_gpu_cache
+from lightllm.common.basemodel.triton_kernel.post_process.apply_invalid_token import apply_invalid_token_ids
from lightllm.server.router.model_infer.infer_batch import InferReq, g_infer_context
from lightllm.utils.envs_utils import get_env_start_args
@@ -14,7 +15,10 @@ def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]):
b_top_ks,
b_length_penalty_param,
b_mask_eos_reqs,
+ invalid_token_ids,
+ cu_invalid_token_num,
is_all_greedy,
+ has_invalid_token_ids,
) = _get_post_sample_tensors(reqs)
eos_ids = torch.tensor(eos_id, dtype=torch.int32, device="cpu", pin_memory=True).cuda(non_blocking=True)
@@ -59,6 +63,14 @@ def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]):
eos_ids=eos_ids,
sampling_params_manager=sampling_params_manager,
)
+
+ if has_invalid_token_ids:
+ apply_invalid_token_ids(
+ Logits=logits,
+ invalid_token_ids=invalid_token_ids,
+ cu_invalid_token_num=cu_invalid_token_num,
+ )
+
logits.div_(b_temperatures.view((-1, 1)))
probs = torch.softmax(logits, dim=-1)
@@ -112,6 +124,12 @@ def _get_post_sample_tensors(reqs: List[InferReq]):
mask_eos_reqs: List[bool] = []
is_all_greedy = True
+ # invalid token ids
+ invalid_token_ids: List[int] = []
+ has_invalid_token_ids = False
+ cu_invalid_token_num = [0]
+ invalid_token_num_start = 0
+
for i, req_obj in enumerate(reqs):
sample_param = req_obj.sampling_param
shm_param = sample_param.shm_param
@@ -127,6 +145,11 @@ def _get_post_sample_tensors(reqs: List[InferReq]):
if top_k_val > 1:
is_all_greedy = False
req_idxes.append(req_obj.req_idx)
+ invalid_token_num_start += len(req_obj.sampling_param.invalid_token_ids)
+ cu_invalid_token_num.append(invalid_token_num_start)
+ if len(req_obj.sampling_param.invalid_token_ids) > 0:
+ has_invalid_token_ids = True
+ invalid_token_ids.extend(req_obj.sampling_param.invalid_token_ids)
req_idxes_cpu = torch.tensor(req_idxes, dtype=torch.int32, device="cpu", pin_memory=True)
temperatures_cpu = torch.tensor(temperatures, dtype=torch.float, device="cpu", pin_memory=True)
@@ -135,6 +158,10 @@ def _get_post_sample_tensors(reqs: List[InferReq]):
length_penalty_param_cpu = torch.tensor(length_penalty_param, dtype=torch.int32, device="cpu", pin_memory=True)
mask_eos_reqs_cpu = torch.tensor(mask_eos_reqs, dtype=torch.bool, device="cpu", pin_memory=True)
+ if has_invalid_token_ids:
+ invalid_token_ids_cpu = torch.tensor(invalid_token_ids, dtype=torch.int32, device="cpu", pin_memory=True)
+ cu_invalid_token_num_cpu = torch.tensor(cu_invalid_token_num, dtype=torch.int32, device="cpu", pin_memory=True)
+
return (
req_idxes_cpu.cuda(non_blocking=True),
temperatures_cpu.cuda(non_blocking=True),
@@ -142,5 +169,8 @@ def _get_post_sample_tensors(reqs: List[InferReq]):
top_ks_cpu.cuda(non_blocking=True),
length_penalty_param_cpu.cuda(non_blocking=True),
mask_eos_reqs_cpu.cuda(non_blocking=True),
+ invalid_token_ids_cpu.cuda(non_blocking=True) if has_invalid_token_ids else None,
+ cu_invalid_token_num_cpu.cuda(non_blocking=True) if has_invalid_token_ids else None,
is_all_greedy,
+ has_invalid_token_ids,
)
diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py
index 55fe7a415e..d7b85adf68 100644
--- a/lightllm/server/router/model_infer/model_rpc.py
+++ b/lightllm/server/router/model_infer/model_rpc.py
@@ -33,6 +33,8 @@
from lightllm.utils.graceful_utils import graceful_registry
from lightllm.utils.process_check import start_parent_check_thread
from lightllm.utils.envs_utils import get_unique_server_name
+from lightllm.utils.torch_memory_saver_utils import MemoryTag
+from lightllm.server.io_struct import GeneralHttpToModelRpcReq, GeneralModelToHttpRpcRsp
logger = init_logger(__name__)
@@ -179,6 +181,34 @@ def init_model(self, kvargs):
def get_max_total_token_num(self):
return self.backend.get_max_total_token_num()
+ def release_memory_occupation(self, tags: List[MemoryTag]):
+ try:
+ self.backend.release_memory_occupation(tags)
+ return True
+ except BaseException as e:
+ logger.exception(f"release memory occupation failed: {str(e)}")
+ return False
+
+ def resume_memory_occupation(self, tags: List[MemoryTag]):
+ try:
+ self.backend.resume_memory_occupation(tags)
+ return True
+ except BaseException as e:
+ logger.exception(f"resume memory occupation failed: {str(e)}")
+ return False
+
+ def forward_to_model(self, req: GeneralHttpToModelRpcReq) -> GeneralModelToHttpRpcRsp:
+ try:
+ if self.backend is None or not hasattr(self.backend, req.func_name):
+ raise ValueError(f"Backend does not support function {req.func_name}")
+ success, ret = getattr(self.backend, req.func_name)(req.func_args)
+ return GeneralModelToHttpRpcRsp(success=success, msg=str(ret), func_name=req.func_name, func_rsp=ret)
+ except BaseException as e:
+ logger.exception(f"forward to model backend failed: {str(e)}")
+ return GeneralModelToHttpRpcRsp(
+ success=False, msg=f"forward to model backend failed: {str(e)}", func_name=req.func_name
+ )
+
class ModelRpcClient:
def __init__(self, rpc_event, rpc_finished_event):
@@ -209,6 +239,16 @@ async def get_max_total_token_num(self):
assert func_name == "get_max_total_token_num"
return ret
+ def forward_to_model(self, req: GeneralHttpToModelRpcReq) -> GeneralModelToHttpRpcRsp:
+ self.rpc_shm_params.write_func_params("forward_to_model", (req,))
+ self.rpc_event.set()
+
+ self.rpc_finished_event.wait()
+ self.rpc_finished_event.clear()
+ func_name, ret = self.rpc_shm_results.read_func_result()
+ assert func_name == "forward_to_model"
+ return ret
+
def _init_env(
args,
@@ -269,7 +309,11 @@ async def start_model_process(
success_event,
),
)
- proc.start()
+ from lightllm.utils.torch_memory_saver_utils import TorchMemorySaverWrapper
+
+ torch_memory_saver = TorchMemorySaverWrapper(args.enable_torch_memory_saver)
+ with torch_memory_saver.configure_subprocess():
+ proc.start()
# Use asyncio.to_thread to make the blocking wait non-blocking
await asyncio.to_thread(success_event.wait, timeout=40)
diff --git a/lightllm/server/router/req_queue/base_queue.py b/lightllm/server/router/req_queue/base_queue.py
index 36aefae6e7..d7ef06828b 100644
--- a/lightllm/server/router/req_queue/base_queue.py
+++ b/lightllm/server/router/req_queue/base_queue.py
@@ -34,6 +34,17 @@ def free_aborted_req_cpu_cache_pages(self, req: Req):
req.cpu_cache_match_page_indexes.clear()
self.router.cpu_cache_client.lock.release()
+ def free_aborted_req(self, req: Req):
+ # 为了让http server 能正常返回请求,还没有开始推理的请求,直接设置结束,返回空字符串
+ input_len = req.input_len
+ req.link_prompt_ids_shm_array()
+ req.link_logprobs_shm_array()
+ req.candetoken_out_len = 1
+ req.finish_token_index = input_len
+ req.shm_prompt_ids.arr[input_len] = self.args.eos_id[0]
+ req.shm_logprobs.arr[input_len] = 0
+ req.finish_status.set_status(FinishStatus.FINISHED_ABORTED)
+
def extend(self, req_group: List[Req]):
for req in req_group:
req.sample_params.suggested_dp_index = self.dp_index
diff --git a/lightllm/server/router/req_queue/chunked_prefill/beam_impl.py b/lightllm/server/router/req_queue/chunked_prefill/beam_impl.py
index ae7c90b335..ed2a5dbb12 100644
--- a/lightllm/server/router/req_queue/chunked_prefill/beam_impl.py
+++ b/lightllm/server/router/req_queue/chunked_prefill/beam_impl.py
@@ -89,7 +89,7 @@ def generate_new_batch(self, current_batch: Batch):
aborted_count = 0
cur_group_reqs = []
for req in self.waiting_req_list:
- if req.is_aborted:
+ if req.is_aborted and not self.router.is_multinode_tp:
aborted_count += 1
abort_req_list.append(req)
continue
@@ -111,7 +111,7 @@ def generate_new_batch(self, current_batch: Batch):
ok_insert, new_batch_first_router_need_tokens = self._can_add_new_group_reqs(
cur_group_reqs, is_busy, new_batch_first_router_need_tokens
)
- if ok_insert:
+ if ok_insert and False:
can_run_list.extend(cur_group_reqs)
new_batch = None
@@ -120,6 +120,7 @@ def generate_new_batch(self, current_batch: Batch):
for req in abort_req_list:
self.free_aborted_req_cpu_cache_pages(req)
+ self.free_aborted_req(req)
self.router.shm_req_manager.put_back_req_obj(req)
self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :]
return new_batch
diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl.py b/lightllm/server/router/req_queue/chunked_prefill/impl.py
index 0d870b55d8..9449798e9c 100644
--- a/lightllm/server/router/req_queue/chunked_prefill/impl.py
+++ b/lightllm/server/router/req_queue/chunked_prefill/impl.py
@@ -79,8 +79,8 @@ def generate_new_batch(self, current_batch: Batch):
waiting_queue = self.waiting_req_list
for req in waiting_queue:
- if req.is_aborted:
- # 由于管理的复杂性,只有没有被调度运行过的请求可以因为abort直接在队列中忽略掉.
+ if req.is_aborted and not self.router.is_multinode_tp:
+ # 由于管理的复杂性,只有没有被调度运行过的单节点请求可以因为abort直接在队列中忽略掉.
# 暂停的请求需要恢复后,由 router manager 部分来过滤。暂时保持这种处理方法, 否则会导致管理token的泄漏
aborted_count += 1
abort_req_list.append(req)
@@ -97,6 +97,7 @@ def generate_new_batch(self, current_batch: Batch):
new_batch = Batch(uuid.uuid4().int, can_run_list, dp_size_in_node=self.dp_size_in_node)
for req in abort_req_list:
self.free_aborted_req_cpu_cache_pages(req)
+ self.free_aborted_req(req)
self.router.shm_req_manager.put_back_req_obj(req)
self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :]
return new_batch
diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl_for_nixl_pd.py b/lightllm/server/router/req_queue/chunked_prefill/impl_for_nixl_pd.py
index f2658159b4..842b93648b 100644
--- a/lightllm/server/router/req_queue/chunked_prefill/impl_for_nixl_pd.py
+++ b/lightllm/server/router/req_queue/chunked_prefill/impl_for_nixl_pd.py
@@ -70,7 +70,7 @@ def generate_new_batch(self, current_batch: Batch):
waiting_queue = self.waiting_req_list
for req in waiting_queue:
- if req.is_aborted:
+ if req.is_aborted and not self.router.is_multinode_tp:
# 由于管理的复杂性,只有没有被调度运行过的请求可以因为abort直接在队列中忽略掉.
# 暂停的请求需要恢复后,由 router manager 部分来过滤。暂时保持这种处理方法, 否则会导致管理token的泄漏
aborted_count += 1
@@ -88,6 +88,7 @@ def generate_new_batch(self, current_batch: Batch):
new_batch = Batch(uuid.uuid4().int, can_run_list, dp_size_in_node=self.dp_size_in_node)
for req in abort_req_list:
self.free_aborted_req_cpu_cache_pages(req)
+ self.free_aborted_req(req)
self.router.shm_req_manager.put_back_req_obj(req)
self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :]
return new_batch
diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd_decode.py b/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd_decode.py
index e0da134875..3dea3cf955 100644
--- a/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd_decode.py
+++ b/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd_decode.py
@@ -38,7 +38,7 @@ def generate_new_batch(self, current_batch: Batch):
abort_req_list = []
aborted_count = 0
for req in self.waiting_req_list:
- if req.is_aborted:
+ if req.is_aborted and not self.router.is_multinode_tp:
# 由于管理的复杂性,只有没有被调度运行过的请求可以因为abort直接在队列中忽略掉.
# 暂停的请求需要恢复后,由 router manager 部分来过滤。暂时保持这种处理方法, 否则会导致管理token和管理req对象的泄漏
aborted_count += 1
@@ -53,6 +53,7 @@ def generate_new_batch(self, current_batch: Batch):
new_batch = Batch(uuid.uuid4().int, can_run_list, dp_size_in_node=self.dp_size_in_node)
for req in abort_req_list:
self.free_aborted_req_cpu_cache_pages(req)
+ self.free_aborted_req(req)
self.router.shm_req_manager.put_back_req_obj(req)
self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :]
return new_batch
diff --git a/lightllm/server/router/req_queue/dp_base_queue.py b/lightllm/server/router/req_queue/dp_base_queue.py
index a73823b8b7..e5f731df5f 100644
--- a/lightllm/server/router/req_queue/dp_base_queue.py
+++ b/lightllm/server/router/req_queue/dp_base_queue.py
@@ -27,6 +27,12 @@ def __init__(self, args, router, base_queue_class, dp_size_in_node) -> None:
self.reqs_waiting_for_dp_index: List[List[Req]] = []
return
+ def free_aborted_req(self, req: Req):
+ dp_index = req.sample_params.suggested_dp_index
+ assert dp_index >= 0 and dp_index < self.dp_size_in_node
+ self.inner_queues[dp_index].free_aborted_req(req)
+ return
+
def get_dp_queue(self, dp_index: int):
assert dp_index < self.dp_size_in_node, "dp index out of range"
return self.inner_queues[dp_index]
diff --git a/lightllm/server/tokenizer.py b/lightllm/server/tokenizer.py
index e0b2bd425e..3563739f79 100644
--- a/lightllm/server/tokenizer.py
+++ b/lightllm/server/tokenizer.py
@@ -30,6 +30,7 @@
from ..models.qwen2_vl.model import QWen2VLTokenizer
from ..models.qwen3_vl.model import QWen3VLTokenizer
from ..models.internvl.model import InternvlTokenizer
+from ..models.neo_chat_moe.model import NeoChatTokenizer
from ..models.gemma3.model import Gemma3Tokenizer
# A fast LLaMA tokenizer with the pre-processed `tokenizer.json` file.
@@ -104,5 +105,7 @@ def get_tokenizer(
tokenizer = InternvlTokenizer(tokenizer, model_cfg, weight_dir=tokenizer_name)
elif model_type == "gemma3":
tokenizer = Gemma3Tokenizer(tokenizer, model_cfg)
+ elif model_type == "neo_chat":
+ tokenizer = NeoChatTokenizer(tokenizer, model_cfg, weight_dir=tokenizer_name)
return tokenizer
diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py
index 202c2fc453..a54a4aeffa 100644
--- a/lightllm/server/visualserver/manager.py
+++ b/lightllm/server/visualserver/manager.py
@@ -8,7 +8,6 @@
import inspect
import setproctitle
from typing import List
-from lightllm.server.core.objs.io_objs.group_req import GroupReqIndexes
from lightllm.server.core.objs import ShmReqManager, StartArgs
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
@@ -19,6 +18,7 @@
from lightllm.utils.graceful_utils import graceful_registry
from lightllm.utils.process_check import start_parent_check_thread
from lightllm.utils.envs_utils import get_unique_server_name
+from lightllm.server.io_struct import BaseReq, GenerateReqIndex
from rpyc.utils.classic import obtain
@@ -49,7 +49,7 @@ def __init__(
self.cache_client = rpyc.connect("localhost", args.cache_port, config={"allow_pickle": True})
self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
self.cache_port = args.cache_port
- self.waiting_reqs: List[GroupReqIndexes] = []
+ self.waiting_reqs: List[BaseReq] = []
self.model_weightdir = args.model_dir
self.tp_world_size = args.tp
self.vit_dp = args.visual_dp
@@ -187,11 +187,12 @@ async def loop_for_netio_req(self):
while True:
try:
for _ in range(self.visual_recv_max_count):
- recv_req: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK)
- if isinstance(recv_req, GroupReqIndexes):
+ recv_req: BaseReq = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK)
+ # 目前只有 GenerateReqIndex 会进入这个队列,判断是否需要推理图片
+ if isinstance(recv_req, GenerateReqIndex):
self.waiting_reqs.append(recv_req)
else:
- assert False, f"Error Req Inf {recv_req}"
+ self.send_to_next_module.send_pyobj(recv_req, protocol=pickle.HIGHEST_PROTOCOL)
self.visual_recv_max_count = min(self.visual_recv_max_count * 1.3, 256)
except zmq.ZMQError:
# 当队列已经开始清空的时候,将一次接受数量下调
diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py
index 8f4a1ee450..22dfa915ba 100644
--- a/lightllm/server/visualserver/model_infer/model_rpc.py
+++ b/lightllm/server/visualserver/model_infer/model_rpc.py
@@ -4,6 +4,7 @@
import torch
import socket
import inspect
+import setproctitle
from datetime import timedelta
from typing import Dict, List, Tuple
from transformers.configuration_utils import PretrainedConfig
@@ -19,12 +20,15 @@
from lightllm.models.qwen2_5_vl.qwen2_5_visual import Qwen2_5_VisionTransformerPretrainedModel
from lightllm.models.qwen3_vl.qwen3_visual import Qwen3VisionTransformerPretrainedModel
from lightllm.models.tarsier2.tarsier2_visual import TarsierVisionTransformerPretrainedModel
+from lightllm.models.neo_chat_moe.neo_visual import NeoVisionTransformerPretrainedModel
from lightllm.utils.infer_utils import set_random_seed
from lightllm.utils.dist_utils import init_vision_distributed_env
from lightllm.utils.graceful_utils import graceful_registry
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient
from lightllm.server.visualserver import set_vit_att_backend
+from lightllm.utils.process_check import start_parent_check_thread
+from lightllm.utils.envs_utils import get_unique_server_name
class VisualModelRpcServer(rpyc.Service):
@@ -80,6 +84,8 @@ def exposed_init_model(self, kvargs):
# self.model = InternVLVisionModel()
elif self.model_type == "gemma3":
self.model = Gemma3VisionModel()
+ elif self.model_type == "neo_chat":
+ self.model = NeoVisionTransformerPretrainedModel(kvargs, **model_cfg["vision_config"]).eval().bfloat16()
else:
raise Exception(f"can not support {self.model_type} now")
@@ -172,6 +178,8 @@ async def encode(self, images: List[ImageItem]):
def _init_env(port, device_id):
# 注册graceful 退出的处理
graceful_registry(inspect.currentframe().f_code.co_name)
+ setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::visual_server::RANK{device_id}")
+ start_parent_check_thread()
import lightllm.utils.rpyc_fix_utils as _
diff --git a/lightllm/utils/device_utils.py b/lightllm/utils/device_utils.py
index a1ed6ed950..9e46a57ec3 100644
--- a/lightllm/utils/device_utils.py
+++ b/lightllm/utils/device_utils.py
@@ -100,7 +100,8 @@ def get_current_device_name():
gpu_name = gpu_name.replace(" ", "_")
return gpu_name
else:
- return None
+ return "unknown" # need fix
+ # raise RuntimeError("No GPU available")
@lru_cache(maxsize=None)
diff --git a/lightllm/utils/dist_utils.py b/lightllm/utils/dist_utils.py
index 65ac401d4c..28667c6d00 100644
--- a/lightllm/utils/dist_utils.py
+++ b/lightllm/utils/dist_utils.py
@@ -65,12 +65,15 @@ def init_vision_distributed_env(kvargs):
device_id = visual_gpu_ids[kvargs["vit_rank_id"]]
set_current_device_id(device_id)
torch.cuda.set_device(device_id)
+ # 不要在init_process_group时,显示的传入device_id
+ # 这会触发torch的device-bound split优化,会默认后面想加入新进程组的rank
+ # 都已经存在于默认组,这样RL更新weight的init_group时,外部想加入的组,在执行
+ # 通信原语时例如all_reduce,会永远等不到LightLLM默认组里的回复,从而导致错误结果。
dist.init_process_group(
"nccl",
init_method=f'tcp://127.0.0.1:{kvargs["visual_nccl_port"]}',
rank=kvargs["tp_rank_id"],
world_size=tp_world_size,
- device_id=torch.device(f"cuda:{device_id}"),
)
# warmup nccl communicator
_a = torch.zeros([1]).to(f"cuda:{device_id}")
@@ -104,7 +107,6 @@ def init_distributed_env(kvargs):
init_method=f'tcp://{kvargs["nccl_host"]}:{kvargs["nccl_port"]}',
rank=kvargs["rank_id"],
world_size=kvargs["world_size"],
- device_id=torch.device(f"cuda:{device_id}"),
)
# warmup nccl communicator
_a = torch.zeros([1]).to(f"cuda:{device_id}")
@@ -270,3 +272,71 @@ def _init_nccl_env():
assert response.status_code == 200, f"Failed to init config server nccl tcp store: {response.status_code}"
return
+
+
+# copy from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/utils/common.py#L1675
+def init_custom_process_group(
+ backend=None,
+ init_method=None,
+ timeout=None,
+ world_size=-1,
+ rank=-1,
+ store=None,
+ group_name=None,
+ pg_options=None,
+ device_id=None,
+):
+ from torch.distributed.distributed_c10d import (
+ Backend,
+ PrefixStore,
+ _new_process_group_helper,
+ _world,
+ default_pg_timeout,
+ rendezvous,
+ )
+
+ assert (store is None) or (init_method is None), "Cannot specify both init_method and store."
+
+ if store is not None:
+ assert world_size > 0, "world_size must be positive if using store"
+ assert rank >= 0, "rank must be non-negative if using store"
+ elif init_method is None:
+ init_method = "env://"
+
+ if backend:
+ backend = Backend(backend)
+ else:
+ backend = Backend("undefined")
+
+ if timeout is None:
+ timeout = default_pg_timeout
+
+ # backward compatible API
+ if store is None:
+ rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout)
+ store, rank, world_size = next(rendezvous_iterator)
+ store.set_timeout(timeout)
+
+ # Use a PrefixStore to avoid accidental overrides of keys used by
+ # different systems (e.g. RPC) in case the store is multi-tenant.
+ store = PrefixStore(group_name, store)
+
+ # NOTE: The pg_options parameter was renamed into backend_options in PyTorch 2.6.0
+ # https://github.com/pytorch/pytorch/commit/a0c7029a75628cd5fa8df83c0de0ea98ee7fd844
+ # We need to determine the appropriate parameter name based on PyTorch version
+ pg_options_param_name = "backend_options" if str(torch.__version__) >= "2.6" else "pg_options"
+ pg, _ = _new_process_group_helper(
+ world_size,
+ rank,
+ [],
+ backend,
+ store,
+ group_name=group_name,
+ **{pg_options_param_name: pg_options},
+ timeout=timeout,
+ device_id=device_id,
+ )
+
+ _world.pg_group_ranks[pg] = {i: i for i in range(world_size)}
+
+ return pg
diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py
index 3a0e28bcb6..a702a465b2 100644
--- a/lightllm/utils/envs_utils.py
+++ b/lightllm/utils/envs_utils.py
@@ -13,6 +13,7 @@ def set_unique_server_name(args):
if args.run_mode == "pd_master":
os.environ["LIGHTLLM_UNIQUE_SERVICE_NAME_ID"] = str(args.port) + "_pd_master"
else:
+ assert str(args.nccl_port) != "None", "nccl_port is not set"
os.environ["LIGHTLLM_UNIQUE_SERVICE_NAME_ID"] = str(args.nccl_port) + "_" + str(args.node_rank)
return
@@ -20,6 +21,8 @@ def set_unique_server_name(args):
@lru_cache(maxsize=None)
def get_unique_server_name():
service_uni_name = os.getenv("LIGHTLLM_UNIQUE_SERVICE_NAME_ID")
+ assert "None" not in service_uni_name, "service_uni_name is not set"
+
return service_uni_name
diff --git a/lightllm/utils/kv_cache_utils.py b/lightllm/utils/kv_cache_utils.py
index 3256fdd1fd..f44aad92ac 100644
--- a/lightllm/utils/kv_cache_utils.py
+++ b/lightllm/utils/kv_cache_utils.py
@@ -24,6 +24,7 @@
PPLINT8KVMemoryManager,
PPLINT4KVMemoryManager,
Deepseek2MemoryManager,
+ NeoMemoryManager,
)
from typing import List, Tuple, Optional
@@ -97,6 +98,28 @@ def calcu_cpu_cache_meta() -> "CpuKVCacheMeta":
scale_head_dim=get_head_dim(args.model_dir) // 8,
scale_data_type=get_llm_data_type(),
)
+ elif mem_manager_class is PPLINT8KVMemoryManager:
+ cpu_cache_meta = CpuKVCacheMeta(
+ page_num=0,
+ token_page_size=args.cpu_cache_token_page_size,
+ layer_num=get_layer_num(args.model_dir),
+ num_heads=get_num_key_value_heads(args.model_dir) * 2,
+ head_dim=get_head_dim(args.model_dir) * 2,
+ data_type=get_llm_data_type(),
+ scale_head_dim=0,
+ scale_data_type=get_llm_data_type(),
+ )
+ elif mem_manager_class is MemoryManager:
+ cpu_cache_meta = CpuKVCacheMeta(
+ page_num=0,
+ token_page_size=args.cpu_cache_token_page_size,
+ layer_num=get_layer_num(args.model_dir),
+ num_heads=get_num_key_value_heads(args.model_dir) * 2,
+ head_dim=get_head_dim(args.model_dir),
+ data_type=get_llm_data_type(),
+ scale_head_dim=0,
+ scale_data_type=get_llm_data_type(),
+ )
else:
logger.error(f"not support mem manager: {mem_manager_class} for cpu kv cache")
raise Exception(f"not support mem manager: {mem_manager_class} for cpu kv cache")
diff --git a/lightllm/utils/net_utils.py b/lightllm/utils/net_utils.py
index 20b9888753..486414e88e 100644
--- a/lightllm/utils/net_utils.py
+++ b/lightllm/utils/net_utils.py
@@ -2,44 +2,72 @@
import subprocess
import ipaddress
import random
+import portpicker
from lightllm.utils.log_utils import init_logger
logger = init_logger(__name__)
def alloc_can_use_network_port(num=3, used_nccl_ports=None, from_port_num=10000):
+ if used_nccl_ports is None:
+ used_nccl_ports = []
+
port_list = []
- for port in range(from_port_num, 65536):
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
- result = s.connect_ex(("localhost", port))
- if result != 0 and port not in used_nccl_ports:
+ max_attempts = num * 50 # Allow more attempts to find ports in range
+
+ for _ in range(max_attempts):
+ if len(port_list) >= num:
+ break
+
+ try:
+ port = portpicker.pick_unused_port()
+
+ if port >= from_port_num and port not in used_nccl_ports:
port_list.append(port)
- if len(port_list) > num * 30:
- break
+ logger.debug(f"Allocated port: {port}")
+ else:
+ logger.debug(f"Port {port} is out of range or in used_nccl_ports, skipping")
+
+ except Exception as e:
+ logger.warning(f"Failed to allocate port: {e}")
+ continue
if len(port_list) < num:
+ logger.error(f"Failed to allocate {num} ports, only got {len(port_list)}")
return None
- random.shuffle(port_list)
- return port_list[0:num]
+ logger.info(f"Successfully allocated {len(port_list)} ports: {port_list}")
+ return port_list
def alloc_can_use_port(min_port, max_port):
port_list = []
for port in range(min_port, max_port):
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
- result = s.connect_ex(("localhost", port))
+ try:
+ test_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ test_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ result = test_socket.connect_ex(("localhost", port))
+ test_socket.close()
+
if result != 0:
port_list.append(port)
+ except Exception:
+ continue
return port_list
def find_available_port(start_port, end_port):
for port in range(start_port, end_port + 1):
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
- result = sock.connect_ex(("localhost", port))
+ try:
+ test_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ test_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ result = test_socket.connect_ex(("localhost", port))
+ test_socket.close()
+
if result != 0:
return port
+ except Exception:
+ continue
return None
diff --git a/lightllm/utils/patch_torch.py b/lightllm/utils/patch_torch.py
new file mode 100644
index 0000000000..9f51edeb64
--- /dev/null
+++ b/lightllm/utils/patch_torch.py
@@ -0,0 +1,63 @@
+# copied from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/utils/patch_torch.py
+from typing import Callable, Union
+
+import torch
+from packaging import version
+from torch.multiprocessing import reductions
+
+
+def monkey_patch_torch_reductions():
+ """Monkey patching before Torch https://github.com/pytorch/pytorch/pull/149248 is fixed"""
+
+ # Currently, NPU does not support UUID. This has been temporarily commented out,
+ # with support expected in the fourth quarter.
+ # if _is_npu:
+ # return
+
+ if hasattr(reductions, "_reduce_tensor_original"):
+ return
+
+ reductions._reduce_tensor_original = reductions.reduce_tensor
+ reductions._rebuild_cuda_tensor_original = reductions.rebuild_cuda_tensor
+
+ reductions.reduce_tensor = _reduce_tensor_modified
+ reductions.rebuild_cuda_tensor = _rebuild_cuda_tensor_modified
+
+ reductions.init_reductions()
+
+
+# The signature has not been changed for years, and we will not need this when the next version is released,
+# so it looks safe to use a constant.
+_REDUCE_TENSOR_ARG_DEVICE_INDEX = 6
+
+
+def _reduce_tensor_modified(*args, **kwargs):
+ output_fn, output_args = reductions._reduce_tensor_original(*args, **kwargs)
+ output_args = _modify_tuple(output_args, _REDUCE_TENSOR_ARG_DEVICE_INDEX, _device_to_uuid)
+ return output_fn, output_args
+
+
+def _rebuild_cuda_tensor_modified(*args):
+ args = _modify_tuple(args, _REDUCE_TENSOR_ARG_DEVICE_INDEX, _device_from_maybe_uuid)
+ return reductions._rebuild_cuda_tensor_original(*args)
+
+
+def _device_to_uuid(device: int) -> str:
+ return str(torch.cuda.get_device_properties(device).uuid)
+
+
+def _device_from_maybe_uuid(device_maybe_uuid: Union[int, str]) -> int:
+ if isinstance(device_maybe_uuid, int):
+ return device_maybe_uuid
+
+ if isinstance(device_maybe_uuid, str):
+ for device in range(torch.cuda.device_count()):
+ if str(torch.cuda.get_device_properties(device).uuid) == device_maybe_uuid:
+ return device
+ raise Exception("Invalid device_uuid=" + device_maybe_uuid)
+
+ raise Exception(f"Unknown type: {device_maybe_uuid=}")
+
+
+def _modify_tuple(t, index: int, modifier: Callable):
+ return *t[:index], modifier(t[index]), *t[index + 1 :]
diff --git a/lightllm/utils/serializer.py b/lightllm/utils/serializer.py
new file mode 100644
index 0000000000..d8180aeb0c
--- /dev/null
+++ b/lightllm/utils/serializer.py
@@ -0,0 +1,131 @@
+# copied from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/utils/common.py
+
+import base64
+import pickle
+import io
+from dataclasses import dataclass
+from multiprocessing.reduction import ForkingPickler
+from typing import List
+
+
+class MultiprocessingSerializer:
+ @staticmethod
+ def serialize(obj, output_str: bool = False):
+ """
+ Serialize a Python object using ForkingPickler.
+
+ Args:
+ obj: The object to serialize.
+ output_str (bool): If True, return a base64-encoded string instead of raw bytes.
+
+ Returns:
+ bytes or str: The serialized object.
+ """
+ buf = io.BytesIO()
+ ForkingPickler(buf).dump(obj)
+ buf.seek(0)
+ output = buf.read()
+
+ if output_str:
+ # Convert bytes to base64-encoded string
+ output = base64.b64encode(output).decode("utf-8")
+
+ return output
+
+ @staticmethod
+ def deserialize(data):
+ """
+ Deserialize a previously serialized object.
+
+ Args:
+ data (bytes or str): The serialized data, optionally base64-encoded.
+
+ Returns:
+ The deserialized Python object.
+ """
+ if isinstance(data, str):
+ # Decode base64 string to bytes
+ data = base64.b64decode(data, validate=True)
+
+ return SafeUnpickler(io.BytesIO(data)).load()
+
+
+class SafeUnpickler(pickle.Unpickler):
+ ALLOWED_MODULE_PREFIXES = {
+ # --- Python types ---
+ "builtins.",
+ "collections.",
+ "copyreg.",
+ "functools.",
+ "itertools.",
+ "operator.",
+ "types.",
+ "weakref.",
+ # --- PyTorch types ---
+ "torch.",
+ "torch._tensor.",
+ "torch.storage.",
+ "torch.nn.parameter.",
+ "torch.autograd.function.",
+ # --- torch distributed ---
+ "torch.distributed.",
+ "torch.distributed._shard.",
+ "torch.distributed._composable.",
+ "torch._C._distributed_c10d.",
+ "torch._C._distributed_fsdp.",
+ "torch.distributed.optim.",
+ # --- multiprocessing ---
+ "multiprocessing.resource_sharer.",
+ "multiprocessing.reduction.",
+ "pickletools.",
+ # --- PEFT / LoRA ---
+ "peft.",
+ "transformers.",
+ "huggingface_hub.",
+ # --- SGLang & Unitest ---
+ "sglang.srt.weight_sync.tensor_bucket.",
+ "sglang.srt.model_executor.model_runner.",
+ "sglang.srt.layers.",
+ "sglang.srt.utils.",
+ # --- LightLLM ---
+ "lightllm.utils.",
+ }
+
+ DENY_CLASSES = {
+ ("builtins", "eval"),
+ ("builtins", "exec"),
+ ("builtins", "compile"),
+ ("os", "system"),
+ ("subprocess", "Popen"),
+ ("subprocess", "run"),
+ ("codecs", "decode"),
+ ("types", "CodeType"),
+ ("types", "FunctionType"),
+ }
+
+ def find_class(self, module, name):
+ # Block deterministic attacks
+ if (module, name) in self.DENY_CLASSES:
+ raise RuntimeError(
+ f"Blocked unsafe class loading ({module}.{name}), " f"to prevent exploitation of CVE-2025-10164"
+ )
+ # Allowlist of safe-to-load modules.
+ if any((module + ".").startswith(prefix) for prefix in self.ALLOWED_MODULE_PREFIXES):
+ return super().find_class(module, name)
+
+ # Block everything else. (Potential attack surface)
+ raise RuntimeError(
+ f"Blocked unsafe class loading ({module}.{name}), " f"to prevent exploitation of CVE-2025-10164"
+ )
+
+
+@dataclass
+class LocalSerializedTensor:
+ """torch.Tensor that gets serialized by MultiprocessingSerializer
+ (which only serializes a pointer and not the data).
+ The i-th element in the list corresponds to i-th rank's GPU."""
+
+ values: List[bytes]
+
+ def get(self, rank: int):
+ return MultiprocessingSerializer.deserialize(self.values[rank])
diff --git a/lightllm/utils/tensor_bucket.py b/lightllm/utils/tensor_bucket.py
new file mode 100644
index 0000000000..a9d7a367dd
--- /dev/null
+++ b/lightllm/utils/tensor_bucket.py
@@ -0,0 +1,104 @@
+# copy from
+# https://raw.githubusercontent.com/sgl-project/sglang/refs/heads/main/python/sglang/
+# srt/weight_sync/tensor_bucket.py
+from dataclasses import dataclass
+from typing import List, Tuple
+
+import torch
+
+
+@dataclass
+class FlattenedTensorMetadata:
+ """Metadata for a tensor in a flattened bucket"""
+
+ name: str
+ shape: torch.Size
+ dtype: torch.dtype
+ start_idx: int
+ end_idx: int
+ numel: int
+
+
+class FlattenedTensorBucket:
+ """
+ A bucket that flattens multiple tensors into a single tensor for efficient processing
+ while preserving all metadata needed for reconstruction.
+ """
+
+ # This field is solely for users of to check whether the class supports this feature
+ supports_multi_dtypes = True
+
+ def __init__(
+ self,
+ named_tensors: List[Tuple[str, torch.Tensor]] = None,
+ flattened_tensor: torch.Tensor = None,
+ metadata: List[FlattenedTensorMetadata] = None,
+ ):
+ """
+ Initialize a tensor bucket from a list of named tensors OR from pre-flattened data.
+ Args:
+ named_tensors: List of (name, tensor) tuples (for creating new bucket)
+ flattened_tensor: Pre-flattened tensor (for reconstruction)
+ metadata: Pre-computed metadata (for reconstruction)
+ """
+ if named_tensors is not None:
+ # Create bucket from named tensors
+ self.metadata: List[FlattenedTensorMetadata] = [None] * len(named_tensors)
+ self.flattened_tensor: torch.Tensor = None
+
+ if not named_tensors:
+ raise ValueError("Cannot create empty tensor bucket")
+
+ # Collect metadata and flatten tensors
+ current_idx = 0
+ flattened_tensors: List[torch.Tensor] = [None] * len(named_tensors)
+
+ for i, (name, tensor) in enumerate(named_tensors):
+ flattened = tensor.flatten().view(torch.uint8)
+ flattened_tensors[i] = flattened
+
+ # Store metadata
+
+ numel = flattened.numel()
+ metadata_obj = FlattenedTensorMetadata(
+ name=name,
+ shape=tensor.shape,
+ dtype=tensor.dtype,
+ start_idx=current_idx,
+ end_idx=current_idx + numel,
+ numel=numel,
+ )
+ self.metadata[i] = metadata_obj
+ current_idx += numel
+
+ # Concatenate all flattened tensors
+ self.flattened_tensor = torch.cat(flattened_tensors, dim=0)
+ else:
+ # Initialize from pre-flattened data
+ if flattened_tensor is None or metadata is None:
+ raise ValueError("Must provide either named_tensors or both flattened_tensor and metadata")
+ self.flattened_tensor = flattened_tensor
+ self.metadata = metadata
+
+ def get_flattened_tensor(self) -> torch.Tensor:
+ """Get the flattened tensor containing all bucket tensors"""
+ return self.flattened_tensor
+
+ def get_metadata(self) -> List[FlattenedTensorMetadata]:
+ """Get metadata for all tensors in the bucket"""
+ return self.metadata
+
+ def reconstruct_tensors(self) -> List[Tuple[str, torch.Tensor]]:
+ """
+ Reconstruct original tensors from flattened tensor with optimized performance.
+ Uses memory-efficient operations to minimize allocations and copies.
+ """
+ # preallocate the result list
+ reconstructed = [None] * len(self.metadata)
+
+ for i, meta in enumerate(self.metadata):
+ tensor = self.flattened_tensor[meta.start_idx : meta.end_idx].view(meta.dtype).reshape(meta.shape)
+
+ reconstructed[i] = (meta.name, tensor)
+
+ return reconstructed
diff --git a/lightllm/utils/torch_memory_saver_utils.py b/lightllm/utils/torch_memory_saver_utils.py
new file mode 100644
index 0000000000..c1184ef30c
--- /dev/null
+++ b/lightllm/utils/torch_memory_saver_utils.py
@@ -0,0 +1,92 @@
+import torch
+from contextlib import contextmanager
+from enum import Enum
+from lightllm.utils.log_utils import init_logger
+
+try:
+ from torch_memory_saver import (
+ torch_memory_saver,
+ configure_subprocess,
+ )
+
+ HAS_TORCH_MEMORY_SAVER = True
+
+except ImportError:
+ HAS_TORCH_MEMORY_SAVER = False
+ pass
+
+logger = init_logger(__name__)
+
+
+class MemoryTag(Enum):
+ KV_CACHE = "kv_cache"
+ WEIGHT = "weights"
+ GRAPH = "graph"
+
+ def is_kv_cache(self):
+ return self == MemoryTag.KV_CACHE
+
+ def is_weight(self):
+ return self == MemoryTag.WEIGHT
+
+ def is_graph(self):
+ return self == MemoryTag.GRAPH
+
+ def __str__(self):
+ return self.value
+
+
+class TorchMemorySaverWrapper:
+ def __new__(cls, enable_torch_memory_saver: bool = False):
+ if enable_torch_memory_saver:
+ assert (
+ HAS_TORCH_MEMORY_SAVER
+ ), "torch_memory_saver is not installed, please install it via `pip install torch_memory_saver`."
+ return _TorchMemorySaver()
+ else:
+ return _TorchMemorySaverFake()
+
+
+class _TorchMemorySaver:
+ def configure_subprocess(self):
+ return configure_subprocess()
+
+ def region(self, tag: MemoryTag, enable_cpu_backup: bool = False):
+ return torch_memory_saver.region(tag=tag.value, enable_cpu_backup=enable_cpu_backup)
+
+ def cuda_graph(self, graph_obj: torch.cuda.CUDAGraph, **kwargs):
+ return torch_memory_saver.cuda_graph(cuda_graph=graph_obj, **kwargs, tag=MemoryTag.GRAPH.value)
+
+ def disable(self):
+ return torch_memory_saver.disable()
+
+ def pause(self, tag: MemoryTag):
+ return torch_memory_saver.pause(tag=tag.value)
+
+ def resume(self, tag: MemoryTag):
+ return torch_memory_saver.resume(tag=tag.value)
+
+
+class _TorchMemorySaverFake:
+ @contextmanager
+ def configure_subprocess(self):
+ yield
+
+ @contextmanager
+ def region(self, tag: MemoryTag, enable_cpu_backup: bool = False):
+ yield
+
+ def cuda_graph(self, graph_obj: torch.cuda.CUDAGraph, **kwargs):
+ return torch.cuda.graph(graph_obj, **kwargs)
+
+ @contextmanager
+ def disable(self):
+ yield
+
+ def pause(self, tag: MemoryTag):
+ logger.warning("torch_memory_saver is not enabled, pause is not supported.")
+ return
+
+ def resume(self, tag: MemoryTag):
+ logger.warning("torch_memory_saver is not enabled, resume is not supported.")
+ return
diff --git a/requirements.txt b/requirements.txt
index a3b9473f82..bbf9668221 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -95,3 +95,5 @@ partial_json_parser==0.2.1.1.post6
websockets==15.0.1
cupy-cuda12x==13.6.0
nixl==0.8.0
+torch_memory_saver==0.0.9
+portpicker==1.6.0
\ No newline at end of file
diff --git a/test/test_api/test_r3.py b/test/test_api/test_r3.py
new file mode 100644
index 0000000000..85c4e44ef9
--- /dev/null
+++ b/test/test_api/test_r3.py
@@ -0,0 +1,92 @@
+import sys
+import argparse
+import requests
+import base64
+import numpy as np
+
+
+def test_routing_export(url: str = "http://localhost:8000"):
+ print(f"Testing routing export at {url}")
+ print("-" * 50)
+
+ try:
+ response = requests.post(
+ f"{url}/generate",
+ json={
+ "inputs": "What is the capital of France? What is the capital of France?",
+ "parameters": {
+ "max_new_tokens": 50,
+ # "return_routed_experts": True,
+ # "repetition_penalty": 1.0,
+ },
+ },
+ timeout=60,
+ )
+ except requests.exceptions.ConnectionError:
+ print(f"ERROR: Cannot connect to server at {url}")
+ print("Make sure the LightLLM server is running with --enable_return_routed_experts")
+ return False
+ except requests.exceptions.Timeout:
+ print("ERROR: Request timed out")
+ return False
+
+ print(f"Status: {response.status_code}")
+
+ if response.status_code != 200:
+ print(f"ERROR: Request failed with status {response.status_code}")
+ print(f"Response: {response.text}")
+ return False
+
+ res = response.json()
+ print(f"Generated text: {res.get('generated_text', 'N/A')[:100]}...")
+
+ if "routed_experts" not in res or not res["routed_experts"]:
+ print("\nWARNING: No routed_experts in response.")
+ print("This could mean:")
+ print(" - The model is not a MoE model")
+ print(" - The server was not started with --enable_return_routed_experts")
+ print(" - The routing capture manager was not initialized")
+ return False
+
+ routing_info = res["routed_experts"]
+ shape = routing_info["shape"]
+ dtype_str = routing_info["dtype"]
+ dtype = np.dtype(dtype_str)
+ data = base64.b64decode(routing_info["data"])
+ routing_array = np.frombuffer(data, dtype=dtype).reshape(shape)
+
+ print(f"\n{'=' * 50}")
+ print("ROUTING CAPTURE SUCCESS!")
+ print(f"{'=' * 50}")
+ print(f"Shape: {shape}")
+ print(f"Dtype: {dtype}")
+ print(f"Num tokens: {shape[0]}")
+ print(f"Num MoE layers: {shape[1]}")
+ print(f"Top-K: {shape[2]}")
+
+ # Compute payload size savings
+ int32_size = np.prod(shape) * 4
+ actual_size = len(data)
+ savings = (1 - actual_size / int32_size) * 100
+ print(f"Payload: {actual_size} bytes (vs {int32_size} bytes with int32, {savings:.0f}% smaller)")
+
+ print(f"\nSample routing (first layer, first 5 tokens):")
+ num_tokens_to_show = shape[0]
+ for i in range(num_tokens_to_show):
+ print(f" Token {i}: experts {routing_array[i, 0, :].tolist()}")
+
+ if np.all(routing_array == 0):
+ print("\nWARNING: All routing data is zeros. Capture may not be working correctly.")
+ return False
+
+ print("\nTest PASSED!")
+ return True
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Test R3 routing export feature")
+ parser.add_argument("--url", default="http://localhost:8000", help="Server URL")
+ args = parser.parse_args()
+
+ success = test_routing_export(args.url)
+ sys.exit(0 if success else 1)
diff --git a/unit_tests/__init__.py b/unit_tests/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/unit_tests/common/__init__.py b/unit_tests/common/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/unit_tests/common/basemodel/__init__.py b/unit_tests/common/basemodel/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/unit_tests/common/basemodel/test_routing_capture_manager.py b/unit_tests/common/basemodel/test_routing_capture_manager.py
new file mode 100644
index 0000000000..dcc010b372
--- /dev/null
+++ b/unit_tests/common/basemodel/test_routing_capture_manager.py
@@ -0,0 +1,219 @@
+import torch
+import numpy as np
+
+
+class TestRoutingCaptureManager:
+ def test_capture_and_extract_basic(self):
+ """Test the core pipeline: capture → flush_to_kv_buffer → extract_from_gpu."""
+ from lightllm.common.basemodel.routing_manager import RoutingCaptureManager
+
+ manager = RoutingCaptureManager(
+ num_moe_layers=4,
+ topk=8,
+ num_experts=64,
+ kv_cache_size=1024,
+ max_capture_tokens=64,
+ )
+
+ # Simulate a batch of 10 tokens at KV-cache positions [100..109]
+ mem_indexes = torch.arange(100, 110, device="cuda")
+
+ # Capture routing for each MoE layer (writes to capture buffer)
+ for layer_idx in range(4):
+ topk_ids = torch.randint(0, 64, (10, 8), device="cuda")
+ manager.capture(moe_layer_index=layer_idx, topk_ids=topk_ids, microbatch_index=0)
+
+ # Flush from capture buffer to KV-indexed gpu_kv_buffer
+ manager.flush_to_kv_buffer(mem_indexes, num_tokens=10, microbatch_index=0)
+
+ # Extract for those same KV-cache positions
+ result = manager.extract_from_gpu(mem_indexes)
+ assert result.shape == (4, 10, 8)
+ assert result.dtype == np.int8
+
+ def test_capture_writes_to_correct_kv_positions(self):
+ """Verify that captured data lands in the right KV-cache positions after flush."""
+ from lightllm.common.basemodel.routing_manager import RoutingCaptureManager
+
+ manager = RoutingCaptureManager(
+ num_moe_layers=2,
+ topk=4,
+ num_experts=32,
+ kv_cache_size=256,
+ max_capture_tokens=16,
+ )
+
+ # Use non-contiguous mem_indexes to simulate real KV-cache
+ mem_indexes = torch.tensor([10, 50, 200], device="cuda")
+
+ topk_ids = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], device="cuda")
+ manager.capture(moe_layer_index=0, topk_ids=topk_ids, microbatch_index=0)
+
+ topk_ids_layer1 = topk_ids + 20
+ manager.capture(moe_layer_index=1, topk_ids=topk_ids_layer1, microbatch_index=0)
+
+ # Flush to KV positions
+ manager.flush_to_kv_buffer(mem_indexes, num_tokens=3, microbatch_index=0)
+
+ # Extract and verify
+ result = manager.extract_from_gpu(mem_indexes)
+ assert result.shape == (2, 3, 4)
+ np.testing.assert_array_equal(result[0], topk_ids.cpu().numpy())
+ np.testing.assert_array_equal(result[1], topk_ids_layer1.cpu().numpy())
+
+ def test_microbatch_isolation(self):
+ """Two microbatches writing to different KV positions don't interfere."""
+ from lightllm.common.basemodel.routing_manager import RoutingCaptureManager
+
+ manager = RoutingCaptureManager(
+ num_moe_layers=1,
+ topk=4,
+ num_experts=32,
+ kv_cache_size=256,
+ max_capture_tokens=16,
+ )
+
+ # Microbatch 0: positions [10, 11]
+ mem0 = torch.tensor([10, 11], device="cuda")
+ ids_0 = torch.ones((2, 4), dtype=torch.int64, device="cuda")
+ manager.capture(moe_layer_index=0, topk_ids=ids_0, microbatch_index=0)
+
+ # Microbatch 1: positions [20, 21]
+ mem1 = torch.tensor([20, 21], device="cuda")
+ ids_1 = torch.ones((2, 4), dtype=torch.int64, device="cuda") * 2
+ manager.capture(moe_layer_index=0, topk_ids=ids_1, microbatch_index=1)
+
+ # Flush each microbatch to different KV positions
+ manager.flush_to_kv_buffer(mem0, num_tokens=2, microbatch_index=0)
+ manager.flush_to_kv_buffer(mem1, num_tokens=2, microbatch_index=1)
+
+ # Extract microbatch 0
+ result0 = manager.extract_from_gpu(mem0)
+ assert result0.shape == (1, 2, 4)
+ assert result0[0, 0, 0] == 1
+
+ # Extract microbatch 1
+ result1 = manager.extract_from_gpu(mem1)
+ assert result1[0, 0, 0] == 2
+
+ def test_dtype_selection_int8(self):
+ """Models with ≤127 experts use int8."""
+ from lightllm.common.basemodel.routing_manager import RoutingCaptureManager
+
+ manager = RoutingCaptureManager(
+ num_moe_layers=1,
+ topk=2,
+ num_experts=64,
+ kv_cache_size=128,
+ max_capture_tokens=16,
+ )
+ assert manager.dtype == torch.int8
+ assert manager.np_dtype == np.int8
+ assert manager.dtype_id == 1
+
+ def test_dtype_selection_int16(self):
+ """Models with >127 experts use int16."""
+ from lightllm.common.basemodel.routing_manager import RoutingCaptureManager
+
+ manager = RoutingCaptureManager(
+ num_moe_layers=1,
+ topk=2,
+ num_experts=256,
+ kv_cache_size=128,
+ max_capture_tokens=16,
+ )
+ assert manager.dtype == torch.int16
+ assert manager.np_dtype == np.int16
+ assert manager.dtype_id == 2
+
+ def test_extract_preserves_values(self):
+ """Extracted values exactly match what was captured, no dtype truncation."""
+ from lightllm.common.basemodel.routing_manager import RoutingCaptureManager
+
+ manager = RoutingCaptureManager(
+ num_moe_layers=1,
+ topk=4,
+ num_experts=64,
+ kv_cache_size=64,
+ max_capture_tokens=16,
+ )
+
+ mem_indexes = torch.tensor([0, 1, 2], device="cuda")
+
+ topk_ids = torch.tensor([[10, 20, 30, 40], [50, 60, 63, 1], [0, 5, 127, 3]], device="cuda")
+ manager.capture(moe_layer_index=0, topk_ids=topk_ids, microbatch_index=0)
+
+ # Flush then extract
+ manager.flush_to_kv_buffer(mem_indexes, num_tokens=3, microbatch_index=0)
+ result = manager.extract_from_gpu(mem_indexes)
+ expected = topk_ids.cpu().numpy().astype(np.int8)
+ np.testing.assert_array_equal(result[0], expected)
+
+ def test_gpu_kv_buffer_shape(self):
+ """Buffer shape is (num_moe_layers, kv_cache_size, topk)."""
+ from lightllm.common.basemodel.routing_manager import RoutingCaptureManager
+
+ # 127 experts fits in int8 (max value 127)
+ manager = RoutingCaptureManager(
+ num_moe_layers=48,
+ topk=8,
+ num_experts=127,
+ kv_cache_size=2048,
+ max_capture_tokens=256,
+ )
+ assert manager.gpu_kv_buffer.shape == (48, 2048, 8)
+ assert manager.gpu_kv_buffer.dtype == torch.int8
+ assert manager.gpu_kv_buffer.device.type == "cuda"
+
+ # 128 experts requires int16
+ manager2 = RoutingCaptureManager(
+ num_moe_layers=48,
+ topk=8,
+ num_experts=128,
+ kv_cache_size=2048,
+ max_capture_tokens=256,
+ )
+ assert manager2.gpu_kv_buffer.dtype == torch.int16
+
+ def test_partial_token_capture(self):
+ """capture() only writes num_tokens rows to the buffer."""
+ from lightllm.common.basemodel.routing_manager import RoutingCaptureManager
+
+ manager = RoutingCaptureManager(
+ num_moe_layers=1,
+ topk=2,
+ num_experts=32,
+ kv_cache_size=128,
+ max_capture_tokens=16,
+ )
+
+ # Capture only 3 tokens, flush to 5 KV positions (first 3 get data)
+ mem_indexes = torch.tensor([10, 11, 12, 13, 14], device="cuda")
+
+ topk_ids = torch.tensor([[1, 2], [3, 4], [5, 6]], device="cuda") # only 3 tokens
+ manager.capture(moe_layer_index=0, topk_ids=topk_ids, microbatch_index=0)
+
+ # Flush only the 3 captured tokens
+ manager.flush_to_kv_buffer(mem_indexes[:3], num_tokens=3, microbatch_index=0)
+
+ # Positions 10-12 should have data, 13-14 should be zeros (from init)
+ result_written = manager.extract_from_gpu(mem_indexes[:3])
+ np.testing.assert_array_equal(result_written[0], topk_ids.cpu().numpy().astype(np.int8))
+
+ result_unwritten = manager.extract_from_gpu(mem_indexes[3:])
+ np.testing.assert_array_equal(result_unwritten[0], np.zeros((2, 2), dtype=np.int8))
+
+ def test_capture_buffer_shape(self):
+ """Capture buffer has correct shape (max_tokens, num_moe_layers, topk)."""
+ from lightllm.common.basemodel.routing_manager import RoutingCaptureManager
+
+ manager = RoutingCaptureManager(
+ num_moe_layers=4,
+ topk=8,
+ num_experts=64,
+ kv_cache_size=1024,
+ max_capture_tokens=256,
+ )
+ assert manager._capture_buffer[0].shape == (256, 4, 8)
+ assert manager._capture_buffer[1].shape == (256, 4, 8)
+ assert manager._capture_buffer[0].dtype == torch.int8
diff --git a/unit_tests/common/basemodel/triton_kernel/test_apply_invalid_token.py b/unit_tests/common/basemodel/triton_kernel/test_apply_invalid_token.py
new file mode 100644
index 0000000000..3b2f159f62
--- /dev/null
+++ b/unit_tests/common/basemodel/triton_kernel/test_apply_invalid_token.py
@@ -0,0 +1,50 @@
+import pytest
+import torch
+
+from lightllm.common.basemodel.triton_kernel.post_process.apply_invalid_token import (
+ apply_invalid_token_ids,
+)
+
+
+@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
+def test_apply_invalid_token_ids(dtype):
+ if not torch.cuda.is_available():
+ pytest.skip("CUDA is required for Triton kernels.")
+
+ batch_size = 4
+ vocab_size = 32
+ logits = torch.randn((batch_size, vocab_size), device="cuda", dtype=dtype)
+ expected = logits.clone()
+
+ invalid_token_ids_per_batch = [
+ [1, 3, 5],
+ [],
+ [0, 2, 31],
+ [7],
+ ]
+
+ flat_ids = []
+ cu_invalid_token_num = [0]
+ invalid_token_num_start = 0
+ for ids in invalid_token_ids_per_batch:
+ flat_ids.extend(ids)
+ invalid_token_num_start += len(ids)
+ cu_invalid_token_num.append(invalid_token_num_start)
+
+ invalid_token_ids = torch.tensor(flat_ids, device="cuda", dtype=torch.int32)
+ cu_invalid_token_num = torch.tensor(cu_invalid_token_num, device="cuda", dtype=torch.int32)
+
+ for batch_idx, ids in enumerate(invalid_token_ids_per_batch):
+ if ids:
+ expected[batch_idx, ids] = float("-inf")
+
+ apply_invalid_token_ids(
+ Logits=logits,
+ invalid_token_ids=invalid_token_ids,
+ cu_invalid_token_num=cu_invalid_token_num,
+ )
+ assert torch.equal(logits, expected)
+
+
+if __name__ == "__main__":
+ pytest.main([__file__])
diff --git a/unit_tests/server/router/dynamic_prompt/test_radix_cache.py b/unit_tests/server/router/dynamic_prompt/test_radix_cache.py
index 605433e9d8..dfeda0b6f7 100644
--- a/unit_tests/server/router/dynamic_prompt/test_radix_cache.py
+++ b/unit_tests/server/router/dynamic_prompt/test_radix_cache.py
@@ -230,5 +230,32 @@ def test_case9():
assert torch.equal(unmerged_node_d.token_id_key, torch.tensor([6], dtype=torch.int64))
+def test_case10():
+ """
+ 测试场景:测试 flush_cache 函数
+ """
+ print("\nTest Case 10: Testing flush_cache function\n")
+ tree = RadixCache("unique_name", 100, 0)
+ tree.insert(torch.tensor([1, 2, 3], dtype=torch.int64))
+ tree.insert(torch.tensor([1, 2, 3, 4, 5], dtype=torch.int64))
+ tree_node, size, values = tree.match_prefix(
+ torch.tensor([1, 2, 3], dtype=torch.int64, device="cpu"), update_refs=True
+ )
+ assert tree_node is not None
+ assert size == 3
+ tree.flush_cache()
+ tree_node, size, values = tree.match_prefix(
+ torch.tensor([1, 2, 3], dtype=torch.int64, device="cpu"), update_refs=True
+ )
+ assert tree_node is None
+ assert size == 0
+ assert tree.get_tree_total_tokens_num() == 0
+ assert tree.get_refed_tokens_num() == 0
+ assert len(tree.root_node.children) == 0
+ assert tree.root_node.token_id_key.numel() == 0
+ assert tree.root_node.token_mem_index_value.numel() == 0
+ assert tree.root_node.ref_counter == 1
+
+
if __name__ == "__main__":
pytest.main()