From b636310df008c61dce2096835dfd668daa76999c Mon Sep 17 00:00:00 2001 From: shihaobai <42648726+shihaobai@users.noreply.github.com> Date: Fri, 14 Nov 2025 13:38:31 +0800 Subject: [PATCH 01/71] add /flush_cache (#1108) --- lightllm/server/api_http.py | 26 +++++++++++ lightllm/server/api_start.py | 8 ++-- lightllm/server/core/objs/start_args_type.py | 1 + lightllm/server/httpserver/manager.py | 11 +++++ lightllm/server/io_struct.py | 7 +++ .../router/dynamic_prompt/radix_cache.py | 24 +++++++++++ lightllm/server/router/manager.py | 34 +++++++++++++++ lightllm/server/router/mananger_rpc.py | 43 +++++++++++++++++++ .../model_infer/mode_backend/base_backend.py | 5 +++ .../server/router/model_infer/model_rpc.py | 19 ++++++++ .../router/dynamic_prompt/test_radix_cache.py | 27 ++++++++++++ 11 files changed, 202 insertions(+), 3 deletions(-) create mode 100644 lightllm/server/io_struct.py create mode 100644 lightllm/server/router/mananger_rpc.py diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index 8bda50fb7..2ef01ea90 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -58,6 +58,7 @@ CompletionRequest, CompletionResponse, ) +from .io_struct import AbortReq from .build_prompt import build_prompt, init_tokenizer logger = init_logger(__name__) @@ -291,6 +292,30 @@ async def metrics() -> Response: return response +@app.post("/abort_req") +async def abort_req(request: AbortReq, raw_request: Request): + """Abort a request.""" + try: + await g_objs.httpserver_manager.abort_req(request) + return Response(status_code=200) + except Exception as e: + return create_error_response(HTTPStatus.EXPECTATION_FAILED, f"error: {str(e)}") + + +@app.post("/flush_cache") +@app.get("/flush_cache") +async def flush_cache(): + """Flush the radix cache.""" + ret = await g_objs.httpserver_manager.flush_cache() + return Response( + content="Cache flushed successfully." + if ret + else "Cache flush failed. " + + "When there are running or waiting requests, the operation will not be performed.", + status_code=200 if ret else 500, + ) + + @app.websocket("/pd_register") async def register_and_keep_alive(websocket: WebSocket): await websocket.accept() @@ -357,6 +382,7 @@ async def startup_event(): logger.info("server start up") loop = asyncio.get_event_loop() g_objs.set_args(get_env_start_args()) + g_objs.httpserver_manager.connect_router_rpc() loop.create_task(g_objs.httpserver_manager.handle_loop()) logger.info(f"server start up ok, loop use is {asyncio.get_event_loop()}") return diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index f73be30db..138b0a599 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -225,11 +225,12 @@ def normal_or_p_d_start(args): node_world_size = args.tp // args.nnodes can_use_ports = alloc_can_use_network_port( - num=8 + node_world_size + args.visual_dp * args.visual_tp, used_nccl_ports=already_uesd_ports + num=9 + node_world_size + args.visual_dp * args.visual_tp, used_nccl_ports=already_uesd_ports ) logger.info(f"alloced ports: {can_use_ports}") ( router_port, + router_rpc_port, detokenization_port, http_server_port, visual_port, @@ -237,8 +238,8 @@ def normal_or_p_d_start(args): cache_port, metric_port, multi_level_kv_cache_port, - ) = can_use_ports[0:8] - can_use_ports = can_use_ports[8:] + ) = can_use_ports[0:9] + can_use_ports = can_use_ports[9:] visual_model_tp_ports = [] for _ in range(args.visual_dp): @@ -248,6 +249,7 @@ def normal_or_p_d_start(args): # 将申请好的端口放入args参数中 args.router_port = router_port + args.router_rpc_port = router_rpc_port args.detokenization_port = detokenization_port args.http_server_port = http_server_port args.visual_port = visual_port diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 69d907fff..659aab1dc 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -113,6 +113,7 @@ class StartArgs: disk_cache_storage_size: float = field(default=10) # zmp ports router_port: int = field(default=None) + router_rpc_port: int = field(default=None) detokenization_port: int = field(default=None) http_server_port: int = field(default=None) visual_port: int = field(default=None) diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 11919398e..7158b8923 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -255,6 +255,13 @@ def alloc_req_id(self, sampling_params, is_health_req: bool = False): assert False, "dead code path" return group_request_id + def connect_router_rpc(self): + from lightllm.server.router.mananger_rpc import connect_router_rpc + + self.router_rpc_client = connect_router_rpc(self.args.router_rpc_port) + logger.info("HttpServerManager connected to Router RPC service successfully") + return + async def generate( self, prompt: Union[str, List[int]], @@ -763,6 +770,10 @@ async def handle_loop(self): self.recycle_event.set() return + async def flush_cache(self): + ret = await self.router_rpc_client.flush_cache() + return ret + class ReqStatus: def __init__(self, group_request_id, multimodal_params, req_objs: List[Req], start_time) -> None: diff --git a/lightllm/server/io_struct.py b/lightllm/server/io_struct.py new file mode 100644 index 000000000..68d32816f --- /dev/null +++ b/lightllm/server/io_struct.py @@ -0,0 +1,7 @@ +from dataclasses import dataclass + + +@dataclass +class AbortReq: + request_id: int = -1 + abort_all: bool = False diff --git a/lightllm/server/router/dynamic_prompt/radix_cache.py b/lightllm/server/router/dynamic_prompt/radix_cache.py index 2bf0a4d5a..9c207ec30 100644 --- a/lightllm/server/router/dynamic_prompt/radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/radix_cache.py @@ -424,6 +424,30 @@ def clear_tree_nodes(self): self.refed_tokens_num.arr[0] = 0 return + def flush_cache(self): + nodes_to_clear = collections.deque(self.root_node.children.values()) + self.root_node.children.clear() + while nodes_to_clear: + node = nodes_to_clear.popleft() + nodes_to_clear.extend(node.children.values()) + node.parent = None + node.children.clear() + + self.root_node.token_id_key[:] = 0 + self.root_node.token_mem_index_value[:] = 0 + self.root_node.ref_counter = 1 # 保持为1,确保不会被evict + self.root_node.time_id = time_gen.generate_time_id() + self.root_node.node_value_len = 0 + self.root_node.node_prefix_total_len = 0 + + self.evict_tree_set.clear() + self.evict_tree_set.add(self.root_node) + + self.tree_total_tokens_num.arr[0] = 0 + self.refed_tokens_num.arr[0] = 0 + + return + def dec_node_ref_counter(self, node: TreeNode): if node is None: return diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 3c8ca2399..b6132b12f 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -5,6 +5,7 @@ import pickle import inspect import setproctitle +import rpyc asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) import zmq @@ -151,6 +152,9 @@ async def wait_to_model_ready(self): rpc_finished_event=self.rpc_finished_event, ) + # 启动 rpyc 服务,供 HTTP Server 远程调用 + self._start_router_rpc_service() + kvargs = { "args": self.args, "rank_id": None, # 由后续处理填充真实数据 @@ -231,6 +235,25 @@ async def wait_to_model_ready(self): return + def _start_router_rpc_service(self): + """launch a rpyc service for httpserver to call RouterManager""" + import threading + from rpyc.utils.server import ThreadedServer + import lightllm.utils.rpyc_fix_utils as _ + from .mananger_rpc import RouterRpcService + + service = RouterRpcService(self) + port = self.args.router_rpc_port + + def start_server(): + t = ThreadedServer(service, port=port, protocol_config={"allow_pickle": True}) + t.start() + + rpc_thread = threading.Thread(target=start_server, daemon=True) + rpc_thread.start() + logger.info(f"Router RPC service started successfully on port {port}") + return + def _get_schedule_time_interval(self): # dp 模式,为了更好的配平,需要更长的调度间隔,以便于能收到更多的请求 return self.schedule_time_interval @@ -535,6 +558,17 @@ async def _recv_new_reqs_and_schedule(self): self._generate_new_batch() return + def flush_cache(self) -> bool: + if self.running_batch is not None: + return False + if self.req_queue.get_wait_req_num() > 0: + return False + # if radix cache client is not initialized, just return True + if self.radix_cache_client is None: + return True + # only flush cache when no running batch and no waiting requests + return self.model_rpc_client.flush_radix_cache() + def clean_up(self): return diff --git a/lightllm/server/router/mananger_rpc.py b/lightllm/server/router/mananger_rpc.py new file mode 100644 index 000000000..60f9e0458 --- /dev/null +++ b/lightllm/server/router/mananger_rpc.py @@ -0,0 +1,43 @@ +import rpyc +import asyncio +import socket +from .manager import RouterManager + + +class RouterRpcService(rpyc.Service): + def __init__(self, router_manager: "RouterManager"): + super().__init__() + self.router_manager = router_manager + return + + def exposed_flush_cache(self) -> bool: + return self.router_manager.flush_cache() + + +class RouterRpcClient: + def __init__(self, router_rpc_conn): + self.router_rpc_conn = router_rpc_conn + + def async_wrap(f): + f = rpyc.async_(f) + + async def _func(*args, **kwargs): + ans = f(*args, **kwargs) + await asyncio.to_thread(ans.wait) + # raise if exception + return ans.value + + return _func + + self._flush_cache = async_wrap(self.router_rpc_conn.root.flush_cache) + return + + async def flush_cache(self) -> bool: + ans = await self._flush_cache() + return ans + + +def connect_router_rpc(port: int) -> RouterRpcClient: + router_rpc_conn = rpyc.connect("localhost", port, config={"allow_pickle": True}) + router_rpc_conn._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + return RouterRpcClient(router_rpc_conn) diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 95f0c9951..db708f3cf 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -288,6 +288,11 @@ def init_mtp_draft_model(self, main_kvargs: dict): self.logger.info(f"loaded mtp model class {self.draft_models[i].__class__}") return + def flush_radix_cache(self): + if self.radix_cache is not None: + self.radix_cache.flush_cache() + return + def _async_copy_next_token_infos_to_pin_mem(self, next_token_ids: torch.Tensor, next_token_logprobs: torch.Tensor): """ 这个函数会把next token id和logprobs保存到pinned memory中 diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py index 1bb625db0..b7797a762 100644 --- a/lightllm/server/router/model_infer/model_rpc.py +++ b/lightllm/server/router/model_infer/model_rpc.py @@ -181,6 +181,15 @@ def init_model(self, kvargs): def get_max_total_token_num(self): return self.backend.get_max_total_token_num() + def flush_radix_cache(self): + try: + if self.backend is not None: + self.backend.flush_radix_cache() + return True + except BaseException as e: + logger.exception(f"flush radix cache failed: {str(e)}") + return False + class ModelRpcClient: def __init__(self, rpc_event, rpc_finished_event): @@ -211,6 +220,16 @@ async def get_max_total_token_num(self): assert func_name == "get_max_total_token_num" return ret + def flush_radix_cache(self) -> bool: + self.rpc_shm_params.write_func_params("flush_radix_cache", ()) + self.rpc_event.set() + + self.rpc_finished_event.wait() + self.rpc_finished_event.clear() + func_name, ret = self.rpc_shm_results.read_func_result() + assert func_name == "flush_radix_cache" + return ret + def _init_env( args, diff --git a/unit_tests/server/router/dynamic_prompt/test_radix_cache.py b/unit_tests/server/router/dynamic_prompt/test_radix_cache.py index 605433e9d..dfeda0b6f 100644 --- a/unit_tests/server/router/dynamic_prompt/test_radix_cache.py +++ b/unit_tests/server/router/dynamic_prompt/test_radix_cache.py @@ -230,5 +230,32 @@ def test_case9(): assert torch.equal(unmerged_node_d.token_id_key, torch.tensor([6], dtype=torch.int64)) +def test_case10(): + """ + 测试场景:测试 flush_cache 函数 + """ + print("\nTest Case 10: Testing flush_cache function\n") + tree = RadixCache("unique_name", 100, 0) + tree.insert(torch.tensor([1, 2, 3], dtype=torch.int64)) + tree.insert(torch.tensor([1, 2, 3, 4, 5], dtype=torch.int64)) + tree_node, size, values = tree.match_prefix( + torch.tensor([1, 2, 3], dtype=torch.int64, device="cpu"), update_refs=True + ) + assert tree_node is not None + assert size == 3 + tree.flush_cache() + tree_node, size, values = tree.match_prefix( + torch.tensor([1, 2, 3], dtype=torch.int64, device="cpu"), update_refs=True + ) + assert tree_node is None + assert size == 0 + assert tree.get_tree_total_tokens_num() == 0 + assert tree.get_refed_tokens_num() == 0 + assert len(tree.root_node.children) == 0 + assert tree.root_node.token_id_key.numel() == 0 + assert tree.root_node.token_mem_index_value.numel() == 0 + assert tree.root_node.ref_counter == 1 + + if __name__ == "__main__": pytest.main() From 60c379ed7e9a25a8db867e0190ef7656d8d6dd7a Mon Sep 17 00:00:00 2001 From: shihaobai <42648726+shihaobai@users.noreply.github.com> Date: Tue, 18 Nov 2025 14:51:43 +0800 Subject: [PATCH 02/71] Aborted reqs (#1113) --- lightllm/server/api_http.py | 6 +- lightllm/server/core/objs/io_objs/__init__.py | 2 +- .../server/core/objs/io_objs/group_req.py | 25 +--- lightllm/server/core/objs/req.py | 12 +- lightllm/server/detokenization/decode_req.py | 8 +- lightllm/server/detokenization/manager.py | 7 +- lightllm/server/httpserver/manager.py | 118 ++++++++++++------ lightllm/server/io_struct.py | 63 +++++++++- .../server/multi_level_kv_cache/manager.py | 8 +- lightllm/server/multimodal_params.py | 15 ++- lightllm/server/router/manager.py | 63 ++++++---- .../server/router/model_infer/infer_batch.py | 2 + .../server/router/req_queue/base_queue.py | 11 ++ .../req_queue/chunked_prefill/beam_impl.py | 5 +- .../router/req_queue/chunked_prefill/impl.py | 5 +- .../chunked_prefill/impl_for_nixl_pd.py | 3 +- .../chunked_prefill/impl_for_pd_decode.py | 3 +- .../server/router/req_queue/dp_base_queue.py | 6 + lightllm/server/visualserver/manager.py | 11 +- 19 files changed, 254 insertions(+), 119 deletions(-) diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index 2ef01ea90..2c8548873 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -292,11 +292,11 @@ async def metrics() -> Response: return response -@app.post("/abort_req") -async def abort_req(request: AbortReq, raw_request: Request): +@app.post("/abort_request") +async def abort_request(request: AbortReq, raw_request: Request): """Abort a request.""" try: - await g_objs.httpserver_manager.abort_req(request) + await g_objs.httpserver_manager.abort_request(request) return Response(status_code=200) except Exception as e: return create_error_response(HTTPStatus.EXPECTATION_FAILED, f"error: {str(e)}") diff --git a/lightllm/server/core/objs/io_objs/__init__.py b/lightllm/server/core/objs/io_objs/__init__.py index c9b806c47..10386b70e 100644 --- a/lightllm/server/core/objs/io_objs/__init__.py +++ b/lightllm/server/core/objs/io_objs/__init__.py @@ -1 +1 @@ -from .group_req import GroupReqIndexes, GroupReqObjs, AbortedReqCmd, StopStrMatchedReqCmd +from .group_req import AbortedReqCmd, StopStrMatchedReqCmd diff --git a/lightllm/server/core/objs/io_objs/group_req.py b/lightllm/server/core/objs/io_objs/group_req.py index dfcbdd256..d644c0c31 100644 --- a/lightllm/server/core/objs/io_objs/group_req.py +++ b/lightllm/server/core/objs/io_objs/group_req.py @@ -1,33 +1,10 @@ from dataclasses import dataclass from lightllm.server.multimodal_params import MultimodalParams +from lightllm.server.core.objs.sampling_params import SamplingParams from typing import List from ..req import Req -@dataclass -class GroupReqIndexes: - group_req_id: int - multimodal_params: MultimodalParams - shm_req_indexes: List[int] - time_mark: float - - -@dataclass -class GroupReqObjs: - group_req_id: int - multimodal_params: MultimodalParams - shm_req_objs: List[Req] - time_mark: float - - def to_group_req_index(self): - return GroupReqIndexes( - group_req_id=self.group_req_id, - multimodal_params=self.multimodal_params, - shm_req_indexes=[req.index_in_shm_mem for req in self.shm_req_objs], - time_mark=self.time_mark, - ) - - @dataclass class AbortedReqCmd: req_id: int diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py index 947f24644..e6f878b25 100644 --- a/lightllm/server/core/objs/req.py +++ b/lightllm/server/core/objs/req.py @@ -24,19 +24,20 @@ class FinishStatus(ctypes.Structure): NO_FINISH = 0 FINISHED_STOP = 1 FINISHED_LENGTH = 2 + FINISHED_ABORTED = 3 def __init__(self, init_state=NO_FINISH): self.status = init_state def set_status(self, new_status): - assert 0 <= new_status <= 2 + assert 0 <= new_status <= 3 self.status = new_status def get_status(self): return self.status def is_finished(self): - return self.FINISHED_STOP <= self.status <= self.FINISHED_LENGTH + return self.FINISHED_STOP <= self.status <= self.FINISHED_ABORTED def is_stopped(self): return self.status == self.FINISHED_STOP @@ -49,6 +50,8 @@ def get_finish_reason(self): return "stop" elif self.status == self.FINISHED_LENGTH: return "length" + elif self.status == self.FINISHED_ABORTED: + return "abort" return None @@ -247,9 +250,8 @@ def can_release(self): ref_count_ok = self.ref_count == 1 can_released_mark = self.can_released_mark - if self.is_aborted and can_released_mark and ref_count_ok: - return True - + # if self.is_aborted and can_released_mark and ref_count_ok: + # return True ok_finished_gen_req = self.finish_status.is_finished() or self.stop_str_matched if ok_finished_gen_req and can_released_mark and ref_count_ok and self.out_tokens_queue.is_empty(): diff --git a/lightllm/server/detokenization/decode_req.py b/lightllm/server/detokenization/decode_req.py index 9aa3a8eff..c77379986 100644 --- a/lightllm/server/detokenization/decode_req.py +++ b/lightllm/server/detokenization/decode_req.py @@ -62,11 +62,7 @@ def stop_sequences_str_match(self) -> bool: return False def need_detoken(self): - if ( - (not self.req.is_aborted) - and (not self.req.stop_str_matched) - and len(self.output_ids) < self.req.candetoken_out_len - ): + if (not self.req.stop_str_matched) and len(self.output_ids) < self.req.candetoken_out_len: return True return False @@ -83,8 +79,6 @@ def get_decode_tokens(self): return prefix_tokens, read_tokens def can_set_release_mark(self): - if self.req.is_aborted: - return True if self.req.stop_str_matched: return True if ( diff --git a/lightllm/server/detokenization/manager.py b/lightllm/server/detokenization/manager.py index 389171ba8..ab5f706b9 100644 --- a/lightllm/server/detokenization/manager.py +++ b/lightllm/server/detokenization/manager.py @@ -6,7 +6,6 @@ import zmq import inspect from lightllm.server.core.objs import ShmReqManager, StartArgs -from lightllm.server.core.objs.io_objs import GroupReqIndexes from lightllm.utils.graceful_utils import graceful_registry from typing import Union, Dict, List from .decode import decode_token @@ -17,6 +16,7 @@ import time from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import get_unique_server_name +from lightllm.server.io_struct import BaseReq logger = init_logger(__name__) @@ -46,7 +46,7 @@ def _init_get_token_id_to_token_str(self): self.token_id_to_token = {token_id: token for token, token_id in self.tokenizer.get_vocab().items()} return - def _add_new_group_req_index(self, recv_obj: GroupReqIndexes): + def _add_new_group_req_index(self, recv_obj: BaseReq): for req_index in recv_obj.shm_req_indexes: req = self.shm_req_manager.get_req_obj_by_index(req_index) req.link_prompt_ids_shm_array() @@ -74,8 +74,7 @@ def handle_loop(self): try: # 一次最多从 zmq 中取 recv_max_count 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。 for _ in range(recv_max_count): - recv_obj: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) - assert isinstance(recv_obj, GroupReqIndexes) + recv_obj: BaseReq = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) self._add_new_group_req_index(recv_obj=recv_obj) # 当队列中存在较多的请求时,将一次接受的数量上调 diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 7158b8923..5254e2097 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -25,12 +25,18 @@ from lightllm.server.core.objs import Req, FinishStatus, StartArgs from lightllm.server.core.objs import SamplingParams from lightllm.server.core.objs.out_token_circlequeue import LIGHTLLM_OUT_TOKEN_QUEUE_SIZE -from lightllm.server.core.objs.io_objs import GroupReqObjs from lightllm.server.core.objs.shm_req_manager import ShmReqManager from lightllm.server.core.objs.atomic_array_lock import AtomicShmArrayLock, AsyncLock, AtomicLockItem from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt from lightllm.utils.log_utils import init_logger from lightllm.server.metrics.manager import MetricClient +from lightllm.server.io_struct import ( + AbortReq, + BaseReq, + GenerateReq, + GenerateReqMeta, + GenerateReqIndex, +) from lightllm.utils.statics_utils import MovingAverage from lightllm.utils.config_utils import get_vocab_size from lightllm.utils.envs_utils import get_unique_server_name @@ -74,7 +80,7 @@ def __init__( self.multinode_req_manager = context.socket(zmq.PULL) self.multinode_req_manager.bind(f"tcp://*:{args.multinode_httpmanager_port}") logger.info( - f"HttpServerManager listening for child node requests on *:{args.multinode_httpmanager_port}" + f"HttpServerManager listening for master node requests on *:{args.multinode_httpmanager_port}" ) self.enable_multimodal = args.enable_multimodal @@ -218,18 +224,32 @@ def tokens(self, prompt, multimodal_params, samping_params: SamplingParams, kwar async def loop_for_request(self): assert self.args.node_rank > 0 while True: - ( - prompt, - sampling_params, - multimodal_params, - ) = await self.multinode_req_manager.recv_pyobj() - results_generator = self.generate(prompt, sampling_params, multimodal_params, None) + req_obj = await self.multinode_req_manager.recv_pyobj() + if req_obj is None: + continue + if isinstance(req_obj, GenerateReqMeta): + self.process_generate_request(req_obj) + elif isinstance(req_obj, AbortReq): + self.process_abort_request(req_obj) + else: + assert False, f"Unknown request type: {type(req_obj)}" + return + + def process_generate_request(self, req_meta: GenerateReqMeta): + prompt = req_meta.prompt + sampling_params = req_meta.sampling_params + multimodal_params = req_meta.multimodal_params + results_generator = self.generate(prompt, sampling_params, multimodal_params, None) + + async def generate_wrapper(results_generator): + async for _, _, _, _ in results_generator: + pass - async def generate_wrapper(results_generator): - async for _, _, _, _ in results_generator: - pass + asyncio.create_task(generate_wrapper(results_generator)) + return - asyncio.create_task(generate_wrapper(results_generator)) + def process_abort_request(self, request: AbortReq): + asyncio.create_task(self.abort_request(request)) return def alloc_req_id(self, sampling_params, is_health_req: bool = False): @@ -279,10 +299,6 @@ async def generate( group_request_id = self.alloc_req_id(sampling_params, is_health_req) try: - original_multimodal_params = None - if self.is_multinode_tp_master: - original_multimodal_params = copy.deepcopy(multimodal_params) - if self.pd_mode.is_P_or_NORMAL(): await multimodal_params.verify_and_preload(request) @@ -346,12 +362,17 @@ async def generate( ) req_objs.append(req_obj) - req_status = ReqStatus(group_request_id, multimodal_params, req_objs, start_time) + req_status = ReqStatus( + group_request_id=group_request_id, + prompt=prompt, + sampling_params=sampling_params, + multimodal_params=multimodal_params, + req_objs=req_objs, + start_time=start_time, + ) self.req_id_to_out_inf[group_request_id] = req_status - await self.transfer_to_next_module_or_node( - prompt, sampling_params, original_multimodal_params, req_status.group_req_objs - ) + await self.transfer_to_next_module_or_node(req_status.group_req_objs) results_generator = self._wait_to_token_package( start_time, @@ -482,44 +503,49 @@ async def _check_and_repair_length(self, prompt_ids: List[int], sampling_params: async def transfer_to_next_module_or_node( self, - prompt: str, - sampling_params: SamplingParams, - original_multimodal_params: MultimodalParams, - group_req_objs: Optional[GroupReqObjs] = None, + req_obj: Optional["BaseReq"] = None, ): # 多节点纯tp 运行模式下,master 节点需要将请求转发给slave节点. + req_to_next_node = req_obj.get_req_to_next_node() + self.transfer_to_next_node(req_to_next_node) + req_to_next_module = req_obj.get_req_to_next_module() + await self.transfer_to_next_module(req_to_next_module) + return + + def transfer_to_next_node( + self, + req_to_next_node: Optional["BaseReq"] = None, + ): if self.is_multinode_tp_master: for sender in self.multinode_req_manager: sender.send_pyobj( - (prompt, sampling_params, original_multimodal_params), + req_to_next_node, protocol=pickle.HIGHEST_PROTOCOL, ) - - await self.transfer_to_next_module(group_req_objs) return async def transfer_to_next_module( self, - group_req_objs: Optional[GroupReqObjs] = None, + req_to_next_module: Optional["GenerateReqIndex"] = None, ): if self.pd_mode.is_P_or_NORMAL(): if self.enable_multimodal: self.send_to_visual.send_pyobj( - group_req_objs.to_group_req_index(), + req_to_next_module, protocol=pickle.HIGHEST_PROTOCOL, ) return if self.args.enable_cpu_cache: self.send_to_multi_level_kv_cache.send_pyobj( - group_req_objs.to_group_req_index(), + req_to_next_module, protocol=pickle.HIGHEST_PROTOCOL, ) return self.send_to_router.send_pyobj( - group_req_objs.to_group_req_index(), + req_to_next_module, protocol=pickle.HIGHEST_PROTOCOL, ) return @@ -527,7 +553,7 @@ async def transfer_to_next_module( if self.pd_mode.is_D(): # 在 D 模式下,不需要传输真的多模态参数,因为其已经被 P 处理好了 self.send_to_router.send_pyobj( - group_req_objs.to_group_req_index(), + req_to_next_module, protocol=pickle.HIGHEST_PROTOCOL, ) return @@ -643,12 +669,24 @@ async def abort(self, group_req_id: int) -> bool: logger.warning(f"aborted group_request_id {group_req_id} not exist") return False - group_req_objs: GroupReqObjs = req_status.group_req_objs + group_req_objs: GenerateReq = req_status.group_req_objs for req in group_req_objs.shm_req_objs: req.is_aborted = True logger.warning(f"aborted group_request_id {group_req_objs.group_req_id}") return True + async def abort_request(self, request: AbortReq): + request_id = request.request_id + abort_all = request.abort_all + if self.is_multinode_tp_master: + self.transfer_to_next_node(req_to_next_node=request) + if request_id is not None and not abort_all: + await self.abort(request_id) + if abort_all: + for group_req_id in list(self.req_id_to_out_inf.keys()): + await self.abort(group_req_id) + pass + async def recycle_resource_loop(self): pre_time_mark = time.time() @@ -776,11 +814,21 @@ async def flush_cache(self): class ReqStatus: - def __init__(self, group_request_id, multimodal_params, req_objs: List[Req], start_time) -> None: + def __init__( + self, + group_request_id: int, + prompt: str, + sampling_params: SamplingParams, + multimodal_params: MultimodalParams, + req_objs: List[Req], + start_time, + ) -> None: self.lock = asyncio.Lock() self.event = asyncio.Event() - self.group_req_objs = GroupReqObjs( + self.group_req_objs = GenerateReq( group_req_id=group_request_id, + prompt=prompt, + sampling_params=sampling_params, multimodal_params=multimodal_params, shm_req_objs=req_objs, time_mark=start_time, diff --git a/lightllm/server/io_struct.py b/lightllm/server/io_struct.py index 68d32816f..b5adff954 100644 --- a/lightllm/server/io_struct.py +++ b/lightllm/server/io_struct.py @@ -1,7 +1,66 @@ +from abc import ABC from dataclasses import dataclass +from lightllm.server.core.objs.req import Req +from lightllm.server.core.objs.sampling_params import SamplingParams +from lightllm.server.multimodal_params import MultimodalParams +from typing import List @dataclass -class AbortReq: - request_id: int = -1 +class BaseReq(ABC): + def get_req_to_next_node(self): + return self + + def get_req_to_next_module(self): + return self + + +# for next node +@dataclass +class GenerateReqMeta(BaseReq): + prompt: str + sampling_params: SamplingParams + multimodal_params: MultimodalParams + + +# for next module +@dataclass +class GenerateReqIndex(BaseReq): + group_req_id: int + multimodal_params: MultimodalParams + shm_req_indexes: List[int] + time_mark: float + + +@dataclass +class GenerateReq(BaseReq): + group_req_id: int + prompt: str + sampling_params: SamplingParams + multimodal_params: MultimodalParams + shm_req_objs: List[Req] + time_mark: float + + def get_req_to_next_module(self): + # 已经完成跨节点转发,可以释放图片原始资源 + self.multimodal_params.free() + return GenerateReqIndex( + group_req_id=self.group_req_id, + multimodal_params=self.multimodal_params, + shm_req_indexes=[req.index_in_shm_mem for req in self.shm_req_objs], + time_mark=self.time_mark, + ) + + def get_req_to_next_node(self): + return GenerateReqMeta( + prompt=self.prompt, + sampling_params=self.sampling_params, + multimodal_params=self.multimodal_params, + ) + + +@dataclass +class AbortReq(BaseReq): + # 外部调用传入,等同内部的 group_req_id + request_id: int = None abort_all: bool = False diff --git a/lightllm/server/multi_level_kv_cache/manager.py b/lightllm/server/multi_level_kv_cache/manager.py index 8853e352e..e3bbe268b 100644 --- a/lightllm/server/multi_level_kv_cache/manager.py +++ b/lightllm/server/multi_level_kv_cache/manager.py @@ -10,7 +10,7 @@ import concurrent.futures from queue import Queue from lightllm.server.core.objs import ShmReqManager, Req, StartArgs -from lightllm.server.core.objs.io_objs import GroupReqIndexes +from lightllm.server.io_struct import GenerateReqIndex from lightllm.utils.graceful_utils import graceful_registry from .cpu_cache_client import CpuKvCacheClient from lightllm.utils.log_utils import init_logger @@ -51,7 +51,7 @@ def cpu_cache_hanle_loop(self): logger.exception(str(e)) return - def _handle_group_req_cpu_cache_match(self, group_req_indexes: GroupReqIndexes, start_time: float): + def _handle_group_req_cpu_cache_match(self, group_req_indexes: GenerateReqIndex, start_time: float): """ match cpu cache pages """ @@ -110,8 +110,8 @@ def recv_loop(self): try: # 一次最多从 zmq 中取 recv_max_count 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。 for _ in range(recv_max_count): - recv_obj: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) - assert isinstance(recv_obj, GroupReqIndexes) + recv_obj: GenerateReqIndex = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) + assert isinstance(recv_obj, GenerateReqIndex) recv_objs.append(recv_obj) start_time = recv_obj.time_mark diff --git a/lightllm/server/multimodal_params.py b/lightllm/server/multimodal_params.py index 066fe5cc2..9a1529a06 100644 --- a/lightllm/server/multimodal_params.py +++ b/lightllm/server/multimodal_params.py @@ -54,9 +54,11 @@ def read(self): assert self._preload_data is not None ans = self._preload_data self._preload_data = None - self._data = None return ans + def free(self): + self._data = None + def to_dict(self): ret = {} ret["uuid"] = self.uuid @@ -112,9 +114,11 @@ def read(self): assert self._preload_data is not None ans = self._preload_data self._preload_data = None - self._data = None return ans + def free(self): + self._data = None + def to_dict(self): ret = {} ret["uuid"] = self.uuid @@ -162,3 +166,10 @@ def to_origin_dict(self): ret = {} ret["images"] = [i.to_origin_dict() for i in self.images] return ret + + def free(self): + for image in self.images: + image.free() + for audio in self.audios: + audio.free() + return diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index b6132b12f..ee3b7a957 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -18,7 +18,6 @@ from .model_infer.model_rpc import start_model_process, ModelRpcClient from .req_queue import build_req_queue from lightllm.server.core.objs.io_objs import ( - GroupReqIndexes, AbortedReqCmd, StopStrMatchedReqCmd, ) @@ -31,6 +30,7 @@ from lightllm.server.metrics.manager import MetricClient from lightllm.common.basemodel.infer_lock import g_router_lock from lightllm.common.mem_manager import ReadOnlyStaticsMemoryManager +from lightllm.server.io_struct import BaseReq, GenerateReqIndex from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.process_check import start_parent_check_thread from lightllm.utils.envs_utils import get_unique_server_name @@ -385,8 +385,13 @@ def _get_aborted_reqs_from_running_batch(self) -> List[Req]: ans = [] if self.running_batch is None: return ans - for req in self.running_batch.reqs: - if req.is_aborted and req._router_aborted is False: + aborted_req_mask = torch.tensor( + [req.is_aborted for req in self.running_batch.reqs], dtype=torch.bool, device="cpu" + ) + if self.is_multinode_tp: + dist.all_reduce(aborted_req_mask, op=dist.ReduceOp.MIN, group=self.mulitnode_group) + for req, is_aborted in zip(self.running_batch.reqs, aborted_req_mask.numpy()): + if is_aborted and req._router_aborted is False: req._router_aborted = True ans.append(req) return ans @@ -435,7 +440,7 @@ def get_used_tokens(self, dp_index): else: return self.max_total_token_num - self.read_only_statics_mem_manager.get_unrefed_token_num(dp_index) - def _add_req(self, group_req_indexes: GroupReqIndexes): + def _add_req(self, group_req_indexes: BaseReq): req_group = [] for req_index in group_req_indexes.shm_req_indexes: req = self.shm_req_manager.get_req_obj_by_index(req_index) @@ -481,9 +486,22 @@ def _multinode_tp_generate_new_batch(self): dist.broadcast_object_list(req_ids, src=0, group=self.mulitnode_group) req_id_select_mark = [1 for _ in range(len(req_ids))] req_id_select_mark = torch.tensor(req_id_select_mark, dtype=torch.int32, device="cpu") + # TODO: 这里可以合成一个 allreudce,req_id_select_mark + aborted_req_mask dist.all_reduce(req_id_select_mark, op=dist.ReduceOp.MIN, group=self.mulitnode_group) + aborted_req_mask = torch.tensor( + [req.is_aborted for req in new_batch.reqs], dtype=torch.bool, device="cpu" + ) + dist.all_reduce(aborted_req_mask, op=dist.ReduceOp.MIN, group=self.mulitnode_group) back_req_list = [] - for req_id, select in zip(req_ids, req_id_select_mark.numpy()): + for req_id, select, is_aborted in zip( + req_ids, req_id_select_mark.numpy(), aborted_req_mask.numpy() + ): + # 释放多节点abort 请求,如果select == 0, is_aborted 一定为False + if is_aborted and select == 1: + req = new_batch.pop_req(req_id) + self.req_queue.free_aborted_req(req) + self.shm_req_manager.put_back_req_obj(req) + continue if select == 0: req = new_batch.pop_req(req_id) back_req_list.append(req) @@ -499,23 +517,28 @@ def _multinode_tp_generate_new_batch(self): else: req_ids = [None for _ in range(req_num)] dist.broadcast_object_list(req_ids, src=0, group=self.mulitnode_group) - all_req_id_set = set([req.request_id for req in self.req_queue.waiting_req_list]) + # all_req_id_set = set([req.request_id for req in self.req_queue.waiting_req_list]) + id_to_req_obj = {req.request_id: req for req in self.req_queue.waiting_req_list} req_id_select_mark = [] + aborted_req_mask = [] for req_id in req_ids: - req_id_select_mark.append(1 if req_id in all_req_id_set else 0) + req_id_select_mark.append(1 if req_id in id_to_req_obj else 0) + aborted_req_mask.append(id_to_req_obj[req_id].is_aborted if req_id in id_to_req_obj else False) req_id_select_mark = torch.tensor(req_id_select_mark, dtype=torch.int32, device="cpu") dist.all_reduce(req_id_select_mark, op=dist.ReduceOp.MIN, group=self.mulitnode_group) - select_req_ids = [] - for req_id, select in zip(req_ids, req_id_select_mark.numpy()): - if select == 1: - select_req_ids.append(req_id) - + aborted_req_mask = torch.tensor(aborted_req_mask, dtype=torch.bool, device="cpu") + dist.all_reduce(aborted_req_mask, op=dist.ReduceOp.MIN, group=self.mulitnode_group) select_reqs = [] - for req_id in select_req_ids: - for req in self.req_queue.waiting_req_list: - if req.request_id == req_id: - select_reqs.append(req) - + for req_id, select, is_aborted in zip( + req_ids, req_id_select_mark.numpy(), aborted_req_mask.numpy() + ): + if select == 1: + req = id_to_req_obj[req_id] + if is_aborted: + self.req_queue.free_aborted_req(req) + self.shm_req_manager.put_back_req_obj(req) + continue + select_reqs.append(req) for req in select_reqs: self.req_queue.waiting_req_list.remove(req) if select_reqs: @@ -538,11 +561,9 @@ async def _recv_new_reqs_and_schedule(self): try: # 一次最多从 zmq 中取 recv_max_count 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。 for _ in range(self.recv_max_count): - recv_req: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) - if isinstance(recv_req, GroupReqIndexes): + recv_req: BaseReq = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) + if isinstance(recv_req, GenerateReqIndex): self._add_req(recv_req) - else: - assert False, f"Error Req Inf {recv_req}" # 当队列中存在较多的请求时,将一次接受的数量上调 self.recv_max_count = min(int(self.recv_max_count * 1.3), 256) diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 3fe3f5136..7bb01538d 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -477,6 +477,8 @@ def update_finish_status(self, eos_ids, output_len: int): self.finish_status.set_status(FinishStatus.FINISHED_STOP) elif output_len >= self.sampling_param.shm_param.max_new_tokens: self.finish_status.set_status(FinishStatus.FINISHED_LENGTH) + elif self.infer_aborted: + self.finish_status.set_status(FinishStatus.FINISHED_ABORTED) return def _stop_sequences_matched(self, output_len: int): diff --git a/lightllm/server/router/req_queue/base_queue.py b/lightllm/server/router/req_queue/base_queue.py index 36aefae6e..d7ef06828 100644 --- a/lightllm/server/router/req_queue/base_queue.py +++ b/lightllm/server/router/req_queue/base_queue.py @@ -34,6 +34,17 @@ def free_aborted_req_cpu_cache_pages(self, req: Req): req.cpu_cache_match_page_indexes.clear() self.router.cpu_cache_client.lock.release() + def free_aborted_req(self, req: Req): + # 为了让http server 能正常返回请求,还没有开始推理的请求,直接设置结束,返回空字符串 + input_len = req.input_len + req.link_prompt_ids_shm_array() + req.link_logprobs_shm_array() + req.candetoken_out_len = 1 + req.finish_token_index = input_len + req.shm_prompt_ids.arr[input_len] = self.args.eos_id[0] + req.shm_logprobs.arr[input_len] = 0 + req.finish_status.set_status(FinishStatus.FINISHED_ABORTED) + def extend(self, req_group: List[Req]): for req in req_group: req.sample_params.suggested_dp_index = self.dp_index diff --git a/lightllm/server/router/req_queue/chunked_prefill/beam_impl.py b/lightllm/server/router/req_queue/chunked_prefill/beam_impl.py index ae7c90b33..ed2a5dbb1 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/beam_impl.py +++ b/lightllm/server/router/req_queue/chunked_prefill/beam_impl.py @@ -89,7 +89,7 @@ def generate_new_batch(self, current_batch: Batch): aborted_count = 0 cur_group_reqs = [] for req in self.waiting_req_list: - if req.is_aborted: + if req.is_aborted and not self.router.is_multinode_tp: aborted_count += 1 abort_req_list.append(req) continue @@ -111,7 +111,7 @@ def generate_new_batch(self, current_batch: Batch): ok_insert, new_batch_first_router_need_tokens = self._can_add_new_group_reqs( cur_group_reqs, is_busy, new_batch_first_router_need_tokens ) - if ok_insert: + if ok_insert and False: can_run_list.extend(cur_group_reqs) new_batch = None @@ -120,6 +120,7 @@ def generate_new_batch(self, current_batch: Batch): for req in abort_req_list: self.free_aborted_req_cpu_cache_pages(req) + self.free_aborted_req(req) self.router.shm_req_manager.put_back_req_obj(req) self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :] return new_batch diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl.py b/lightllm/server/router/req_queue/chunked_prefill/impl.py index 0d870b55d..9449798e9 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/impl.py +++ b/lightllm/server/router/req_queue/chunked_prefill/impl.py @@ -79,8 +79,8 @@ def generate_new_batch(self, current_batch: Batch): waiting_queue = self.waiting_req_list for req in waiting_queue: - if req.is_aborted: - # 由于管理的复杂性,只有没有被调度运行过的请求可以因为abort直接在队列中忽略掉. + if req.is_aborted and not self.router.is_multinode_tp: + # 由于管理的复杂性,只有没有被调度运行过的单节点请求可以因为abort直接在队列中忽略掉. # 暂停的请求需要恢复后,由 router manager 部分来过滤。暂时保持这种处理方法, 否则会导致管理token的泄漏 aborted_count += 1 abort_req_list.append(req) @@ -97,6 +97,7 @@ def generate_new_batch(self, current_batch: Batch): new_batch = Batch(uuid.uuid4().int, can_run_list, dp_size_in_node=self.dp_size_in_node) for req in abort_req_list: self.free_aborted_req_cpu_cache_pages(req) + self.free_aborted_req(req) self.router.shm_req_manager.put_back_req_obj(req) self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :] return new_batch diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl_for_nixl_pd.py b/lightllm/server/router/req_queue/chunked_prefill/impl_for_nixl_pd.py index f2658159b..842b93648 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/impl_for_nixl_pd.py +++ b/lightllm/server/router/req_queue/chunked_prefill/impl_for_nixl_pd.py @@ -70,7 +70,7 @@ def generate_new_batch(self, current_batch: Batch): waiting_queue = self.waiting_req_list for req in waiting_queue: - if req.is_aborted: + if req.is_aborted and not self.router.is_multinode_tp: # 由于管理的复杂性,只有没有被调度运行过的请求可以因为abort直接在队列中忽略掉. # 暂停的请求需要恢复后,由 router manager 部分来过滤。暂时保持这种处理方法, 否则会导致管理token的泄漏 aborted_count += 1 @@ -88,6 +88,7 @@ def generate_new_batch(self, current_batch: Batch): new_batch = Batch(uuid.uuid4().int, can_run_list, dp_size_in_node=self.dp_size_in_node) for req in abort_req_list: self.free_aborted_req_cpu_cache_pages(req) + self.free_aborted_req(req) self.router.shm_req_manager.put_back_req_obj(req) self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :] return new_batch diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd_decode.py b/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd_decode.py index e0da13487..3dea3cf95 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd_decode.py +++ b/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd_decode.py @@ -38,7 +38,7 @@ def generate_new_batch(self, current_batch: Batch): abort_req_list = [] aborted_count = 0 for req in self.waiting_req_list: - if req.is_aborted: + if req.is_aborted and not self.router.is_multinode_tp: # 由于管理的复杂性,只有没有被调度运行过的请求可以因为abort直接在队列中忽略掉. # 暂停的请求需要恢复后,由 router manager 部分来过滤。暂时保持这种处理方法, 否则会导致管理token和管理req对象的泄漏 aborted_count += 1 @@ -53,6 +53,7 @@ def generate_new_batch(self, current_batch: Batch): new_batch = Batch(uuid.uuid4().int, can_run_list, dp_size_in_node=self.dp_size_in_node) for req in abort_req_list: self.free_aborted_req_cpu_cache_pages(req) + self.free_aborted_req(req) self.router.shm_req_manager.put_back_req_obj(req) self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :] return new_batch diff --git a/lightllm/server/router/req_queue/dp_base_queue.py b/lightllm/server/router/req_queue/dp_base_queue.py index a73823b8b..e5f731df5 100644 --- a/lightllm/server/router/req_queue/dp_base_queue.py +++ b/lightllm/server/router/req_queue/dp_base_queue.py @@ -27,6 +27,12 @@ def __init__(self, args, router, base_queue_class, dp_size_in_node) -> None: self.reqs_waiting_for_dp_index: List[List[Req]] = [] return + def free_aborted_req(self, req: Req): + dp_index = req.sample_params.suggested_dp_index + assert dp_index >= 0 and dp_index < self.dp_size_in_node + self.inner_queues[dp_index].free_aborted_req(req) + return + def get_dp_queue(self, dp_index: int): assert dp_index < self.dp_size_in_node, "dp index out of range" return self.inner_queues[dp_index] diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index b7e1ac10c..fa3ac0c99 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -8,7 +8,6 @@ import inspect import setproctitle from typing import List -from lightllm.server.core.objs.io_objs.group_req import GroupReqIndexes from lightllm.server.core.objs import ShmReqManager, StartArgs asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) @@ -18,6 +17,7 @@ from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.process_check import start_parent_check_thread from lightllm.utils.envs_utils import get_unique_server_name +from lightllm.server.io_struct import BaseReq, GenerateReqIndex from rpyc.utils.classic import obtain @@ -48,7 +48,7 @@ def __init__( self.cache_client = rpyc.connect("localhost", args.cache_port, config={"allow_pickle": True}) self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) self.cache_port = args.cache_port - self.waiting_reqs: List[GroupReqIndexes] = [] + self.waiting_reqs: List[BaseReq] = [] self.model_weightdir = args.model_dir self.tp_world_size = args.tp self.vit_dp = args.visual_dp @@ -171,11 +171,12 @@ async def loop_for_netio_req(self): while True: try: for _ in range(self.visual_recv_max_count): - recv_req: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) - if isinstance(recv_req, GroupReqIndexes): + recv_req: BaseReq = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) + # 目前只有 GenerateReqIndex 会进入这个队列,判断是否需要推理图片 + if isinstance(recv_req, GenerateReqIndex): self.waiting_reqs.append(recv_req) else: - assert False, f"Error Req Inf {recv_req}" + self.send_to_next_module.send_pyobj(recv_req, protocol=pickle.HIGHEST_PROTOCOL) self.visual_recv_max_count = min(self.visual_recv_max_count * 1.3, 256) except zmq.ZMQError: # 当队列已经开始清空的时候,将一次接受数量下调 From 4095831bd165b8a605ddb0031c9ec2df754a7769 Mon Sep 17 00:00:00 2001 From: shihaobai <42648726+shihaobai@users.noreply.github.com> Date: Wed, 19 Nov 2025 14:11:51 +0800 Subject: [PATCH 03/71] flush cache mulit node (#1116) --- lightllm/server/api_http.py | 1 - lightllm/server/detokenization/manager.py | 19 ++- lightllm/server/httpserver/manager.py | 148 ++++++++++++---------- lightllm/server/io_struct.py | 15 +++ lightllm/server/router/manager.py | 77 ++++++----- lightllm/server/router/mananger_rpc.py | 43 ------- 6 files changed, 161 insertions(+), 142 deletions(-) delete mode 100644 lightllm/server/router/mananger_rpc.py diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index 2c8548873..0a8841f94 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -382,7 +382,6 @@ async def startup_event(): logger.info("server start up") loop = asyncio.get_event_loop() g_objs.set_args(get_env_start_args()) - g_objs.httpserver_manager.connect_router_rpc() loop.create_task(g_objs.httpserver_manager.handle_loop()) logger.info(f"server start up ok, loop use is {asyncio.get_event_loop()}") return diff --git a/lightllm/server/detokenization/manager.py b/lightllm/server/detokenization/manager.py index ab5f706b9..7548342cd 100644 --- a/lightllm/server/detokenization/manager.py +++ b/lightllm/server/detokenization/manager.py @@ -16,7 +16,11 @@ import time from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import get_unique_server_name -from lightllm.server.io_struct import BaseReq +from lightllm.server.io_struct import ( + BaseReq, + GenerateResp, + FlushCacheResp, +) logger = init_logger(__name__) @@ -31,9 +35,9 @@ def __init__( self.zmq_recv_socket = context.socket(zmq.PULL) self.zmq_recv_socket.bind(f"{args.zmq_mode}127.0.0.1:{args.detokenization_port}") - self.pub_to_httpserver = context.socket(zmq.PUB) - self.pub_to_httpserver.bind(f"{args.zmq_mode}127.0.0.1:{args.http_server_port}") - logger.info(f"pub_to_httpserver sendhwm {self.pub_to_httpserver.getsockopt(zmq.SNDHWM)}") + self.send_to_httpserver = context.socket(zmq.PUSH) + self.send_to_httpserver.bind(f"{args.zmq_mode}127.0.0.1:{args.http_server_port}") + logger.info(f"send_to_httpserver sendhwm {self.send_to_httpserver.getsockopt(zmq.SNDHWM)}") self.tokenizer = get_tokenizer(args.model_dir, args.tokenizer_mode, trust_remote_code=args.trust_remote_code) self.all_special_ids = set(self.tokenizer.all_special_ids) self.req_id_to_out: Dict[int, DecodeReq] = {} @@ -75,6 +79,11 @@ def handle_loop(self): # 一次最多从 zmq 中取 recv_max_count 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。 for _ in range(recv_max_count): recv_obj: BaseReq = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) + if isinstance(recv_obj, FlushCacheResp): + print("Detokenization receive flush cache request", flush=True) + self.send_to_httpserver.send_pyobj(recv_obj, protocol=pickle.HIGHEST_PROTOCOL) + print("Detokenization send flush cache request to httpserver", flush=True) + continue self._add_new_group_req_index(recv_obj=recv_obj) # 当队列中存在较多的请求时,将一次接受的数量上调 @@ -145,7 +154,7 @@ def gen_token_out(self): # 通知 httpserver 进程 if exist_decode: - self.pub_to_httpserver.send_pyobj(None, protocol=pickle.HIGHEST_PROTOCOL) + self.send_to_httpserver.send_pyobj(GenerateResp(), protocol=pickle.HIGHEST_PROTOCOL) self.remove_finished_reqs() diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 5254e2097..0dab8fc8c 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -33,7 +33,10 @@ from lightllm.server.io_struct import ( AbortReq, BaseReq, + FlushCacheReq, + FlushCacheResp, GenerateReq, + GenerateResp, GenerateReqMeta, GenerateReqIndex, ) @@ -96,9 +99,8 @@ def __init__( self.shm_req_manager = ShmReqManager() # recv from detokenization - self.zmq_recv_socket = context.socket(zmq.SUB) + self.zmq_recv_socket = context.socket(zmq.PULL) self.zmq_recv_socket.connect(f"{args.zmq_mode}127.0.0.1:{args.http_server_port}") - self.zmq_recv_socket.setsockopt(zmq.SUBSCRIBE, b"") self.tokenizer = get_tokenizer(args.model_dir, args.tokenizer_mode, trust_remote_code=args.trust_remote_code) @@ -120,6 +122,9 @@ def __init__( # If the timemark is not updated for a pre-set time, a prob request will be sent to the backend. self.latest_success_infer_time_mark = SharedInt(f"{get_unique_server_name()}_latest_success_infer_time_mark") self.latest_success_infer_time_mark.set_value(int(time.time())) + + # 交互式请求 event + self.flush_cache_event: Optional[asyncio.Event] = None return async def _alloc_resource(self, items, md5sums, token_nums, datas): @@ -275,13 +280,6 @@ def alloc_req_id(self, sampling_params, is_health_req: bool = False): assert False, "dead code path" return group_request_id - def connect_router_rpc(self): - from lightllm.server.router.mananger_rpc import connect_router_rpc - - self.router_rpc_client = connect_router_rpc(self.args.router_rpc_port) - logger.info("HttpServerManager connected to Router RPC service successfully") - return - async def generate( self, prompt: Union[str, List[int]], @@ -743,64 +741,16 @@ async def handle_loop(self): while True: try: - await asyncio.wait_for(self.zmq_recv_socket.recv_pyobj(), timeout=0.05) + recv_obj = await asyncio.wait_for(self.zmq_recv_socket.recv_pyobj(), timeout=0.05) except asyncio.TimeoutError: - pass + recv_obj = None try: - for group_req_id_ in list(self.req_id_to_out_inf.keys()): - req_status = self.req_id_to_out_inf.get(group_req_id_, None) - if req_status is None: - continue + if recv_obj is None or isinstance(recv_obj, GenerateResp): + await self._handle_recv_generate_request(recv_obj) + elif isinstance(recv_obj, FlushCacheResp): + await self._handle_recv_flush_cache_request(recv_obj) - token_list = [] - for req in req_status.group_req_objs.shm_req_objs: - req_id = req.request_id - read_token_count = 1 - if req.out_tokens_queue.is_full(): - read_token_count = LIGHTLLM_OUT_TOKEN_QUEUE_SIZE - - for _ in range(read_token_count): - if not req.out_tokens_queue.is_empty(): - - text, src_index, special, count_output_tokens = req.out_tokens_queue.peek() - req.cumlogprob += float(req.shm_logprobs.arr[src_index]) - metadata = { - "id": int(req.shm_prompt_ids.arr[src_index]), - "logprob": float(req.shm_logprobs.arr[src_index]), - "cumlogprob": float(req.cumlogprob) / count_output_tokens, - "special": special, - "count_output_tokens": count_output_tokens, - "prompt_cache_len": req.prompt_cache_len, - "cpu_prompt_cache_len": req.cpu_prompt_cache_len, - "mtp_accepted_token_num": req.mtp_accepted_token_num, - } - if self.args.return_all_prompt_logprobs: - metadata.update(req.get_all_prompt_metadata()) - if self.args.use_reward_model: - metadata["score"] = float(req.reward_score) - - req.out_tokens_queue.pop_no_ret() - - finished_token_index = ( - req.stop_str_matched_token_index if req.stop_str_matched else req.finish_token_index - ) - - if finished_token_index != src_index: - token_list.append((req_id, text, metadata, FinishStatus())) - else: - if req.stop_str_matched: - finish_status = FinishStatus(FinishStatus.FINISHED_STOP) - else: - finish_status = FinishStatus(req.finish_status.status) - - token_list.append((req_id, text, metadata, finish_status)) - else: - break - - async with req_status.lock: - req_status.out_token_info_list.extend(token_list) - req_status.event.set() except BaseException as e: logger.exception(str(e)) raise e @@ -808,8 +758,78 @@ async def handle_loop(self): self.recycle_event.set() return + async def _handle_recv_generate_request(self, recv_obj: GenerateReqMeta): + for group_req_id_ in list(self.req_id_to_out_inf.keys()): + req_status = self.req_id_to_out_inf.get(group_req_id_, None) + if req_status is None: + continue + + token_list = [] + for req in req_status.group_req_objs.shm_req_objs: + req_id = req.request_id + read_token_count = 1 + if req.out_tokens_queue.is_full(): + read_token_count = LIGHTLLM_OUT_TOKEN_QUEUE_SIZE + + for _ in range(read_token_count): + if not req.out_tokens_queue.is_empty(): + + text, src_index, special, count_output_tokens = req.out_tokens_queue.peek() + req.cumlogprob += float(req.shm_logprobs.arr[src_index]) + metadata = { + "id": int(req.shm_prompt_ids.arr[src_index]), + "logprob": float(req.shm_logprobs.arr[src_index]), + "cumlogprob": float(req.cumlogprob) / count_output_tokens, + "special": special, + "count_output_tokens": count_output_tokens, + "prompt_cache_len": req.prompt_cache_len, + "cpu_prompt_cache_len": req.cpu_prompt_cache_len, + "mtp_accepted_token_num": req.mtp_accepted_token_num, + } + if self.args.return_all_prompt_logprobs: + metadata.update(req.get_all_prompt_metadata()) + if self.args.use_reward_model: + metadata["score"] = float(req.reward_score) + + req.out_tokens_queue.pop_no_ret() + + finished_token_index = ( + req.stop_str_matched_token_index if req.stop_str_matched else req.finish_token_index + ) + + if finished_token_index != src_index: + token_list.append((req_id, text, metadata, FinishStatus())) + else: + if req.stop_str_matched: + finish_status = FinishStatus(FinishStatus.FINISHED_STOP) + else: + finish_status = FinishStatus(req.finish_status.status) + + token_list.append((req_id, text, metadata, finish_status)) + else: + break + + async with req_status.lock: + req_status.out_token_info_list.extend(token_list) + req_status.event.set() + + async def _handle_recv_flush_cache_request(self, recv_obj: FlushCacheResp): + assert self.flush_cache_event is not None + self.flush_cache_event.success = recv_obj.success + self.flush_cache_event.set() + return + async def flush_cache(self): - ret = await self.router_rpc_client.flush_cache() + if self.flush_cache_event is None: + self.flush_cache_event = asyncio.Event() + await self.transfer_to_next_module(FlushCacheReq()) + try: + await asyncio.wait_for(self.flush_cache_event.wait(), timeout=30) + ret = self.flush_cache_event.success + except asyncio.TimeoutError: + # 超时直接返回失败 + ret = False + self.flush_cache_event.clear() return ret diff --git a/lightllm/server/io_struct.py b/lightllm/server/io_struct.py index b5adff954..2b4b3cef4 100644 --- a/lightllm/server/io_struct.py +++ b/lightllm/server/io_struct.py @@ -59,6 +59,21 @@ def get_req_to_next_node(self): ) +@dataclass +class GenerateResp(BaseReq): + pass + + +@dataclass +class FlushCacheReq(BaseReq): + pass + + +@dataclass +class FlushCacheResp(BaseReq): + success: bool + + @dataclass class AbortReq(BaseReq): # 外部调用传入,等同内部的 group_req_id diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index ee3b7a957..64b36a123 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -30,7 +30,12 @@ from lightllm.server.metrics.manager import MetricClient from lightllm.common.basemodel.infer_lock import g_router_lock from lightllm.common.mem_manager import ReadOnlyStaticsMemoryManager -from lightllm.server.io_struct import BaseReq, GenerateReqIndex +from lightllm.server.io_struct import ( + BaseReq, + GenerateReqIndex, + FlushCacheReq, + FlushCacheResp, +) from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.process_check import start_parent_check_thread from lightllm.utils.envs_utils import get_unique_server_name @@ -152,9 +157,6 @@ async def wait_to_model_ready(self): rpc_finished_event=self.rpc_finished_event, ) - # 启动 rpyc 服务,供 HTTP Server 远程调用 - self._start_router_rpc_service() - kvargs = { "args": self.args, "rank_id": None, # 由后续处理填充真实数据 @@ -235,25 +237,6 @@ async def wait_to_model_ready(self): return - def _start_router_rpc_service(self): - """launch a rpyc service for httpserver to call RouterManager""" - import threading - from rpyc.utils.server import ThreadedServer - import lightllm.utils.rpyc_fix_utils as _ - from .mananger_rpc import RouterRpcService - - service = RouterRpcService(self) - port = self.args.router_rpc_port - - def start_server(): - t = ThreadedServer(service, port=port, protocol_config={"allow_pickle": True}) - t.start() - - rpc_thread = threading.Thread(target=start_server, daemon=True) - rpc_thread.start() - logger.info(f"Router RPC service started successfully on port {port}") - return - def _get_schedule_time_interval(self): # dp 模式,为了更好的配平,需要更长的调度间隔,以便于能收到更多的请求 return self.schedule_time_interval @@ -559,11 +542,15 @@ async def _recv_new_reqs_and_schedule(self): self.recv_max_count = 64 try: + # 多机tp需要广播给其他node的请求 + special_reqs = [] # 一次最多从 zmq 中取 recv_max_count 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。 for _ in range(self.recv_max_count): recv_req: BaseReq = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) if isinstance(recv_req, GenerateReqIndex): self._add_req(recv_req) + elif isinstance(recv_req, FlushCacheReq): + special_reqs.append(recv_req) # 当队列中存在较多的请求时,将一次接受的数量上调 self.recv_max_count = min(int(self.recv_max_count * 1.3), 256) @@ -572,6 +559,8 @@ async def _recv_new_reqs_and_schedule(self): # 当队列已经开始清空的时候,将一次接受的数量下调 self.recv_max_count = 64 + self._process_special_reqs(special_reqs) + if self.is_multinode_tp: self._multinode_tp_generate_new_batch() else: @@ -579,16 +568,46 @@ async def _recv_new_reqs_and_schedule(self): self._generate_new_batch() return + def _process_special_reqs(self, special_reqs: List[BaseReq]): + if self.is_multinode_tp: + special_reqs = self.broadcast_reqs_to_other_nodes(special_reqs) + for req in special_reqs: + if isinstance(req, FlushCacheReq): + self.flush_cache() + + def broadcast_reqs_to_other_nodes(self, reqs: List[BaseReq]): + req_num = len(reqs) + if self.node_rank == 0: + req_nums = [len(reqs)] + dist.broadcast_object_list(req_nums, src=0, group=self.mulitnode_group) + req_num = req_nums[0] + if req_num > 0: + dist.broadcast_object_list(reqs, src=0, group=self.mulitnode_group) + else: + req_nums = [None] + dist.broadcast_object_list(req_nums, src=0, group=self.mulitnode_group) + req_num = req_nums[0] + if req_num > 0: + reqs = [None for _ in range(req_num)] + dist.broadcast_object_list(reqs, src=0, group=self.mulitnode_group) + return reqs + def flush_cache(self) -> bool: - if self.running_batch is not None: - return False - if self.req_queue.get_wait_req_num() > 0: - return False # if radix cache client is not initialized, just return True if self.radix_cache_client is None: - return True + success = True # only flush cache when no running batch and no waiting requests - return self.model_rpc_client.flush_radix_cache() + elif self.running_batch is not None or self.req_queue.get_wait_req_num() > 0: + success = False + else: + success = self.model_rpc_client.flush_radix_cache() + + if self.is_multinode_tp: + # 等待其他节点的flush 结果 + dist.barrier(group=self.mulitnode_group) + if self.is_multinode_tp_master: + self.send_to_detokenization.send_pyobj(FlushCacheResp(success=success), protocol=pickle.HIGHEST_PROTOCOL) + return success def clean_up(self): return diff --git a/lightllm/server/router/mananger_rpc.py b/lightllm/server/router/mananger_rpc.py deleted file mode 100644 index 60f9e0458..000000000 --- a/lightllm/server/router/mananger_rpc.py +++ /dev/null @@ -1,43 +0,0 @@ -import rpyc -import asyncio -import socket -from .manager import RouterManager - - -class RouterRpcService(rpyc.Service): - def __init__(self, router_manager: "RouterManager"): - super().__init__() - self.router_manager = router_manager - return - - def exposed_flush_cache(self) -> bool: - return self.router_manager.flush_cache() - - -class RouterRpcClient: - def __init__(self, router_rpc_conn): - self.router_rpc_conn = router_rpc_conn - - def async_wrap(f): - f = rpyc.async_(f) - - async def _func(*args, **kwargs): - ans = f(*args, **kwargs) - await asyncio.to_thread(ans.wait) - # raise if exception - return ans.value - - return _func - - self._flush_cache = async_wrap(self.router_rpc_conn.root.flush_cache) - return - - async def flush_cache(self) -> bool: - ans = await self._flush_cache() - return ans - - -def connect_router_rpc(port: int) -> RouterRpcClient: - router_rpc_conn = rpyc.connect("localhost", port, config={"allow_pickle": True}) - router_rpc_conn._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - return RouterRpcClient(router_rpc_conn) From ca9325fd0cd590adec3032c4d2024e5337327e18 Mon Sep 17 00:00:00 2001 From: shihaobai <42648726+shihaobai@users.noreply.github.com> Date: Wed, 19 Nov 2025 14:56:32 +0800 Subject: [PATCH 04/71] [bugfix]: flush cache in single node (#1118) Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- lightllm/server/router/manager.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 64b36a123..35fd861d8 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -592,7 +592,7 @@ def broadcast_reqs_to_other_nodes(self, reqs: List[BaseReq]): dist.broadcast_object_list(reqs, src=0, group=self.mulitnode_group) return reqs - def flush_cache(self) -> bool: + def flush_cache(self) -> None: # if radix cache client is not initialized, just return True if self.radix_cache_client is None: success = True @@ -605,9 +605,9 @@ def flush_cache(self) -> bool: if self.is_multinode_tp: # 等待其他节点的flush 结果 dist.barrier(group=self.mulitnode_group) - if self.is_multinode_tp_master: + if self.node_rank == 0: self.send_to_detokenization.send_pyobj(FlushCacheResp(success=success), protocol=pickle.HIGHEST_PROTOCOL) - return success + return def clean_up(self): return From 99489258234bb2256b5908285aa5230c70af00bf Mon Sep 17 00:00:00 2001 From: shihaobai <42648726+shihaobai@users.noreply.github.com> Date: Wed, 19 Nov 2025 22:03:35 +0800 Subject: [PATCH 05/71] add pause and continue (#1120) --- lightllm/server/api_http.py | 12 ++++++++++++ lightllm/server/httpserver/manager.py | 24 ++++++++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index 0a8841f94..07fcc4139 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -316,6 +316,18 @@ async def flush_cache(): ) +@app.post("/pause_generation") +async def pause_generation(): + await g_objs.httpserver_manager.pause_generation() + return Response(content="Generation paused successfully.", status_code=200) + + +@app.post("/continue_generation") +async def continue_generation(): + await g_objs.httpserver_manager.continue_generation() + return Response(content="Generation continued successfully.", status_code=200) + + @app.websocket("/pd_register") async def register_and_keep_alive(websocket: WebSocket): await websocket.accept() diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 0dab8fc8c..765b44eea 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -123,6 +123,9 @@ def __init__( self.latest_success_infer_time_mark = SharedInt(f"{get_unique_server_name()}_latest_success_infer_time_mark") self.latest_success_infer_time_mark.set_value(int(time.time())) + self.is_pause = False + self.is_pause_cond = asyncio.Condition() + # 交互式请求 event self.flush_cache_event: Optional[asyncio.Event] = None return @@ -302,6 +305,10 @@ async def generate( # 记录请求到达的相关信息 await self._log_req_header(request_headers, group_request_id) + + async with self.is_pause_cond: + await self.is_pause_cond.wait_for(lambda: not self.is_pause) + # encode prompt_ids = await self._encode(prompt, multimodal_params, sampling_params) @@ -832,6 +839,23 @@ async def flush_cache(self): self.flush_cache_event.clear() return ret + async def pause_generation(self): + # 因为请求是从master node转发到slave node的 + # 所以只要master暂停了,slave自然暂停。 + async with self.is_pause_cond: + self.is_pause = True + while True: + await self.abort_request(AbortReq(request_id=None, abort_all=True)) + running_req_num = len(list(self.req_id_to_out_inf.keys())) + if running_req_num == 0: + break + await asyncio.sleep(1.0) + + async def continue_generation(self): + async with self.is_pause_cond: + self.is_pause = False + self.is_pause_cond.notify_all() + class ReqStatus: def __init__( From 4b32287d662c8501d9c7c2a1f193c265d01fc6be Mon Sep 17 00:00:00 2001 From: sufubao <47234901+sufubao@users.noreply.github.com> Date: Fri, 21 Nov 2025 17:39:17 +0800 Subject: [PATCH 06/71] add launch_server and StartArgs (#1119) --- lightllm/server/api_http.py | 15 ++++- lightllm/server/api_server.py | 25 ++++++-- lightllm/server/api_start.py | 33 ++++++++--- lightllm/server/core/objs/start_args_type.py | 62 ++++++++++++++------ lightllm/utils/device_utils.py | 2 +- 5 files changed, 106 insertions(+), 31 deletions(-) diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index 07fcc4139..b96cf9306 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -33,7 +33,7 @@ import uuid from PIL import Image import multiprocessing as mp -from typing import AsyncGenerator, Union +from typing import Any, AsyncGenerator, Union from typing import Callable from lightllm.server import TokenLoad from fastapi import BackgroundTasks, FastAPI, Request, WebSocket, WebSocketDisconnect @@ -131,6 +131,19 @@ def get_model_name(): return {"model_name": g_objs.args.model_name} +@app.get("/get_server_info") +@app.post("/get_server_info") +def get_server_info(): + # 将 StartArgs 转换为字典格式 + from dataclasses import asdict + server_info: dict[str, Any] = asdict(g_objs.args) + return {**server_info} + +@app.get("/get_weight_version") +@app.post("/get_weight_version") +def get_weight_version(): + return {"weight_version": g_objs.args.weight_version} + @app.get("/healthz", summary="Check server health") @app.get("/health", summary="Check server health") @app.head("/health", summary="Check server health") diff --git a/lightllm/server/api_server.py b/lightllm/server/api_server.py index b4447d808..dd531f58d 100755 --- a/lightllm/server/api_server.py +++ b/lightllm/server/api_server.py @@ -1,11 +1,21 @@ import torch from .api_cli import make_argument_parser +from lightllm.server.core.objs.start_args_type import StartArgs +from lightllm.utils.log_utils import init_logger -if __name__ == "__main__": - torch.multiprocessing.set_start_method("spawn") # this code will not be ok for settings to fork to subprocess - parser = make_argument_parser() - args = parser.parse_args() +logger = init_logger(__name__) + +def launch_server(args: StartArgs): from .api_start import pd_master_start, normal_or_p_d_start, config_server_start + + try: + # this code will not be ok for settings to fork to subprocess + torch.multiprocessing.set_start_method("spawn") + except RuntimeError as e: + logger.warning(f"Failed to set start method: {e}") + except Exception as e: + logger.error(f"Failed to set start method: {e}") + raise e if args.run_mode == "pd_master": pd_master_start(args) @@ -13,3 +23,10 @@ config_server_start(args) else: normal_or_p_d_start(args) + + +if __name__ == "__main__": + parser = make_argument_parser() + args = parser.parse_args() + + launch_server(StartArgs(**vars(args))) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 138b0a599..6a02dda17 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -16,6 +16,7 @@ from lightllm.utils.process_check import is_process_active from lightllm.utils.multinode_utils import send_and_receive_node_ip from lightllm.utils.shm_size_check import check_recommended_shm_size +from lightllm.server.core.objs.start_args_type import StartArgs logger = init_logger(__name__) @@ -51,20 +52,38 @@ def signal_handler(sig, frame): process_manager.terminate_all_processes() logger.info("All processes have been terminated gracefully.") sys.exit(0) + elif sig == signal.SIGHUP: + logger.info("Received SIGHUP (terminal closed), shutting down gracefully...") + if http_server_process and http_server_process.poll() is None: + http_server_process.send_signal(signal.SIGTERM) + + start_time = time.time() + while (time.time() - start_time) < 60: + if not is_process_active(http_server_process.pid): + logger.info("httpserver exit") + break + time.sleep(1) + + if time.time() - start_time < 60: + logger.info("HTTP server has exited gracefully") + else: + logger.warning("HTTP server did not exit in time, killing it...") + kill_recursive(http_server_process) + + process_manager.terminate_all_processes() + logger.info("All processes have been terminated gracefully due to terminal closure.") + sys.exit(0) signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGHUP, signal_handler) logger.info(f"start process pid {os.getpid()}") logger.info(f"http server pid {http_server_process.pid}") return -def normal_or_p_d_start(args): - from lightllm.server.core.objs.start_args_type import StartArgs - - args: StartArgs = args - +def normal_or_p_d_start(args: StartArgs): set_unique_server_name(args) if not args.disable_shm_warning: @@ -370,7 +389,7 @@ def normal_or_p_d_start(args): return -def pd_master_start(args): +def pd_master_start(args: StartArgs): set_unique_server_name(args) if args.run_mode != "pd_master": return @@ -433,7 +452,7 @@ def pd_master_start(args): http_server_process.wait() -def config_server_start(args): +def config_server_start(args: StartArgs): set_unique_server_name(args) if args.run_mode != "config_server": return diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 659aab1dc..40f68a743 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -1,48 +1,52 @@ from dataclasses import dataclass, field from typing import List, Optional, Tuple -# 只是为了更好的编程提示 +# 服务启动参数 @dataclass class StartArgs: run_mode: str = field( default="normal", - metadata={"choices": ["normal", "prefill", "decode", "pd_master", "nixl_prefill", "nixl_decode"]}, + metadata={"choices": ["normal", "prefill", "decode", "nixl_prefill", "nixl_decode", "pd_master", "config_server"]}, ) host: str = field(default="127.0.0.1") port: int = field(default=8000) + httpserver_workers: int = field(default=1) zmq_mode: str = field( default="ipc:///tmp/", metadata={"help": "use socket mode or ipc mode, only can be set in ['tcp://', 'ipc:///tmp/']"}, ) - pd_master_ip: str = field(default="127.0.0.1") + pd_master_ip: str = field(default="0.0.0.0") pd_master_port: int = field(default=1212) config_server_host: str = field(default=None) config_server_port: int = field(default=None) pd_decode_rpyc_port: int = field(default=42000) - select_p_d_node_strategy: str = field(default=None) + select_p_d_node_strategy: str = field( + default="round_robin", + metadata={"choices": ["random", "round_robin", "adaptive_load"]} + ) model_name: str = field(default="default_model_name") model_dir: Optional[str] = field(default=None) - tokenizer_mode: str = field(default="slow") + tokenizer_mode: str = field(default="fast") load_way: str = field(default="HF") max_total_token_num: Optional[int] = field(default=None) mem_fraction: float = field(default=0.9) batch_max_tokens: Optional[int] = field(default=None) - eos_id: List[int] = field(default_factory=list) + eos_id: Optional[List[int]] = field(default=None) tool_call_parser: Optional[str] = field( - default=None, metadata={"choices": ["llama3", "qwen25", "mistral", "deepseekv3", "kimi_k2", "qwen"]} + default=None, metadata={"choices": ["qwen25", "llama3", "mistral", "deepseekv3", "qwen"]} ) running_max_req_size: int = field(default=1000) tp: int = field(default=1) dp: int = field(default=1) nnodes: int = field(default=1) node_rank: int = field(default=0) - max_req_total_len: int = field(default=2048 + 1024) + max_req_total_len: int = field(default=16384) nccl_host: str = field(default="127.0.0.1") nccl_port: int = field(default=28765) use_config_server_to_init_nccl: bool = field(default=False) - mode: List[str] = field(default_factory=list) + mode: List[str] = field(default_factory=lambda: []) trust_remote_code: bool = field(default=False) disable_log_stats: bool = field(default=False) log_stats_interval: int = field(default=10) @@ -51,11 +55,14 @@ class StartArgs: router_max_wait_tokens: int = field(default=1) disable_aggressive_schedule: bool = field(default=False) disable_dynamic_prompt_cache: bool = field(default=False) - chunked_prefill_size: int = field(default=8192) + chunked_prefill_size: int = field(default=4096) disable_chunked_prefill: bool = field(default=False) diverse_mode: bool = field(default=False) token_healing_mode: bool = field(default=False) - output_constraint_mode: str = field(default="none", metadata={"choices": ["none", "simple", "xgrammar"]}) + output_constraint_mode: str = field( + default="none", + metadata={"choices": ["outlines", "xgrammar", "none"]} + ) first_token_constraint_mode: bool = field(default=False) enable_multimodal: bool = field(default=False) enable_multimodal_audio: bool = field(default=False) @@ -74,10 +81,10 @@ class StartArgs: health_monitor: bool = field(default=False) metric_gateway: Optional[str] = field(default=None) job_name: str = field(default="lightllm") - grouping_key: List[str] = field(default_factory=list) + grouping_key: List[str] = field(default_factory=lambda: []) push_interval: int = field(default=10) visual_infer_batch_size: int = field(default=1) - visual_gpu_ids: List[int] = field(default_factory=lambda: [0]) + visual_gpu_ids: Optional[List[int]] = field(default=None) visual_tp: int = field(default=1) visual_dp: int = field(default=1) visual_nccl_ports: List[int] = field(default_factory=lambda: [29500]) @@ -86,10 +93,10 @@ class StartArgs: graph_max_batch_size: int = field(default=256) graph_split_batch_size: int = field(default=32) graph_grow_step_size: int = field(default=16) - graph_max_len_in_batch: int = field(default=8192) - quant_type: Optional[str] = field(default=None) + graph_max_len_in_batch: int = field(default=0) + quant_type: Optional[str] = field(default="none") quant_cfg: Optional[str] = field(default=None) - vit_quant_type: Optional[str] = field(default=None) + vit_quant_type: Optional[str] = field(default="none") vit_quant_cfg: Optional[str] = field(default=None) enable_flashinfer_prefill: bool = field(default=False) enable_flashinfer_decode: bool = field(default=False) @@ -99,7 +106,10 @@ class StartArgs: ) ep_redundancy_expert_config_path: Optional[str] = field(default=None) auto_update_redundancy_expert: bool = field(default=False) - mtp_mode: Optional[str] = field(default=None) + mtp_mode: Optional[str] = field( + default=None, + metadata={"choices": ["deepseekv3_vanilla", "deepseekv3_eagle", None]} + ) mtp_draft_model_dir: Optional[str] = field(default=None) mtp_step: int = field(default=0) kv_quant_calibration_config_path: Optional[str] = field(default=None) @@ -108,7 +118,7 @@ class StartArgs: pd_node_id: int = field(default=-1) enable_cpu_cache: bool = field(default=False) cpu_cache_storage_size: float = field(default=2) - cpu_cache_token_page_size: int = field(default=64) + cpu_cache_token_page_size: int = field(default=256) enable_disk_cache: bool = field(default=False) disk_cache_storage_size: float = field(default=10) # zmp ports @@ -128,3 +138,19 @@ class StartArgs: # kernel setting enable_fa3: bool = field(default=False) + + httpserver_workers: int = field(default=1) + disable_shm_warning: bool = field(default=False) + dp_balancer: str = field( + default="bs_balancer", + metadata={"choices": ["round_robin", "bs_balancer"]} + ) + enable_custom_allgather: bool = field(default=False) + enable_fused_shared_experts: bool = field(default=False) + enable_mps: bool = field(default=False) + multinode_router_gloo_port: int = field(default=20001) + schedule_time_interval: float = field(default=0.03) + use_dynamic_prompt_cache: bool = field(default=False) + disable_custom_allreduce: bool = field(default=False) + + weight_version: str = "default" diff --git a/lightllm/utils/device_utils.py b/lightllm/utils/device_utils.py index b4d1ba629..d2b6d06a8 100644 --- a/lightllm/utils/device_utils.py +++ b/lightllm/utils/device_utils.py @@ -85,7 +85,7 @@ def get_current_device_name(): gpu_name = torch.cuda.get_device_name(device).replace(" ", "_") return gpu_name else: - return None + raise RuntimeError("No GPU available") @lru_cache(maxsize=None) From 27abcf536894eb5e6f91871a5120debd07adabde Mon Sep 17 00:00:00 2001 From: Weichao Luo Date: Mon, 1 Dec 2025 16:39:35 +0800 Subject: [PATCH 07/71] Update weight (#1127) Co-authored-by: Weichao Luo Co-authored-by: shihaobai <1798930569@qq.com> --- lightllm/common/basemodel/basemodel.py | 7 + lightllm/server/api_http.py | 42 +++- lightllm/server/detokenization/manager.py | 5 + lightllm/server/httpserver/manager.py | 73 +++++++ lightllm/server/io_struct.py | 85 +++++++- lightllm/server/router/manager.py | 21 +- .../model_infer/mode_backend/base_backend.py | 181 +++++++++++++++++- .../server/router/model_infer/model_rpc.py | 20 ++ lightllm/utils/dist_utils.py | 74 ++++++- lightllm/utils/patch_torch.py | 65 +++++++ lightllm/utils/serializer.py | 132 +++++++++++++ lightllm/utils/tensor_bucket.py | 108 +++++++++++ 12 files changed, 805 insertions(+), 8 deletions(-) create mode 100644 lightllm/utils/patch_torch.py create mode 100644 lightllm/utils/serializer.py create mode 100644 lightllm/utils/tensor_bucket.py diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 77ca299b2..1221f1939 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -178,6 +178,13 @@ def _init_weights(self): [weight.verify_load() for weight in self.trans_layers_weight] return + def load_weights(self, weight_dict: dict): + load_hf_weights(self.data_type, + self.weight_dir_, + pre_post_layer=self.pre_post_weight, + transformer_layer_list=self.trans_layers_weight, + weight_dict=weight_dict) + def _init_mem_manager(self): assert self.config["num_attention_heads"] % self.tp_world_size_ == 0 self.mem_manager = MemoryManager( diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index b96cf9306..28ae93dfb 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -58,7 +58,14 @@ CompletionRequest, CompletionResponse, ) -from .io_struct import AbortReq +from .io_struct import ( + AbortReq, + InitWeightsUpdateGroupReq, + DestroyWeightsUpdateGroupReq, + UpdateWeightsFromDistributedReq, + UpdateWeightsFromTensorReq, + GeneralModelToHttpRpcRsp +) from .build_prompt import build_prompt, init_tokenizer logger = init_logger(__name__) @@ -315,6 +322,39 @@ async def abort_request(request: AbortReq, raw_request: Request): return create_error_response(HTTPStatus.EXPECTATION_FAILED, f"error: {str(e)}") +async def handle_request_common(request_obj, handler): + try: + ret: GeneralModelToHttpRpcRsp = await handler(request_obj) + if ret.success: + return JSONResponse({"success": ret.success, "message": ret.msg}, status_code=200) + else: + return create_error_response(HTTPStatus.BAD_REQUEST, ret.msg) + except Exception as e: + return create_error_response( + HTTPStatus.EXPECTATION_FAILED, + f"error: {str(e)}" + ) + +@app.post("/init_weights_update_group") +async def init_weights_update_group(request: InitWeightsUpdateGroupReq, raw_request: Request): + """Init weights update group.""" + return await handle_request_common(request, g_objs.httpserver_manager.init_weights_update_group) + +@app.post("/destroy_weights_update_group") +async def destroy_weights_update_group(request: DestroyWeightsUpdateGroupReq, raw_request: Request): + """Destroy weights update group.""" + return await handle_request_common(request, g_objs.httpserver_manager.destroy_weights_update_group) + +@app.post("/update_weights_from_distributed") +async def update_weights_from_distributed(request: UpdateWeightsFromDistributedReq, raw_request: Request): + """Update model parameter from distributed online.""" + return await handle_request_common(request, g_objs.httpserver_manager.update_weights_from_distributed) + +@app.post("/update_weights_from_tensor") +async def update_weights_from_distributed(request: UpdateWeightsFromTensorReq, raw_request: Request): + """Update model parameter from distributed online.""" + return await handle_request_common(request, g_objs.httpserver_manager.update_weights_from_tensor) + @app.post("/flush_cache") @app.get("/flush_cache") async def flush_cache(): diff --git a/lightllm/server/detokenization/manager.py b/lightllm/server/detokenization/manager.py index 7548342cd..b7ba96025 100644 --- a/lightllm/server/detokenization/manager.py +++ b/lightllm/server/detokenization/manager.py @@ -20,6 +20,7 @@ BaseReq, GenerateResp, FlushCacheResp, + GeneralModelToHttpRpcRsp ) logger = init_logger(__name__) @@ -84,6 +85,10 @@ def handle_loop(self): self.send_to_httpserver.send_pyobj(recv_obj, protocol=pickle.HIGHEST_PROTOCOL) print("Detokenization send flush cache request to httpserver", flush=True) continue + elif isinstance(recv_obj, GeneralModelToHttpRpcRsp): + self.send_to_httpserver.send_pyobj(recv_obj, protocol=pickle.HIGHEST_PROTOCOL) + print(f"Detokenization send {recv_obj.func_name} request to httpserver") + continue self._add_new_group_req_index(recv_obj=recv_obj) # 当队列中存在较多的请求时,将一次接受的数量上调 diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 765b44eea..44721cfe3 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -10,6 +10,7 @@ import hashlib import datetime import pickle +import inspect from frozendict import frozendict asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) @@ -39,6 +40,17 @@ GenerateResp, GenerateReqMeta, GenerateReqIndex, + InitWeightsUpdateGroupReq, + InitWeightsUpdateGroupRsp, + DestroyWeightsUpdateGroupReq, + DestroyWeightsUpdateGroupRsp, + UpdateWeightsFromDistributedReq, + UpdateWeightsFromDistributedRsp, + UpdateWeightsFromTensorReq, + UpdateWeightsFromTensorRsp, + GeneralHttpToModelRpcReq, + GeneralModelToHttpRpcRsp + ) from lightllm.utils.statics_utils import MovingAverage from lightllm.utils.config_utils import get_vocab_size @@ -128,6 +140,7 @@ def __init__( # 交互式请求 event self.flush_cache_event: Optional[asyncio.Event] = None + self.async_events_per_func: Dict[str, asyncio.Event] = {} return async def _alloc_resource(self, items, md5sums, token_nums, datas): @@ -757,6 +770,8 @@ async def handle_loop(self): await self._handle_recv_generate_request(recv_obj) elif isinstance(recv_obj, FlushCacheResp): await self._handle_recv_flush_cache_request(recv_obj) + elif isinstance(recv_obj, GeneralModelToHttpRpcRsp): + await self._handle_recv_general_model_to_http_request(recv_obj) except BaseException as e: logger.exception(str(e)) @@ -826,6 +841,13 @@ async def _handle_recv_flush_cache_request(self, recv_obj: FlushCacheResp): self.flush_cache_event.set() return + async def _handle_recv_general_model_to_http_request(self, recv_obj: GeneralModelToHttpRpcRsp): + assert recv_obj.func_name is not None + event = await self.get_event_for_func(recv_obj.func_name) + event.result = recv_obj + event.set() + return + async def flush_cache(self): if self.flush_cache_event is None: self.flush_cache_event = asyncio.Event() @@ -856,6 +878,57 @@ async def continue_generation(self): self.is_pause = False self.is_pause_cond.notify_all() + async def get_event_for_func(self, func_name: str) -> asyncio.Event: + if func_name not in self.async_events_per_func: + self.async_events_per_func[func_name] = asyncio.Event() + return self.async_events_per_func[func_name] + + async def http_to_model_special_request(self, request: GeneralHttpToModelRpcReq, timeout: int=300) -> GeneralModelToHttpRpcRsp: + event = await self.get_event_for_func(request.func_name) + await self.transfer_to_next_module(request) + try: + await asyncio.wait_for(event.wait(), timeout=timeout) + ret = event.result + + except asyncio.TimeoutError: + ret = GeneralModelToHttpRpcRsp(success=False, msg="wait for response timeout", func_name=request.func_name) + except Exception as e: + ret = GeneralModelToHttpRpcRsp(success=False, msg="wait for response error: %s" % str(e), func_name=request.func_name) + return ret + + + async def init_weights_update_group(self, request: InitWeightsUpdateGroupReq): + return await self.http_to_model_special_request(GeneralHttpToModelRpcReq( + func_name="init_weights_update_group", func_args=request)) + + + async def destroy_weights_update_group(self, request: DestroyWeightsUpdateGroupReq): + return await self.http_to_model_special_request(GeneralHttpToModelRpcReq( + func_name="destroy_weights_update_group", func_args=request)) + + + async def update_weights_from_distributed(self, request: UpdateWeightsFromDistributedReq): + + if request.abort_all_requests: + await self.abort_request(AbortReq(abort_all=True)) + + if request.flush_cache: + await self.flush_cache() + + return await self.http_to_model_special_request( + GeneralHttpToModelRpcReq(func_name="update_weights_from_distributed", func_args=request)) + + async def update_weights_from_tensor(self, request: UpdateWeightsFromTensorReq) -> Tuple[bool, str]: + if request.abort_all_requests: + await self.abort_request(AbortReq(abort_all=True)) + + if request.flush_cache: + await self.flush_cache() + + return await self.http_to_model_special_request( + GeneralHttpToModelRpcReq(func_name="update_weights_from_tensor", func_args=request) + ) + class ReqStatus: def __init__( diff --git a/lightllm/server/io_struct.py b/lightllm/server/io_struct.py index 2b4b3cef4..7947c7a58 100644 --- a/lightllm/server/io_struct.py +++ b/lightllm/server/io_struct.py @@ -3,7 +3,7 @@ from lightllm.server.core.objs.req import Req from lightllm.server.core.objs.sampling_params import SamplingParams from lightllm.server.multimodal_params import MultimodalParams -from typing import List +from typing import List, Optional, Any, Union @dataclass @@ -14,6 +14,10 @@ def get_req_to_next_node(self): def get_req_to_next_module(self): return self +@dataclass +class BaseRsp(ABC): + success: bool + msg: Optional[str] # for next node @dataclass @@ -79,3 +83,82 @@ class AbortReq(BaseReq): # 外部调用传入,等同内部的 group_req_id request_id: int = None abort_all: bool = False + + +@dataclass +class GeneralHttpToModelRpcReq(BaseReq): + func_name: str + func_args: Optional[Any] = None + +@dataclass +class GeneralModelToHttpRpcRsp(BaseRsp): + func_name: str + func_rsp: Optional[Any] = None + +@dataclass +class InitWeightsUpdateGroupReq(BaseReq): + # The master address + master_address: str + # The master port + master_port: int + # The rank offset + rank_offset: int + # The world size + world_size: int + # The group name + group_name: str = "weight_update_group" + # The backend + backend: str = "nccl" + +@dataclass +class InitWeightsUpdateGroupRsp(BaseRsp): + pass + +@dataclass +class DestroyWeightsUpdateGroupReq(BaseReq): + group_name: str = "weight_update_group" + +@dataclass +class DestroyWeightsUpdateGroupRsp(BaseRsp): + pass + +@dataclass +class UpdateWeightsFromDistributedReq(BaseReq): + names: List[str] + dtypes: List[str] + shapes: List[List[int]] + # The group name + group_name: str = "weight_update_group" + # Whether to flush the cache after updating weights + flush_cache: bool = True + # Whether to abort all requests before updating weights + abort_all_requests: bool = False + # Optional: Update weight version along with weights + weight_version: Optional[str] = None + +@dataclass +class UpdateWeightsFromDistributedRsp(BaseRsp): + pass + + +@dataclass +class UpdateWeightsFromTensorReq(BaseReq): + """Update model weights from tensor input. + + - Tensors are serialized for transmission + - Data is structured in JSON for easy transmission over HTTP + """ + + serialized_named_tensors: List[Union[str, bytes]] + # Optional format specification for loading + load_format: Optional[str] = None + # Whether to flush the cache after updating weights + flush_cache: bool = True + # Whether to abort all requests before updating weights + abort_all_requests: bool = False + # Optional: Update weight version along with weights + weight_version: Optional[str] = None + +@dataclass +class UpdateWeightsFromTensorRsp(BaseRsp): + pass \ No newline at end of file diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 35fd861d8..ea1f17b90 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -35,12 +35,13 @@ GenerateReqIndex, FlushCacheReq, FlushCacheResp, + GeneralHttpToModelRpcReq, + GeneralModelToHttpRpcRsp ) from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.process_check import start_parent_check_thread from lightllm.utils.envs_utils import get_unique_server_name - logger = init_logger(__name__) @@ -549,7 +550,7 @@ async def _recv_new_reqs_and_schedule(self): recv_req: BaseReq = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) if isinstance(recv_req, GenerateReqIndex): self._add_req(recv_req) - elif isinstance(recv_req, FlushCacheReq): + elif isinstance(recv_req, (FlushCacheReq, GeneralHttpToModelRpcReq)): special_reqs.append(recv_req) # 当队列中存在较多的请求时,将一次接受的数量上调 @@ -574,6 +575,8 @@ def _process_special_reqs(self, special_reqs: List[BaseReq]): for req in special_reqs: if isinstance(req, FlushCacheReq): self.flush_cache() + elif isinstance(req, (GeneralHttpToModelRpcReq)): + self.forward_to_model(req) def broadcast_reqs_to_other_nodes(self, reqs: List[BaseReq]): req_num = len(reqs) @@ -609,6 +612,20 @@ def flush_cache(self) -> None: self.send_to_detokenization.send_pyobj(FlushCacheResp(success=success), protocol=pickle.HIGHEST_PROTOCOL) return + def forward_to_model(self, req: GeneralHttpToModelRpcReq) -> None: + ret = self.model_rpc_client.forward_to_model(req) + if self.is_multinode_tp: + output_list = [None for _ in self.nnodes] if self.node_rank == 0 else None + dist.gather_object(ret, output_list, dst=0, group=self.mulitnode_group) + for res in output_list: + res : GeneralModelToHttpRpcRsp + if not res.success: + ret = res + break + + if self.node_rank == 0: + self.send_to_detokenization.send_pyobj(ret, protocol=pickle.HIGHEST_PROTOCOL) + def clean_up(self): return diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index db708f3cf..60431c8ff 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -4,7 +4,7 @@ import time import threading import torch.distributed as dist -from typing import List, Tuple, Callable, Optional +from typing import List, Tuple, Callable, Optional, Union from transformers.configuration_utils import PretrainedConfig from lightllm.utils.infer_utils import set_random_seed from lightllm.utils.log_utils import init_logger @@ -16,7 +16,7 @@ from lightllm.common.basemodel.basemodel import TpPartBaseModel from lightllm.common.basemodel.batch_objs import ModelOutput, ModelInput from lightllm.common.basemodel.triton_kernel.mtp_utils import mtp_verify -from lightllm.utils.dist_utils import init_distributed_env +from lightllm.utils.dist_utils import init_distributed_env, init_custom_process_group from lightllm.utils.envs_utils import get_unique_server_name from lightllm.server.core.objs import ShmReqManager, StartArgs from lightllm.server.core.objs.io_objs import AbortedReqCmd, StopStrMatchedReqCmd @@ -31,6 +31,9 @@ enable_radix_tree_timer_merge, get_radix_tree_merge_update_delta, ) +from lightllm.utils.serializer import LocalSerializedTensor, MultiprocessingSerializer +from lightllm.utils.patch_torch import monkey_patch_torch_reductions +from lightllm.utils.tensor_bucket import FlattenedTensorBucket, FlattenedTensorMetadata from lightllm.distributed import dist_group_manager from lightllm.server.core.objs.shm_objs_io_buffer import ShmObjsIOBuffer from lightllm.server.router.model_infer.mode_backend.overlap_events import OverlapEventManager, OverlapEventPack @@ -39,6 +42,12 @@ from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token from lightllm.server.pd_io_struct import NIXLChunckedTransTaskRet from .multi_level_kv_cache import MultiLevelKvCacheModule +from lightllm.server.io_struct import ( + InitWeightsUpdateGroupReq, + DestroyWeightsUpdateGroupReq, + UpdateWeightsFromDistributedReq, + UpdateWeightsFromTensorReq +) class ModeBackend: @@ -112,6 +121,8 @@ def init_model(self, kvargs): ) dist_group_manager.create_groups(group_size=group_size) # set the default group + self._model_update_group = {} + self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", self.dp_size_in_node) # 为 p d 分离模式添加的全局锁管理,用于做一些同步操作。 一定需要在 @@ -293,6 +304,172 @@ def flush_radix_cache(self): self.radix_cache.flush_cache() return + def init_weights_update_group(self, request: InitWeightsUpdateGroupReq): + assert ( + torch.distributed.is_initialized() + ), "Default torch process group must be initialized" + + assert request.group_name != "", "Group name cannot be empty" + rank = request.rank_offset + self.rank_in_dp + self.logger.info( + f"init custom process group: master_address={request.master_address}, master_port={request.master_port}, " + f"rank_offset={request.rank_offset}, rank={rank}, world_size={request.world_size}, group_name={request.group_name}, " + f" backend={request.backend}" + ) + + try: + if request.group_name in self._model_update_group: + raise ValueError( + f"Process group with name {request.group_name} already exists." + ) + + self._model_update_group[request.group_name] = init_custom_process_group( + backend=request.backend, + init_method=f"tcp://{request.master_address}:{request.master_port}", + world_size=request.world_size, + rank=rank, + group_name=request.group_name, + ) + return True, "Succeeded to initialize custom process group." + + except Exception as e: + message = f"Failed to initialize custom process group: {e}." + self.logger.error(message) + return False, message + + def destroy_weights_update_group(self, request: DestroyWeightsUpdateGroupReq): + try: + if request.group_name in self._model_update_group: + pg = self._model_update_group.pop(request.group_name) + torch.distributed.destroy_process_group(pg) + return True, "Succeeded to destroy custom process group." + else: + return False, "The group to be destroyed does not exist." + except Exception as e: + message = f"Failed to destroy custom process group: {e}." + self.logger.error(message) + return False, message + + def update_weights_from_distributed(self, request: UpdateWeightsFromDistributedReq): + """ + Update specific parameter in the model weights online + through `_model_update_group` process group. + + Args: + name: the name of the parameter to be updated. + dtype: the data type of the parameter to be updated. + shape: the shape of the parameter to be updated. + """ + + assert request.group_name in self._model_update_group, ( + f"Group {request.group_name} not in {list(self._model_update_group.keys())}. " + "Please call `init_weights_update_group` first." + ) + + try: + weights = [] + handles = [] + for name, dtype, shape in zip(request.names, request.dtypes, request.shapes): + target_dtype = ( + dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype) + ) + weight = torch.empty(shape, dtype=target_dtype, device='cuda') + handles.append( + torch.distributed.broadcast( + weight, + src=0, + group=self._model_update_group[request.group_name], + async_op=True, + ) + ) + weights.append((name, weight)) + for handle in handles: + handle.wait() + + self.model.load_weights(weights) + return True, "Succeeded to update parameter online from distributed." + + except Exception as e: + error_msg = ( + f"Failed to update parameter online: {e}. " + f"The full weights of the ModelRunner are partially updated. " + f"Please discard the whole weights." + ) + self.logger.error(error_msg) + return False, error_msg + + def _update_weights_from_flattened_bucket( + self, + flattened_tensor_bucket_dict, + ): + """Handle flattened bucket format for weight updates""" + flattened_tensor = flattened_tensor_bucket_dict["flattened_tensor"] + metadata = flattened_tensor_bucket_dict["metadata"] + + # Convert metadata dict to our format + converted_metadata = [] + for meta in metadata: + converted_meta = FlattenedTensorMetadata( + name=meta.name, + shape=meta.shape, + dtype=meta.dtype, + start_idx=meta.start_idx, + end_idx=meta.end_idx, + numel=meta.numel, + ) + converted_metadata.append(converted_meta) + + # Create bucket and reconstruct tensors + bucket = FlattenedTensorBucket( + flattened_tensor=flattened_tensor, metadata=converted_metadata + ) + reconstructed_tensors = bucket.reconstruct_tensors() + + # Load the reconstructed tensors using the standard method + self.model.load_weights(reconstructed_tensors) + + return True, "Succeeded to update parameter online from flattened bucket tensor." + + def update_weights_from_tensor( + self, + request: UpdateWeightsFromTensorReq + ): + try: + monkey_patch_torch_reductions() + if request.load_format == "flattened_bucket": + # Handle flattened bucket format + return self._update_weights_from_flattened_bucket( + flattened_tensor_bucket_dict=request.named_tensors + ) + + # We need to get device after patch otherwise the device would be wrong + self.device_module = torch.get_device_module("cuda") + infered_device = self.device_module.current_device() + + named_tensors=MultiprocessingSerializer.deserialize( + request.serialized_named_tensors[self.rank_in_dp] + ) + + def _unwrap_tensor(tensor, tp_rank, device): + if isinstance(tensor, LocalSerializedTensor): + tensor = tensor.get(tp_rank) + return tensor.to(device) + + named_tensors = { + name : _unwrap_tensor(tensor, tp_rank=self.rank_in_dp, device=infered_device) + for name, tensor in named_tensors + } + + self.model.load_weights(named_tensors) + + return True, "Succeeded to update parameter online from tensor." + + except Exception as e: + message = f"Failed to update parameter online from tensor. Reason: {e}." + self.logger.error(message) + + return False, message + def _async_copy_next_token_infos_to_pin_mem(self, next_token_ids: torch.Tensor, next_token_logprobs: torch.Tensor): """ 这个函数会把next token id和logprobs保存到pinned memory中 diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py index b7797a762..04e7495ad 100644 --- a/lightllm/server/router/model_infer/model_rpc.py +++ b/lightllm/server/router/model_infer/model_rpc.py @@ -33,6 +33,7 @@ from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.process_check import start_parent_check_thread from lightllm.utils.envs_utils import get_unique_server_name +from lightllm.server.io_struct import GeneralHttpToModelRpcReq, GeneralModelToHttpRpcRsp logger = init_logger(__name__) @@ -190,6 +191,16 @@ def flush_radix_cache(self): logger.exception(f"flush radix cache failed: {str(e)}") return False + def forward_to_model(self, req: GeneralHttpToModelRpcReq) -> GeneralModelToHttpRpcRsp: + try: + if self.backend is None or not hasattr(self.backend, req.func_name): + raise ValueError(f"Backend does not support function {req.func_name}") + success, ret = getattr(self.backend, req.func_name)(req.func_args) + return GeneralModelToHttpRpcRsp(success=success, msg=str(ret), func_name=req.func_name, func_rsp=ret) + except BaseException as e: + logger.exception(f"forward to model backend failed: {str(e)}") + return GeneralModelToHttpRpcRsp(success=False, msg=f'forward to model backend failed: {str(e)}', func_name=req.func_name) + class ModelRpcClient: def __init__(self, rpc_event, rpc_finished_event): @@ -230,6 +241,15 @@ def flush_radix_cache(self) -> bool: assert func_name == "flush_radix_cache" return ret + def forward_to_model(self, req: GeneralHttpToModelRpcReq) -> GeneralModelToHttpRpcRsp: + self.rpc_shm_params.write_func_params("forward_to_model", (req,)) + self.rpc_event.set() + + self.rpc_finished_event.wait() + self.rpc_finished_event.clear() + func_name, ret = self.rpc_shm_results.read_func_result() + assert func_name == "forward_to_model" + return ret def _init_env( args, diff --git a/lightllm/utils/dist_utils.py b/lightllm/utils/dist_utils.py index 65ac401d4..28667c6d0 100644 --- a/lightllm/utils/dist_utils.py +++ b/lightllm/utils/dist_utils.py @@ -65,12 +65,15 @@ def init_vision_distributed_env(kvargs): device_id = visual_gpu_ids[kvargs["vit_rank_id"]] set_current_device_id(device_id) torch.cuda.set_device(device_id) + # 不要在init_process_group时,显示的传入device_id + # 这会触发torch的device-bound split优化,会默认后面想加入新进程组的rank + # 都已经存在于默认组,这样RL更新weight的init_group时,外部想加入的组,在执行 + # 通信原语时例如all_reduce,会永远等不到LightLLM默认组里的回复,从而导致错误结果。 dist.init_process_group( "nccl", init_method=f'tcp://127.0.0.1:{kvargs["visual_nccl_port"]}', rank=kvargs["tp_rank_id"], world_size=tp_world_size, - device_id=torch.device(f"cuda:{device_id}"), ) # warmup nccl communicator _a = torch.zeros([1]).to(f"cuda:{device_id}") @@ -104,7 +107,6 @@ def init_distributed_env(kvargs): init_method=f'tcp://{kvargs["nccl_host"]}:{kvargs["nccl_port"]}', rank=kvargs["rank_id"], world_size=kvargs["world_size"], - device_id=torch.device(f"cuda:{device_id}"), ) # warmup nccl communicator _a = torch.zeros([1]).to(f"cuda:{device_id}") @@ -270,3 +272,71 @@ def _init_nccl_env(): assert response.status_code == 200, f"Failed to init config server nccl tcp store: {response.status_code}" return + + +# copy from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/utils/common.py#L1675 +def init_custom_process_group( + backend=None, + init_method=None, + timeout=None, + world_size=-1, + rank=-1, + store=None, + group_name=None, + pg_options=None, + device_id=None, +): + from torch.distributed.distributed_c10d import ( + Backend, + PrefixStore, + _new_process_group_helper, + _world, + default_pg_timeout, + rendezvous, + ) + + assert (store is None) or (init_method is None), "Cannot specify both init_method and store." + + if store is not None: + assert world_size > 0, "world_size must be positive if using store" + assert rank >= 0, "rank must be non-negative if using store" + elif init_method is None: + init_method = "env://" + + if backend: + backend = Backend(backend) + else: + backend = Backend("undefined") + + if timeout is None: + timeout = default_pg_timeout + + # backward compatible API + if store is None: + rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout) + store, rank, world_size = next(rendezvous_iterator) + store.set_timeout(timeout) + + # Use a PrefixStore to avoid accidental overrides of keys used by + # different systems (e.g. RPC) in case the store is multi-tenant. + store = PrefixStore(group_name, store) + + # NOTE: The pg_options parameter was renamed into backend_options in PyTorch 2.6.0 + # https://github.com/pytorch/pytorch/commit/a0c7029a75628cd5fa8df83c0de0ea98ee7fd844 + # We need to determine the appropriate parameter name based on PyTorch version + pg_options_param_name = "backend_options" if str(torch.__version__) >= "2.6" else "pg_options" + pg, _ = _new_process_group_helper( + world_size, + rank, + [], + backend, + store, + group_name=group_name, + **{pg_options_param_name: pg_options}, + timeout=timeout, + device_id=device_id, + ) + + _world.pg_group_ranks[pg] = {i: i for i in range(world_size)} + + return pg diff --git a/lightllm/utils/patch_torch.py b/lightllm/utils/patch_torch.py new file mode 100644 index 000000000..c504e4bbc --- /dev/null +++ b/lightllm/utils/patch_torch.py @@ -0,0 +1,65 @@ +# copied from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/utils/patch_torch.py +from typing import Callable, Union + +import torch +from packaging import version +from torch.multiprocessing import reductions + + +def monkey_patch_torch_reductions(): + """Monkey patching before Torch https://github.com/pytorch/pytorch/pull/149248 is fixed""" + + # Currently, NPU does not support UUID. This has been temporarily commented out, with support expected in the fourth quarter. + # if _is_npu: + # return + + if hasattr(reductions, "_reduce_tensor_original"): + return + + reductions._reduce_tensor_original = reductions.reduce_tensor + reductions._rebuild_cuda_tensor_original = reductions.rebuild_cuda_tensor + + reductions.reduce_tensor = _reduce_tensor_modified + reductions.rebuild_cuda_tensor = _rebuild_cuda_tensor_modified + + reductions.init_reductions() + + +# The signature has not been changed for years, and we will not need this when the next version is released, +# so it looks safe to use a constant. +_REDUCE_TENSOR_ARG_DEVICE_INDEX = 6 + + +def _reduce_tensor_modified(*args, **kwargs): + output_fn, output_args = reductions._reduce_tensor_original(*args, **kwargs) + output_args = _modify_tuple( + output_args, _REDUCE_TENSOR_ARG_DEVICE_INDEX, _device_to_uuid + ) + return output_fn, output_args + + +def _rebuild_cuda_tensor_modified(*args): + args = _modify_tuple(args, _REDUCE_TENSOR_ARG_DEVICE_INDEX, _device_from_maybe_uuid) + return reductions._rebuild_cuda_tensor_original(*args) + + +def _device_to_uuid(device: int) -> str: + return str(torch.cuda.get_device_properties(device).uuid) + + +def _device_from_maybe_uuid(device_maybe_uuid: Union[int, str]) -> int: + if isinstance(device_maybe_uuid, int): + return device_maybe_uuid + + if isinstance(device_maybe_uuid, str): + for device in range(torch.cuda.device_count()): + if str(torch.cuda.get_device_properties(device).uuid) == device_maybe_uuid: + return device + raise Exception("Invalid device_uuid=" + device_maybe_uuid) + + raise Exception(f"Unknown type: {device_maybe_uuid=}") + + +def _modify_tuple(t, index: int, modifier: Callable): + return *t[:index], modifier(t[index]), *t[index + 1 :] + diff --git a/lightllm/utils/serializer.py b/lightllm/utils/serializer.py new file mode 100644 index 000000000..e0b523303 --- /dev/null +++ b/lightllm/utils/serializer.py @@ -0,0 +1,132 @@ + +# copied from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/utils/common.py + +import base64 +import pickle +import io +from dataclasses import dataclass +from multiprocessing.reduction import ForkingPickler +from typing import List + + +class MultiprocessingSerializer: + @staticmethod + def serialize(obj, output_str: bool = False): + """ + Serialize a Python object using ForkingPickler. + + Args: + obj: The object to serialize. + output_str (bool): If True, return a base64-encoded string instead of raw bytes. + + Returns: + bytes or str: The serialized object. + """ + buf = io.BytesIO() + ForkingPickler(buf).dump(obj) + buf.seek(0) + output = buf.read() + + if output_str: + # Convert bytes to base64-encoded string + output = base64.b64encode(output).decode("utf-8") + + return output + + @staticmethod + def deserialize(data): + """ + Deserialize a previously serialized object. + + Args: + data (bytes or str): The serialized data, optionally base64-encoded. + + Returns: + The deserialized Python object. + """ + if isinstance(data, str): + # Decode base64 string to bytes + data = base64.b64decode(data, validate=True) + + return SafeUnpickler(io.BytesIO(data)).load() + + +class SafeUnpickler(pickle.Unpickler): + ALLOWED_MODULE_PREFIXES = { + # --- Python types --- + "builtins.", + "collections.", + "copyreg.", + "functools.", + "itertools.", + "operator.", + "types.", + "weakref.", + # --- PyTorch types --- + "torch.", + "torch._tensor.", + "torch.storage.", + "torch.nn.parameter.", + "torch.autograd.function.", + # --- torch distributed --- + "torch.distributed.", + "torch.distributed._shard.", + "torch.distributed._composable.", + "torch._C._distributed_c10d.", + "torch._C._distributed_fsdp.", + "torch.distributed.optim.", + # --- multiprocessing --- + "multiprocessing.resource_sharer.", + "multiprocessing.reduction.", + "pickletools.", + # --- PEFT / LoRA --- + "peft.", + "transformers.", + "huggingface_hub.", + # --- SGLang & Unitest --- + "sglang.srt.weight_sync.tensor_bucket.", + "sglang.srt.model_executor.model_runner.", + "sglang.srt.layers.", + "sglang.srt.utils.", + } + + DENY_CLASSES = { + ("builtins", "eval"), + ("builtins", "exec"), + ("builtins", "compile"), + ("os", "system"), + ("subprocess", "Popen"), + ("subprocess", "run"), + ("codecs", "decode"), + ("types", "CodeType"), + ("types", "FunctionType"), + } + + def find_class(self, module, name): + # Block deterministic attacks + if (module, name) in self.DENY_CLASSES: + raise RuntimeError( + f"Blocked unsafe class loading ({module}.{name}), " + f"to prevent exploitation of CVE-2025-10164" + ) + # Allowlist of safe-to-load modules. + if any( + (module + ".").startswith(prefix) for prefix in self.ALLOWED_MODULE_PREFIXES + ): + return super().find_class(module, name) + + # Block everything else. (Potential attack surface) + raise RuntimeError( + f"Blocked unsafe class loading ({module}.{name}), " + f"to prevent exploitation of CVE-2025-10164" + ) + +@dataclass +class LocalSerializedTensor: + """torch.Tensor that gets serialized by MultiprocessingSerializer (which only serializes a pointer and not the data). + The i-th element in the list corresponds to i-th rank's GPU.""" + + values: List[bytes] + + def get(self, rank: int): + return MultiprocessingSerializer.deserialize(self.values[rank]) \ No newline at end of file diff --git a/lightllm/utils/tensor_bucket.py b/lightllm/utils/tensor_bucket.py new file mode 100644 index 000000000..762bd0dd0 --- /dev/null +++ b/lightllm/utils/tensor_bucket.py @@ -0,0 +1,108 @@ +# copy from https://raw.githubusercontent.com/sgl-project/sglang/refs/heads/main/python/sglang/srt/weight_sync/tensor_bucket.py +from dataclasses import dataclass +from typing import List, Tuple + +import torch + + +@dataclass +class FlattenedTensorMetadata: + """Metadata for a tensor in a flattened bucket""" + + name: str + shape: torch.Size + dtype: torch.dtype + start_idx: int + end_idx: int + numel: int + + +class FlattenedTensorBucket: + """ + A bucket that flattens multiple tensors into a single tensor for efficient processing + while preserving all metadata needed for reconstruction. + """ + + # This field is solely for users of to check whether the class supports this feature + supports_multi_dtypes = True + + def __init__( + self, + named_tensors: List[Tuple[str, torch.Tensor]] = None, + flattened_tensor: torch.Tensor = None, + metadata: List[FlattenedTensorMetadata] = None, + ): + """ + Initialize a tensor bucket from a list of named tensors OR from pre-flattened data. + Args: + named_tensors: List of (name, tensor) tuples (for creating new bucket) + flattened_tensor: Pre-flattened tensor (for reconstruction) + metadata: Pre-computed metadata (for reconstruction) + """ + if named_tensors is not None: + # Create bucket from named tensors + self.metadata: List[FlattenedTensorMetadata] = [None] * len(named_tensors) + self.flattened_tensor: torch.Tensor = None + + if not named_tensors: + raise ValueError("Cannot create empty tensor bucket") + + # Collect metadata and flatten tensors + current_idx = 0 + flattened_tensors: List[torch.Tensor] = [None] * len(named_tensors) + + for i, (name, tensor) in enumerate(named_tensors): + flattened = tensor.flatten().view(torch.uint8) + flattened_tensors[i] = flattened + + # Store metadata + + numel = flattened.numel() + metadata_obj = FlattenedTensorMetadata( + name=name, + shape=tensor.shape, + dtype=tensor.dtype, + start_idx=current_idx, + end_idx=current_idx + numel, + numel=numel, + ) + self.metadata[i] = metadata_obj + current_idx += numel + + # Concatenate all flattened tensors + self.flattened_tensor = torch.cat(flattened_tensors, dim=0) + else: + # Initialize from pre-flattened data + if flattened_tensor is None or metadata is None: + raise ValueError( + "Must provide either named_tensors or both flattened_tensor and metadata" + ) + self.flattened_tensor = flattened_tensor + self.metadata = metadata + + def get_flattened_tensor(self) -> torch.Tensor: + """Get the flattened tensor containing all bucket tensors""" + return self.flattened_tensor + + def get_metadata(self) -> List[FlattenedTensorMetadata]: + """Get metadata for all tensors in the bucket""" + return self.metadata + + def reconstruct_tensors(self) -> List[Tuple[str, torch.Tensor]]: + """ + Reconstruct original tensors from flattened tensor with optimized performance. + Uses memory-efficient operations to minimize allocations and copies. + """ + # preallocate the result list + reconstructed = [None] * len(self.metadata) + + for i, meta in enumerate(self.metadata): + tensor = ( + self.flattened_tensor[meta.start_idx : meta.end_idx] + .view(meta.dtype) + .reshape(meta.shape) + ) + + reconstructed[i] = (meta.name, tensor) + + return reconstructed \ No newline at end of file From c210c82fa08f9933ab286e5b67d638da35c9a7aa Mon Sep 17 00:00:00 2001 From: shihaobai <42648726+shihaobai@users.noreply.github.com> Date: Mon, 1 Dec 2025 20:06:06 +0800 Subject: [PATCH 08/71] release and resume (#1122) --- lightllm/common/basemodel/basemodel.py | 96 ++++++++++++++++--- lightllm/common/basemodel/cuda_graph.py | 9 +- .../basemodel/layer_weights/hf_load_utils.py | 70 +++++++++++++- lightllm/server/api_cli.py | 6 ++ lightllm/server/api_http.py | 40 +++++--- lightllm/server/core/objs/start_args_type.py | 22 ++--- lightllm/server/detokenization/manager.py | 12 +-- lightllm/server/httpserver/manager.py | 78 ++++++++------- lightllm/server/io_struct.py | 33 ++++++- lightllm/server/router/manager.py | 36 +++---- .../model_infer/mode_backend/base_backend.py | 75 ++++++++------- .../server/router/model_infer/model_rpc.py | 37 +++---- lightllm/utils/torch_memory_saver_utils.py | 92 ++++++++++++++++++ requirements.txt | 3 +- 14 files changed, 451 insertions(+), 158 deletions(-) create mode 100644 lightllm/utils/torch_memory_saver_utils.py diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 1221f1939..cc50d0a08 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -6,7 +6,7 @@ import json import torch import torch.nn.functional as F -from typing import final +from typing import final, List, Optional from tqdm import tqdm from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights @@ -30,6 +30,10 @@ from lightllm.utils.envs_utils import set_model_init_status from lightllm.common.triton_utils.autotuner import Autotuner from lightllm.utils.infer_utils import post_empty_cache +from lightllm.utils.torch_memory_saver_utils import ( + TorchMemorySaverWrapper, + MemoryTag, +) logger = init_logger(__name__) @@ -88,6 +92,7 @@ def __init__(self, kvargs): self.enable_tpsp_mix_mode = get_env_start_args().enable_tpsp_mix_mode self.is_deepseekv3_mtp_mode = self.args.mtp_mode in ["deepseekv3_vanilla", "deepseekv3_eagle"] + self.torch_memory_saver = TorchMemorySaverWrapper(self.args.enable_torch_memory_saver) self._init_datatype() self._init_config() @@ -97,20 +102,29 @@ def __init__(self, kvargs): # 更连续的显存分配可以有更好的性能 if self.max_total_token_num is None: - self._init_weights() - self._init_mem_manager() + with self.torch_memory_saver.region( + tag=MemoryTag.WEIGHT, enable_cpu_backup=self.args.enable_weight_cpu_backup + ): + self._init_weights() + with self.torch_memory_saver.region(tag=MemoryTag.KV_CACHE): + self._init_mem_manager() else: - self._init_mem_manager() - self._init_weights() + with self.torch_memory_saver.region(tag=MemoryTag.KV_CACHE): + self._init_mem_manager() + with self.torch_memory_saver.region( + tag=MemoryTag.WEIGHT, enable_cpu_backup=self.args.enable_weight_cpu_backup + ): + self._init_weights() self._init_kv_move_buffer() self._check_mem_size() - self._init_req_manager() + with self.torch_memory_saver.region(tag=MemoryTag.KV_CACHE): + self._init_req_manager() self._init_infer_layer() self._init_some_value() self._init_custom() self._init_inferstate_cls() - self._autotune_warmup() + # self._autotune_warmup() self._init_padded_req() # wait必须在init cudagraph 之前,避免错误捕获 self._wait_other_modules_ready() @@ -179,11 +193,13 @@ def _init_weights(self): return def load_weights(self, weight_dict: dict): - load_hf_weights(self.data_type, - self.weight_dir_, - pre_post_layer=self.pre_post_weight, - transformer_layer_list=self.trans_layers_weight, - weight_dict=weight_dict) + load_hf_weights( + self.data_type, + self.weight_dir_, + pre_post_layer=self.pre_post_weight, + transformer_layer_list=self.trans_layers_weight, + weight_dict=weight_dict, + ) def _init_mem_manager(self): assert self.config["num_attention_heads"] % self.tp_world_size_ == 0 @@ -766,6 +782,7 @@ def _check_max_len_infer(self): ) logger.error(exception_str) raise Exception(exception_str) + torch.cuda.empty_cache() return def autotune_layers(self): @@ -896,6 +913,9 @@ def _init_padded_req(self): del b_seq_len del b_ready_cache_len del model_output + del b_mtp_index + del b_prefill_start_loc + del b_q_seq_len torch.cuda.empty_cache() return @@ -911,3 +931,55 @@ def _gen_special_model_input(self, token_num: int): special_model_input["deepseekv3_mtp_draft_input_hiddens"] = None return special_model_input + + def release_memory_occupation(self, tags: Optional[List[MemoryTag]]): + if tags is None: + self.release_all() + return + if MemoryTag.WEIGHT in tags: + self.release_weight() + if MemoryTag.KV_CACHE in tags: + self.release_kv_cache() + if MemoryTag.GRAPH in tags: + self.release_graph() + return + + def resume_memory_occupation(self, tags: Optional[List[MemoryTag]]): + if tags is None: + self.resume_all() + return + if MemoryTag.WEIGHT in tags: + self.resume_weight() + if MemoryTag.KV_CACHE in tags: + self.resume_kv_cache() + if MemoryTag.GRAPH in tags: + self.resume_graph() + return + + def release_weight(self): + self.torch_memory_saver.pause(tag=MemoryTag.WEIGHT) + + def release_kv_cache(self): + self.torch_memory_saver.pause(tag=MemoryTag.KV_CACHE) + + def release_graph(self): + self.torch_memory_saver.pause(tag=MemoryTag.GRAPH) + + def release_all(self): + self.torch_memory_saver.pause(tag=MemoryTag.WEIGHT) + self.torch_memory_saver.pause(tag=MemoryTag.KV_CACHE) + self.torch_memory_saver.pause(tag=MemoryTag.GRAPH) + + def resume_weight(self): + self.torch_memory_saver.resume(tag=MemoryTag.WEIGHT) + + def resume_kv_cache(self): + self.torch_memory_saver.resume(tag=MemoryTag.KV_CACHE) + + def resume_graph(self): + self.torch_memory_saver.resume(tag=MemoryTag.GRAPH) + + def resume_all(self): + self.torch_memory_saver.resume(tag=MemoryTag.WEIGHT) + self.torch_memory_saver.resume(tag=MemoryTag.KV_CACHE) + self.torch_memory_saver.resume(tag=MemoryTag.GRAPH) diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index c754fabce..220ae10cf 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -7,6 +7,10 @@ from lightllm.utils.envs_utils import get_env_start_args from lightllm.distributed import dist_group_manager, lightllm_capture_graph, CustomProcessGroup from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput +from lightllm.utils.torch_memory_saver_utils import ( + TorchMemorySaverWrapper, + MemoryTag, +) from .infer_struct import InferStateInfo @@ -24,6 +28,7 @@ def __init__(self, max_batch_size=8, max_len_in_batch=8192): self.max_batch_size = max_batch_size self.graph_max_len_in_batch = max_len_in_batch self.enable_decode_microbatch_overlap = self.args.enable_decode_microbatch_overlap + self.torch_memory_saver = TorchMemorySaverWrapper(self.args.enable_torch_memory_saver) # gen cuda graph batch_sizes # cuda graph gen for batch size = [1, 2, 3, ..., graph_split_batch_size] @@ -82,7 +87,7 @@ def _capture_decode(self, decode_func, input_ids: torch.Tensor, infer_state: Inf torch.cuda.synchronize() with lightllm_capture_graph(dist_group): - with torch.cuda.graph(graph_obj, pool=self.mempool): + with self.torch_memory_saver.cuda_graph(graph_obj, pool=self.mempool): model_output = decode_func(input_ids, infer_state) self.graph[batch_size] = (graph_obj, input_ids, infer_state, model_output) graph_obj.replay() @@ -111,7 +116,7 @@ def _capture_decode_overlap( torch.cuda.synchronize() with lightllm_capture_graph(dist_group1): with lightllm_capture_graph(dist_group): - with torch.cuda.graph(graph_obj, pool=self.mempool): + with self.torch_memory_saver.cuda_graph(graph_obj, pool=self.mempool): model_output, model_output1 = decode_func(input_ids, infer_state, input_ids1, infer_state1) self.graph[batch_size] = ( graph_obj, diff --git a/lightllm/common/basemodel/layer_weights/hf_load_utils.py b/lightllm/common/basemodel/layer_weights/hf_load_utils.py index 8cf66a5ad..2a9006efd 100755 --- a/lightllm/common/basemodel/layer_weights/hf_load_utils.py +++ b/lightllm/common/basemodel/layer_weights/hf_load_utils.py @@ -5,6 +5,8 @@ from tqdm import tqdm import lightllm.utils.petrel_helper as utils from lightllm.utils.dist_utils import get_current_device_id +from queue import Queue +from threading import Thread def load_func(file_, use_safetensors=False, pre_post_layer=None, transformer_layer_list=None, weight_dir=None): @@ -28,7 +30,7 @@ def load_func(file_, use_safetensors=False, pre_post_layer=None, transformer_lay gc.collect() -def load_hf_weights(data_type, weight_dir, pre_post_layer=None, transformer_layer_list=None, weight_dict=None): +def load_hf_weights_old(data_type, weight_dir, pre_post_layer=None, transformer_layer_list=None, weight_dict=None): if isinstance(data_type, str): data_type = torch.float16 if data_type == "fp16" else torch.float32 if pre_post_layer is not None: @@ -70,3 +72,69 @@ def load_hf_weights(data_type, weight_dir, pre_post_layer=None, transformer_laye pass return + + +def _read_file(file_, use_safetensors, weight_dir): + if use_safetensors: + weights = safe_open(os.path.join(weight_dir, file_), "pt", "cpu") + weights = {k: weights.get_tensor(k) for k in weights.keys()} + else: + weights = utils.PetrelHelper.load(os.path.join(weight_dir, file_), map_location="cpu") + + return weights + + +def load_hf_weights(data_type, weight_dir, pre_post_layer=None, transformer_layer_list=None, weight_dict=None): + if isinstance(data_type, str): + data_type = torch.float16 if data_type == "fp16" else torch.float32 + if pre_post_layer is not None: + assert pre_post_layer.data_type_ == data_type, "type is not right" + if transformer_layer_list is not None: + assert transformer_layer_list[0].data_type_ == data_type, "type is not right" + if weight_dict: + if pre_post_layer is not None: + pre_post_layer.load_hf_weights(weight_dict) + if transformer_layer_list is not None: + for layer in transformer_layer_list: + layer.load_hf_weights(weight_dict) + del weight_dict + return + use_safetensors = True + files = utils.PetrelHelper.list(weight_dir, extension="all") + candidate_files = list(filter(lambda x: x.endswith(".safetensors"), files)) + if len(candidate_files) == 0: + use_safetensors = False + candidate_files = list(filter(lambda x: x.endswith(".bin"), files)) + assert len(candidate_files) != 0, "can only support pytorch tensor and safetensors format for weights." + + weight_queue = Queue(maxsize=5) # 控制内存使用 + + def producer(chunk): + for file_ in chunk: + weights = _read_file(file_, use_safetensors, weight_dir) + weight_queue.put(weights) + + LOADWORKER = int(os.environ.get("LOADWORKER", 1)) + + num_producers = min(LOADWORKER, len(candidate_files)) # 生产者数量 + chunk_size = (len(candidate_files) + num_producers - 1) // num_producers + file_chunks = [candidate_files[i : i + chunk_size] for i in range(0, len(candidate_files), chunk_size)] + + producer_threads = [] + for i, chunk in enumerate(file_chunks): + thread = Thread(target=producer, args=(chunk,), name=f"Producer-{i}") + thread.start() + producer_threads.append(thread) + + for _ in tqdm(range(len(candidate_files)), desc="Loading weights"): + weights = weight_queue.get() + if pre_post_layer is not None: + pre_post_layer.load_hf_weights(weights) + if transformer_layer_list is not None: + for layer in transformer_layer_list: + layer.load_hf_weights(weights) + del weights + gc.collect() + + for thread in producer_threads: + thread.join() diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index ec11f8f1d..ee3f184e4 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -537,4 +537,10 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--disk_cache_storage_size", type=float, default=10, help="""The capacity of disk cache. GB used.""" ) + parser.add_argument( + "--enable_torch_memory_saver", + action="store_true", + help="""enable torch memory saver, which is used for release_memory and resume_memory during RL training.""", + ) + parser.add_argument("--enable_weight_cpu_backup", action="store_true", help="""enable weight cpu backup.""") return parser diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index 28ae93dfb..ff9acafc9 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -49,6 +49,7 @@ from lightllm.utils.error_utils import ServerBusyError from lightllm.server.metrics.manager import MetricClient from lightllm.utils.envs_utils import get_unique_server_name +from lightllm.server.io_struct import ReleaseMemoryReq, ResumeMemoryReq from dataclasses import dataclass from .api_openai import chat_completions_impl, completions_impl @@ -60,11 +61,12 @@ ) from .io_struct import ( AbortReq, + FlushCacheReq, InitWeightsUpdateGroupReq, DestroyWeightsUpdateGroupReq, UpdateWeightsFromDistributedReq, UpdateWeightsFromTensorReq, - GeneralModelToHttpRpcRsp + GeneralModelToHttpRpcRsp, ) from .build_prompt import build_prompt, init_tokenizer @@ -143,14 +145,17 @@ def get_model_name(): def get_server_info(): # 将 StartArgs 转换为字典格式 from dataclasses import asdict + server_info: dict[str, Any] = asdict(g_objs.args) return {**server_info} + @app.get("/get_weight_version") @app.post("/get_weight_version") def get_weight_version(): return {"weight_version": g_objs.args.weight_version} + @app.get("/healthz", summary="Check server health") @app.get("/health", summary="Check server health") @app.head("/health", summary="Check server health") @@ -330,43 +335,38 @@ async def handle_request_common(request_obj, handler): else: return create_error_response(HTTPStatus.BAD_REQUEST, ret.msg) except Exception as e: - return create_error_response( - HTTPStatus.EXPECTATION_FAILED, - f"error: {str(e)}" - ) + return create_error_response(HTTPStatus.EXPECTATION_FAILED, f"error: {str(e)}") + @app.post("/init_weights_update_group") async def init_weights_update_group(request: InitWeightsUpdateGroupReq, raw_request: Request): """Init weights update group.""" return await handle_request_common(request, g_objs.httpserver_manager.init_weights_update_group) + @app.post("/destroy_weights_update_group") async def destroy_weights_update_group(request: DestroyWeightsUpdateGroupReq, raw_request: Request): """Destroy weights update group.""" return await handle_request_common(request, g_objs.httpserver_manager.destroy_weights_update_group) + @app.post("/update_weights_from_distributed") async def update_weights_from_distributed(request: UpdateWeightsFromDistributedReq, raw_request: Request): """Update model parameter from distributed online.""" return await handle_request_common(request, g_objs.httpserver_manager.update_weights_from_distributed) + @app.post("/update_weights_from_tensor") -async def update_weights_from_distributed(request: UpdateWeightsFromTensorReq, raw_request: Request): +async def update_weights_from_tensor(request: UpdateWeightsFromTensorReq, raw_request: Request): """Update model parameter from distributed online.""" return await handle_request_common(request, g_objs.httpserver_manager.update_weights_from_tensor) + @app.post("/flush_cache") @app.get("/flush_cache") async def flush_cache(): """Flush the radix cache.""" - ret = await g_objs.httpserver_manager.flush_cache() - return Response( - content="Cache flushed successfully." - if ret - else "Cache flush failed. " - + "When there are running or waiting requests, the operation will not be performed.", - status_code=200 if ret else 500, - ) + return await handle_request_common(FlushCacheReq(), g_objs.httpserver_manager.flush_cache) @app.post("/pause_generation") @@ -381,6 +381,18 @@ async def continue_generation(): return Response(content="Generation continued successfully.", status_code=200) +@app.get("/release_memory_occupation") +@app.post("/release_memory_occupation") +async def release_memory_occupation(request: ReleaseMemoryReq): + return await handle_request_common(request, g_objs.httpserver_manager.release_memory_occupation) + + +@app.get("/resume_memory_occupation") +@app.post("/resume_memory_occupation") +async def resume_memory_occupation(request: ResumeMemoryReq): + return await handle_request_common(request, g_objs.httpserver_manager.resume_memory_occupation) + + @app.websocket("/pd_register") async def register_and_keep_alive(websocket: WebSocket): await websocket.accept() diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 40f68a743..eff4dfab5 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -8,7 +8,9 @@ class StartArgs: run_mode: str = field( default="normal", - metadata={"choices": ["normal", "prefill", "decode", "nixl_prefill", "nixl_decode", "pd_master", "config_server"]}, + metadata={ + "choices": ["normal", "prefill", "decode", "nixl_prefill", "nixl_decode", "pd_master", "config_server"] + }, ) host: str = field(default="127.0.0.1") port: int = field(default=8000) @@ -23,8 +25,7 @@ class StartArgs: config_server_port: int = field(default=None) pd_decode_rpyc_port: int = field(default=42000) select_p_d_node_strategy: str = field( - default="round_robin", - metadata={"choices": ["random", "round_robin", "adaptive_load"]} + default="round_robin", metadata={"choices": ["random", "round_robin", "adaptive_load"]} ) model_name: str = field(default="default_model_name") model_dir: Optional[str] = field(default=None) @@ -59,10 +60,7 @@ class StartArgs: disable_chunked_prefill: bool = field(default=False) diverse_mode: bool = field(default=False) token_healing_mode: bool = field(default=False) - output_constraint_mode: str = field( - default="none", - metadata={"choices": ["outlines", "xgrammar", "none"]} - ) + output_constraint_mode: str = field(default="none", metadata={"choices": ["outlines", "xgrammar", "none"]}) first_token_constraint_mode: bool = field(default=False) enable_multimodal: bool = field(default=False) enable_multimodal_audio: bool = field(default=False) @@ -107,8 +105,7 @@ class StartArgs: ep_redundancy_expert_config_path: Optional[str] = field(default=None) auto_update_redundancy_expert: bool = field(default=False) mtp_mode: Optional[str] = field( - default=None, - metadata={"choices": ["deepseekv3_vanilla", "deepseekv3_eagle", None]} + default=None, metadata={"choices": ["deepseekv3_vanilla", "deepseekv3_eagle", None]} ) mtp_draft_model_dir: Optional[str] = field(default=None) mtp_step: int = field(default=0) @@ -141,10 +138,7 @@ class StartArgs: httpserver_workers: int = field(default=1) disable_shm_warning: bool = field(default=False) - dp_balancer: str = field( - default="bs_balancer", - metadata={"choices": ["round_robin", "bs_balancer"]} - ) + dp_balancer: str = field(default="bs_balancer", metadata={"choices": ["round_robin", "bs_balancer"]}) enable_custom_allgather: bool = field(default=False) enable_fused_shared_experts: bool = field(default=False) enable_mps: bool = field(default=False) @@ -152,5 +146,7 @@ class StartArgs: schedule_time_interval: float = field(default=0.03) use_dynamic_prompt_cache: bool = field(default=False) disable_custom_allreduce: bool = field(default=False) + enable_torch_memory_saver: bool = field(default=False) + enable_weight_cpu_backup: bool = field(default=False) weight_version: str = "default" diff --git a/lightllm/server/detokenization/manager.py b/lightllm/server/detokenization/manager.py index b7ba96025..17a47dfde 100644 --- a/lightllm/server/detokenization/manager.py +++ b/lightllm/server/detokenization/manager.py @@ -20,7 +20,9 @@ BaseReq, GenerateResp, FlushCacheResp, - GeneralModelToHttpRpcRsp + ReleaseMemoryResp, + ResumeMemoryResp, + GeneralModelToHttpRpcRsp, ) logger = init_logger(__name__) @@ -80,14 +82,8 @@ def handle_loop(self): # 一次最多从 zmq 中取 recv_max_count 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。 for _ in range(recv_max_count): recv_obj: BaseReq = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) - if isinstance(recv_obj, FlushCacheResp): - print("Detokenization receive flush cache request", flush=True) + if isinstance(recv_obj, GeneralModelToHttpRpcRsp): self.send_to_httpserver.send_pyobj(recv_obj, protocol=pickle.HIGHEST_PROTOCOL) - print("Detokenization send flush cache request to httpserver", flush=True) - continue - elif isinstance(recv_obj, GeneralModelToHttpRpcRsp): - self.send_to_httpserver.send_pyobj(recv_obj, protocol=pickle.HIGHEST_PROTOCOL) - print(f"Detokenization send {recv_obj.func_name} request to httpserver") continue self._add_new_group_req_index(recv_obj=recv_obj) diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 44721cfe3..083e939ba 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -14,7 +14,7 @@ from frozendict import frozendict asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) -from typing import Union, List, Tuple, Dict, Optional +from typing import Union, List, Tuple, Dict, Optional, Literal from websockets import ClientConnection from fastapi import Request from ..tokenizer import get_tokenizer @@ -40,6 +40,10 @@ GenerateResp, GenerateReqMeta, GenerateReqIndex, + ReleaseMemoryReq, + ReleaseMemoryResp, + ResumeMemoryReq, + ResumeMemoryResp, InitWeightsUpdateGroupReq, InitWeightsUpdateGroupRsp, DestroyWeightsUpdateGroupReq, @@ -49,13 +53,13 @@ UpdateWeightsFromTensorReq, UpdateWeightsFromTensorRsp, GeneralHttpToModelRpcReq, - GeneralModelToHttpRpcRsp - + GeneralModelToHttpRpcRsp, ) from lightllm.utils.statics_utils import MovingAverage from lightllm.utils.config_utils import get_vocab_size from lightllm.utils.envs_utils import get_unique_server_name from lightllm.utils.error_utils import NixlPrefillNodeStopGenToken +from lightllm.utils.torch_memory_saver_utils import MemoryTag from rpyc.utils.classic import obtain logger = init_logger(__name__) @@ -140,6 +144,8 @@ def __init__( # 交互式请求 event self.flush_cache_event: Optional[asyncio.Event] = None + self.release_memory_event: Optional[asyncio.Event] = None + self.resume_memory_event: Optional[asyncio.Event] = None self.async_events_per_func: Dict[str, asyncio.Event] = {} return @@ -768,8 +774,6 @@ async def handle_loop(self): try: if recv_obj is None or isinstance(recv_obj, GenerateResp): await self._handle_recv_generate_request(recv_obj) - elif isinstance(recv_obj, FlushCacheResp): - await self._handle_recv_flush_cache_request(recv_obj) elif isinstance(recv_obj, GeneralModelToHttpRpcRsp): await self._handle_recv_general_model_to_http_request(recv_obj) @@ -835,12 +839,6 @@ async def _handle_recv_generate_request(self, recv_obj: GenerateReqMeta): req_status.out_token_info_list.extend(token_list) req_status.event.set() - async def _handle_recv_flush_cache_request(self, recv_obj: FlushCacheResp): - assert self.flush_cache_event is not None - self.flush_cache_event.success = recv_obj.success - self.flush_cache_event.set() - return - async def _handle_recv_general_model_to_http_request(self, recv_obj: GeneralModelToHttpRpcRsp): assert recv_obj.func_name is not None event = await self.get_event_for_func(recv_obj.func_name) @@ -848,22 +846,11 @@ async def _handle_recv_general_model_to_http_request(self, recv_obj: GeneralMode event.set() return - async def flush_cache(self): - if self.flush_cache_event is None: - self.flush_cache_event = asyncio.Event() - await self.transfer_to_next_module(FlushCacheReq()) - try: - await asyncio.wait_for(self.flush_cache_event.wait(), timeout=30) - ret = self.flush_cache_event.success - except asyncio.TimeoutError: - # 超时直接返回失败 - ret = False - self.flush_cache_event.clear() - return ret - async def pause_generation(self): # 因为请求是从master node转发到slave node的 # 所以只要master暂停了,slave自然暂停。 + if self.is_pause: + return async with self.is_pause_cond: self.is_pause = True while True: @@ -883,7 +870,9 @@ async def get_event_for_func(self, func_name: str) -> asyncio.Event: self.async_events_per_func[func_name] = asyncio.Event() return self.async_events_per_func[func_name] - async def http_to_model_special_request(self, request: GeneralHttpToModelRpcReq, timeout: int=300) -> GeneralModelToHttpRpcRsp: + async def http_to_model_special_request( + self, request: GeneralHttpToModelRpcReq, timeout: int = 300 + ) -> GeneralModelToHttpRpcRsp: event = await self.get_event_for_func(request.func_name) await self.transfer_to_next_module(request) try: @@ -893,19 +882,41 @@ async def http_to_model_special_request(self, request: GeneralHttpToModelRpcReq, except asyncio.TimeoutError: ret = GeneralModelToHttpRpcRsp(success=False, msg="wait for response timeout", func_name=request.func_name) except Exception as e: - ret = GeneralModelToHttpRpcRsp(success=False, msg="wait for response error: %s" % str(e), func_name=request.func_name) + ret = GeneralModelToHttpRpcRsp( + success=False, msg="wait for response error: %s" % str(e), func_name=request.func_name + ) return ret + async def flush_cache(self, request: FlushCacheReq): + return await self.http_to_model_special_request( + GeneralHttpToModelRpcReq(func_name="flush_cache", func_args=request) + ) - async def init_weights_update_group(self, request: InitWeightsUpdateGroupReq): - return await self.http_to_model_special_request(GeneralHttpToModelRpcReq( - func_name="init_weights_update_group", func_args=request)) + async def release_memory_occupation(self, request: ReleaseMemoryReq): + assert len(self.req_id_to_out_inf) == 0, "there are still requests running, cannot release memory occupation" + # 暂停接受请求,除非resume + await self.pause_generation() + return await self.http_to_model_special_request( + GeneralHttpToModelRpcReq(func_name="release_memory_occupation", func_args=request.tags) + ) + async def resume_memory_occupation(self, request: ResumeMemoryReq): + ret = await self.http_to_model_special_request( + GeneralHttpToModelRpcReq(func_name="resume_memory_occupation", func_args=request.tags) + ) + if ret.success: + await self.continue_generation() + return ret - async def destroy_weights_update_group(self, request: DestroyWeightsUpdateGroupReq): - return await self.http_to_model_special_request(GeneralHttpToModelRpcReq( - func_name="destroy_weights_update_group", func_args=request)) + async def init_weights_update_group(self, request: InitWeightsUpdateGroupReq): + return await self.http_to_model_special_request( + GeneralHttpToModelRpcReq(func_name="init_weights_update_group", func_args=request) + ) + async def destroy_weights_update_group(self, request: DestroyWeightsUpdateGroupReq): + return await self.http_to_model_special_request( + GeneralHttpToModelRpcReq(func_name="destroy_weights_update_group", func_args=request) + ) async def update_weights_from_distributed(self, request: UpdateWeightsFromDistributedReq): @@ -916,7 +927,8 @@ async def update_weights_from_distributed(self, request: UpdateWeightsFromDistri await self.flush_cache() return await self.http_to_model_special_request( - GeneralHttpToModelRpcReq(func_name="update_weights_from_distributed", func_args=request)) + GeneralHttpToModelRpcReq(func_name="update_weights_from_distributed", func_args=request) + ) async def update_weights_from_tensor(self, request: UpdateWeightsFromTensorReq) -> Tuple[bool, str]: if request.abort_all_requests: diff --git a/lightllm/server/io_struct.py b/lightllm/server/io_struct.py index 7947c7a58..e04e8871c 100644 --- a/lightllm/server/io_struct.py +++ b/lightllm/server/io_struct.py @@ -4,6 +4,7 @@ from lightllm.server.core.objs.sampling_params import SamplingParams from lightllm.server.multimodal_params import MultimodalParams from typing import List, Optional, Any, Union +from lightllm.utils.torch_memory_saver_utils import MemoryTag @dataclass @@ -14,11 +15,13 @@ def get_req_to_next_node(self): def get_req_to_next_module(self): return self + @dataclass class BaseRsp(ABC): success: bool msg: Optional[str] + # for next node @dataclass class GenerateReqMeta(BaseReq): @@ -85,16 +88,38 @@ class AbortReq(BaseReq): abort_all: bool = False +@dataclass +class ReleaseMemoryReq(BaseReq): + tags: Optional[List[MemoryTag]] = None + + +@dataclass +class ReleaseMemoryResp(BaseReq): + success: bool + + +@dataclass +class ResumeMemoryReq(BaseReq): + tags: Optional[List[MemoryTag]] = None + + +@dataclass +class ResumeMemoryResp(BaseReq): + success: bool + + @dataclass class GeneralHttpToModelRpcReq(BaseReq): func_name: str func_args: Optional[Any] = None + @dataclass class GeneralModelToHttpRpcRsp(BaseRsp): func_name: str func_rsp: Optional[Any] = None + @dataclass class InitWeightsUpdateGroupReq(BaseReq): # The master address @@ -110,18 +135,22 @@ class InitWeightsUpdateGroupReq(BaseReq): # The backend backend: str = "nccl" + @dataclass class InitWeightsUpdateGroupRsp(BaseRsp): pass + @dataclass class DestroyWeightsUpdateGroupReq(BaseReq): group_name: str = "weight_update_group" + @dataclass class DestroyWeightsUpdateGroupRsp(BaseRsp): pass + @dataclass class UpdateWeightsFromDistributedReq(BaseReq): names: List[str] @@ -136,6 +165,7 @@ class UpdateWeightsFromDistributedReq(BaseReq): # Optional: Update weight version along with weights weight_version: Optional[str] = None + @dataclass class UpdateWeightsFromDistributedRsp(BaseRsp): pass @@ -159,6 +189,7 @@ class UpdateWeightsFromTensorReq(BaseReq): # Optional: Update weight version along with weights weight_version: Optional[str] = None + @dataclass class UpdateWeightsFromTensorRsp(BaseRsp): - pass \ No newline at end of file + pass diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index ea1f17b90..20952e228 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -35,12 +35,17 @@ GenerateReqIndex, FlushCacheReq, FlushCacheResp, + ReleaseMemoryReq, + ReleaseMemoryResp, + ResumeMemoryReq, + ResumeMemoryResp, GeneralHttpToModelRpcReq, - GeneralModelToHttpRpcRsp + GeneralModelToHttpRpcRsp, ) from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.process_check import start_parent_check_thread from lightllm.utils.envs_utils import get_unique_server_name +from lightllm.utils.torch_memory_saver_utils import MemoryTag logger = init_logger(__name__) @@ -550,8 +555,10 @@ async def _recv_new_reqs_and_schedule(self): recv_req: BaseReq = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) if isinstance(recv_req, GenerateReqIndex): self._add_req(recv_req) - elif isinstance(recv_req, (FlushCacheReq, GeneralHttpToModelRpcReq)): + elif isinstance(recv_req, GeneralHttpToModelRpcReq): special_reqs.append(recv_req) + else: + raise ValueError(f"Unknown request type: {type(recv_req)}") # 当队列中存在较多的请求时,将一次接受的数量上调 self.recv_max_count = min(int(self.recv_max_count * 1.3), 256) @@ -573,10 +580,8 @@ def _process_special_reqs(self, special_reqs: List[BaseReq]): if self.is_multinode_tp: special_reqs = self.broadcast_reqs_to_other_nodes(special_reqs) for req in special_reqs: - if isinstance(req, FlushCacheReq): - self.flush_cache() - elif isinstance(req, (GeneralHttpToModelRpcReq)): - self.forward_to_model(req) + assert isinstance(req, GeneralHttpToModelRpcReq), "special request must be GeneralHttpToModelRpcReq" + self.forward_to_model(req) def broadcast_reqs_to_other_nodes(self, reqs: List[BaseReq]): req_num = len(reqs) @@ -595,30 +600,13 @@ def broadcast_reqs_to_other_nodes(self, reqs: List[BaseReq]): dist.broadcast_object_list(reqs, src=0, group=self.mulitnode_group) return reqs - def flush_cache(self) -> None: - # if radix cache client is not initialized, just return True - if self.radix_cache_client is None: - success = True - # only flush cache when no running batch and no waiting requests - elif self.running_batch is not None or self.req_queue.get_wait_req_num() > 0: - success = False - else: - success = self.model_rpc_client.flush_radix_cache() - - if self.is_multinode_tp: - # 等待其他节点的flush 结果 - dist.barrier(group=self.mulitnode_group) - if self.node_rank == 0: - self.send_to_detokenization.send_pyobj(FlushCacheResp(success=success), protocol=pickle.HIGHEST_PROTOCOL) - return - def forward_to_model(self, req: GeneralHttpToModelRpcReq) -> None: ret = self.model_rpc_client.forward_to_model(req) if self.is_multinode_tp: output_list = [None for _ in self.nnodes] if self.node_rank == 0 else None dist.gather_object(ret, output_list, dst=0, group=self.mulitnode_group) for res in output_list: - res : GeneralModelToHttpRpcRsp + res: GeneralModelToHttpRpcRsp if not res.success: ret = res break diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 60431c8ff..4be944584 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -41,12 +41,14 @@ from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token from lightllm.server.pd_io_struct import NIXLChunckedTransTaskRet +from lightllm.utils.torch_memory_saver_utils import MemoryTag from .multi_level_kv_cache import MultiLevelKvCacheModule from lightllm.server.io_struct import ( + FlushCacheReq, InitWeightsUpdateGroupReq, DestroyWeightsUpdateGroupReq, UpdateWeightsFromDistributedReq, - UpdateWeightsFromTensorReq + UpdateWeightsFromTensorReq, ) @@ -299,36 +301,54 @@ def init_mtp_draft_model(self, main_kvargs: dict): self.logger.info(f"loaded mtp model class {self.draft_models[i].__class__}") return - def flush_radix_cache(self): + def flush_cache(self, request: FlushCacheReq): if self.radix_cache is not None: self.radix_cache.flush_cache() - return + return True, "Succeeded to flush cache." + + def release_memory_occupation(self, tags: List[MemoryTag]): + try: + self.model.release_memory_occupation(tags) + self.flush_cache(request=None) + self.model.req_manager.free_all() + self.model.mem_manager.free_all() + return True, "Succeeded to release memory occupation." + except Exception as e: + self.logger.error(f"release memory occupation failed: {str(e)}") + return False, f"release memory occupation failed: {str(e)}" + + def resume_memory_occupation(self, tags: List[MemoryTag]): + try: + self.model.resume_memory_occupation(tags) + return True, "Succeeded to resume memory occupation." + except Exception as e: + self.logger.error(f"resume memory occupation failed: {str(e)}") + return False, f"resume memory occupation failed: {str(e)}" def init_weights_update_group(self, request: InitWeightsUpdateGroupReq): - assert ( - torch.distributed.is_initialized() - ), "Default torch process group must be initialized" + assert torch.distributed.is_initialized(), "Default torch process group must be initialized" assert request.group_name != "", "Group name cannot be empty" - rank = request.rank_offset + self.rank_in_dp + rank_offset = request.rank_offset + rank = rank_offset + self.rank_in_dp + world_size = request.world_size + group_name = request.group_name self.logger.info( f"init custom process group: master_address={request.master_address}, master_port={request.master_port}, " - f"rank_offset={request.rank_offset}, rank={rank}, world_size={request.world_size}, group_name={request.group_name}, " + f"rank_offset={rank_offset}, rank={rank}, world_size={world_size}, group_name={group_name}, " f" backend={request.backend}" ) try: - if request.group_name in self._model_update_group: - raise ValueError( - f"Process group with name {request.group_name} already exists." - ) + if group_name in self._model_update_group: + raise ValueError(f"Process group with name {group_name} already exists.") - self._model_update_group[request.group_name] = init_custom_process_group( + self._model_update_group[group_name] = init_custom_process_group( backend=request.backend, init_method=f"tcp://{request.master_address}:{request.master_port}", - world_size=request.world_size, + world_size=world_size, rank=rank, - group_name=request.group_name, + group_name=group_name, ) return True, "Succeeded to initialize custom process group." @@ -370,10 +390,8 @@ def update_weights_from_distributed(self, request: UpdateWeightsFromDistributedR weights = [] handles = [] for name, dtype, shape in zip(request.names, request.dtypes, request.shapes): - target_dtype = ( - dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype) - ) - weight = torch.empty(shape, dtype=target_dtype, device='cuda') + target_dtype = dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype) + weight = torch.empty(shape, dtype=target_dtype, device="cuda") handles.append( torch.distributed.broadcast( weight, @@ -420,9 +438,7 @@ def _update_weights_from_flattened_bucket( converted_metadata.append(converted_meta) # Create bucket and reconstruct tensors - bucket = FlattenedTensorBucket( - flattened_tensor=flattened_tensor, metadata=converted_metadata - ) + bucket = FlattenedTensorBucket(flattened_tensor=flattened_tensor, metadata=converted_metadata) reconstructed_tensors = bucket.reconstruct_tensors() # Load the reconstructed tensors using the standard method @@ -430,25 +446,18 @@ def _update_weights_from_flattened_bucket( return True, "Succeeded to update parameter online from flattened bucket tensor." - def update_weights_from_tensor( - self, - request: UpdateWeightsFromTensorReq - ): + def update_weights_from_tensor(self, request: UpdateWeightsFromTensorReq): try: monkey_patch_torch_reductions() if request.load_format == "flattened_bucket": # Handle flattened bucket format - return self._update_weights_from_flattened_bucket( - flattened_tensor_bucket_dict=request.named_tensors - ) + return self._update_weights_from_flattened_bucket(flattened_tensor_bucket_dict=request.named_tensors) # We need to get device after patch otherwise the device would be wrong self.device_module = torch.get_device_module("cuda") infered_device = self.device_module.current_device() - named_tensors=MultiprocessingSerializer.deserialize( - request.serialized_named_tensors[self.rank_in_dp] - ) + named_tensors = MultiprocessingSerializer.deserialize(request.serialized_named_tensors[self.rank_in_dp]) def _unwrap_tensor(tensor, tp_rank, device): if isinstance(tensor, LocalSerializedTensor): @@ -456,7 +465,7 @@ def _unwrap_tensor(tensor, tp_rank, device): return tensor.to(device) named_tensors = { - name : _unwrap_tensor(tensor, tp_rank=self.rank_in_dp, device=infered_device) + name: _unwrap_tensor(tensor, tp_rank=self.rank_in_dp, device=infered_device) for name, tensor in named_tensors } diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py index 04e7495ad..399e9d240 100644 --- a/lightllm/server/router/model_infer/model_rpc.py +++ b/lightllm/server/router/model_infer/model_rpc.py @@ -33,6 +33,7 @@ from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.process_check import start_parent_check_thread from lightllm.utils.envs_utils import get_unique_server_name +from lightllm.utils.torch_memory_saver_utils import MemoryTag from lightllm.server.io_struct import GeneralHttpToModelRpcReq, GeneralModelToHttpRpcRsp logger = init_logger(__name__) @@ -182,13 +183,20 @@ def init_model(self, kvargs): def get_max_total_token_num(self): return self.backend.get_max_total_token_num() - def flush_radix_cache(self): + def release_memory_occupation(self, tags: List[MemoryTag]): try: - if self.backend is not None: - self.backend.flush_radix_cache() + self.backend.release_memory_occupation(tags) return True except BaseException as e: - logger.exception(f"flush radix cache failed: {str(e)}") + logger.exception(f"release memory occupation failed: {str(e)}") + return False + + def resume_memory_occupation(self, tags: List[MemoryTag]): + try: + self.backend.resume_memory_occupation(tags) + return True + except BaseException as e: + logger.exception(f"resume memory occupation failed: {str(e)}") return False def forward_to_model(self, req: GeneralHttpToModelRpcReq) -> GeneralModelToHttpRpcRsp: @@ -199,7 +207,9 @@ def forward_to_model(self, req: GeneralHttpToModelRpcReq) -> GeneralModelToHttpR return GeneralModelToHttpRpcRsp(success=success, msg=str(ret), func_name=req.func_name, func_rsp=ret) except BaseException as e: logger.exception(f"forward to model backend failed: {str(e)}") - return GeneralModelToHttpRpcRsp(success=False, msg=f'forward to model backend failed: {str(e)}', func_name=req.func_name) + return GeneralModelToHttpRpcRsp( + success=False, msg=f"forward to model backend failed: {str(e)}", func_name=req.func_name + ) class ModelRpcClient: @@ -231,16 +241,6 @@ async def get_max_total_token_num(self): assert func_name == "get_max_total_token_num" return ret - def flush_radix_cache(self) -> bool: - self.rpc_shm_params.write_func_params("flush_radix_cache", ()) - self.rpc_event.set() - - self.rpc_finished_event.wait() - self.rpc_finished_event.clear() - func_name, ret = self.rpc_shm_results.read_func_result() - assert func_name == "flush_radix_cache" - return ret - def forward_to_model(self, req: GeneralHttpToModelRpcReq) -> GeneralModelToHttpRpcRsp: self.rpc_shm_params.write_func_params("forward_to_model", (req,)) self.rpc_event.set() @@ -251,6 +251,7 @@ def forward_to_model(self, req: GeneralHttpToModelRpcReq) -> GeneralModelToHttpR assert func_name == "forward_to_model" return ret + def _init_env( args, rank, @@ -313,7 +314,11 @@ async def start_model_process( success_event, ), ) - proc.start() + from lightllm.utils.torch_memory_saver_utils import TorchMemorySaverWrapper + + torch_memory_saver = TorchMemorySaverWrapper(args.enable_torch_memory_saver) + with torch_memory_saver.configure_subprocess(): + proc.start() # Use asyncio.to_thread to make the blocking wait non-blocking await asyncio.to_thread(success_event.wait, timeout=40) diff --git a/lightllm/utils/torch_memory_saver_utils.py b/lightllm/utils/torch_memory_saver_utils.py new file mode 100644 index 000000000..edf15fa83 --- /dev/null +++ b/lightllm/utils/torch_memory_saver_utils.py @@ -0,0 +1,92 @@ +import torch +from contextlib import contextmanager +from enum import Enum +from lightllm.utils.log_utils import init_logger + +try: + from torch_memory_saver import ( + torch_memory_saver, + configure_subprocess, + ) + + HAS_TORCH_MEMORY_SAVER = True + +except ImportError: + HAS_TORCH_MEMORY_SAVER = False + pass + +logger = init_logger(__name__) + + +class MemoryTag(Enum): + KV_CACHE = "kv_cache" + WEIGHT = "weight" + GRAPH = "graph" + + def is_kv_cache(self): + return self == MemoryTag.KV_CACHE + + def is_weight(self): + return self == MemoryTag.WEIGHT + + def is_graph(self): + return self == MemoryTag.GRAPH + + def __str__(self): + return self.value + + +class TorchMemorySaverWrapper: + def __new__(cls, enable_torch_memory_saver: bool = False): + if enable_torch_memory_saver: + assert ( + HAS_TORCH_MEMORY_SAVER + ), "torch_memory_saver is not installed, please install it via `pip install torch_memory_saver`." + return _TorchMemorySaver() + else: + return _TorchMemorySaverFake() + + +class _TorchMemorySaver: + def configure_subprocess(self): + return configure_subprocess() + + def region(self, tag: MemoryTag, enable_cpu_backup: bool = False): + return torch_memory_saver.region(tag=tag.value, enable_cpu_backup=enable_cpu_backup) + + def cuda_graph(self, graph_obj: torch.cuda.CUDAGraph, **kwargs): + return torch_memory_saver.cuda_graph(cuda_graph=graph_obj, **kwargs, tag=MemoryTag.GRAPH.value) + + def disable(self): + return torch_memory_saver.disable() + + def pause(self, tag: MemoryTag): + return torch_memory_saver.pause(tag=tag.value) + + def resume(self, tag: MemoryTag): + return torch_memory_saver.resume(tag=tag.value) + + +class _TorchMemorySaverFake: + @contextmanager + def configure_subprocess(self): + yield + + @contextmanager + def region(self, tag: MemoryTag, enable_cpu_backup: bool = False): + yield + + def cuda_graph(self, graph_obj: torch.cuda.CUDAGraph, **kwargs): + return torch.cuda.graph(graph_obj, **kwargs) + + @contextmanager + def disable(self): + yield + + def pause(self, tag: MemoryTag): + logger.warning("torch_memory_saver is not enabled, pause is not supported.") + return + + def resume(self, tag: MemoryTag): + logger.warning("torch_memory_saver is not enabled, resume is not supported.") + return diff --git a/requirements.txt b/requirements.txt index 40d0b4956..20f27dc05 100644 --- a/requirements.txt +++ b/requirements.txt @@ -87,4 +87,5 @@ librosa==0.11.0 cuda_bindings==12.9.0 orjson==3.11.2 setproctitle==1.3.6 -xxhash==3.6.0 \ No newline at end of file +xxhash==3.6.0 +torch_memory_saver==0.0.9 From 094df8ca341f3ae3d579933321677f7c6ad6424e Mon Sep 17 00:00:00 2001 From: sufubao <47234901+sufubao@users.noreply.github.com> Date: Mon, 8 Dec 2025 19:38:00 +0800 Subject: [PATCH 09/71] use portpicker (#1142) --- lightllm/utils/net_utils.py | 52 ++++++++++++++++++++++++++++--------- 1 file changed, 40 insertions(+), 12 deletions(-) diff --git a/lightllm/utils/net_utils.py b/lightllm/utils/net_utils.py index 20b988875..486414e88 100644 --- a/lightllm/utils/net_utils.py +++ b/lightllm/utils/net_utils.py @@ -2,44 +2,72 @@ import subprocess import ipaddress import random +import portpicker from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) def alloc_can_use_network_port(num=3, used_nccl_ports=None, from_port_num=10000): + if used_nccl_ports is None: + used_nccl_ports = [] + port_list = [] - for port in range(from_port_num, 65536): - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - result = s.connect_ex(("localhost", port)) - if result != 0 and port not in used_nccl_ports: + max_attempts = num * 50 # Allow more attempts to find ports in range + + for _ in range(max_attempts): + if len(port_list) >= num: + break + + try: + port = portpicker.pick_unused_port() + + if port >= from_port_num and port not in used_nccl_ports: port_list.append(port) - if len(port_list) > num * 30: - break + logger.debug(f"Allocated port: {port}") + else: + logger.debug(f"Port {port} is out of range or in used_nccl_ports, skipping") + + except Exception as e: + logger.warning(f"Failed to allocate port: {e}") + continue if len(port_list) < num: + logger.error(f"Failed to allocate {num} ports, only got {len(port_list)}") return None - random.shuffle(port_list) - return port_list[0:num] + logger.info(f"Successfully allocated {len(port_list)} ports: {port_list}") + return port_list def alloc_can_use_port(min_port, max_port): port_list = [] for port in range(min_port, max_port): - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - result = s.connect_ex(("localhost", port)) + try: + test_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + test_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + result = test_socket.connect_ex(("localhost", port)) + test_socket.close() + if result != 0: port_list.append(port) + except Exception: + continue return port_list def find_available_port(start_port, end_port): for port in range(start_port, end_port + 1): - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: - result = sock.connect_ex(("localhost", port)) + try: + test_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + test_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + result = test_socket.connect_ex(("localhost", port)) + test_socket.close() + if result != 0: return port + except Exception: + continue return None From 560be020a0414ae0cb63197a2a1409e6cfd66895 Mon Sep 17 00:00:00 2001 From: shihaobai <42648726+shihaobai@users.noreply.github.com> Date: Mon, 8 Dec 2025 20:52:03 +0800 Subject: [PATCH 10/71] Rl weight (#1143) Co-authored-by: sufubao --- lightllm/common/basemodel/basemodel.py | 10 +- .../layer_weights/meta_weights/__init__.py | 5 +- .../{ => fused_moe}/fused_moe_weight_ep.py | 178 ++++- .../fused_moe_weight_ep_redundancy.py | 9 +- .../fused_moe/fused_moe_weight_tp.py | 325 +++++++++ .../gpt_oss_fused_moe_weight_tp.py | 2 +- .../meta_weights/fused_moe_weight_tp.py | 665 ------------------ .../meta_weights/mm_weight/__init__.py | 9 +- .../meta_weights/mm_weight/colmm_weight.py | 82 +-- .../meta_weights/mm_weight/mm_factory.py | 90 --- .../meta_weights/mm_weight/mm_slicer.py | 18 + .../meta_weights/mm_weight/mm_weight.py | 348 ++------- .../meta_weights/mm_weight/rowmm_weight.py | 88 +-- .../layer_weights/meta_weights/norm_weight.py | 50 +- .../layer_weights/transformer_layer_weight.py | 6 +- lightllm/common/quantization/__init__.py | 5 +- lightllm/common/quantization/awq_quant.py | 139 ++-- .../common/quantization/deepgemm_quant.py | 54 +- lightllm/common/quantization/no_quant.py | 52 ++ .../common/quantization/quantize_method.py | 66 +- lightllm/common/quantization/registry.py | 5 +- lightllm/common/quantization/torchao_quant.py | 9 +- .../fp8/fp8w8a8_block_quant_kernel.py | 2 +- .../fp8/fp8w8a8_scaled_mm_per_token_kernel.py | 471 +++++++++++++ .../quantization/triton_quant/triton_quant.py | 43 +- lightllm/common/quantization/w8a8_quant.py | 100 ++- .../pre_and_post_layer_weight.py | 25 +- .../pre_and_post_layer_weight.py | 17 +- .../pre_and_post_layer_weight.py | 6 +- .../layer_weights/transformer_layer_weight.py | 6 +- .../layer_weights/transformer_layer_weight.py | 8 +- .../pre_and_post_layer_weight.py | 20 +- .../layer_weights/transformer_layer_weight.py | 8 +- .../pre_and_post_layer_weight.py | 4 +- .../layer_weights/transformer_layer_weight.py | 4 +- .../layer_weights/transformer_layer_weight.py | 6 +- .../pre_and_post_layer_weight.py | 20 +- .../pre_and_post_layer_weight.py | 6 +- .../pre_and_post_layer_weight.py | 18 +- .../layer_weights/transformer_layer_weight.py | 53 +- .../pre_and_post_layer_weight.py | 6 +- .../layer_weights/transformer_layer_weight.py | 1 + .../pre_and_post_layer_weight.py | 18 +- .../pre_and_post_layer_weight.py | 21 +- .../layer_weights/transformer_layer_weight.py | 9 - .../pre_and_post_layer_weight.py | 32 +- .../layer_weights/transformer_layer_weight.py | 4 +- .../layer_weights/transformer_layer_weight.py | 5 +- .../pre_and_post_layer_weight.py | 23 +- .../pre_and_post_layer_weight.py | 31 +- .../pre_and_post_layer_weight.py | 25 +- .../pre_and_post_layer_weight.py | 48 +- .../layer_weights/transformer_layer_weight.py | 9 +- .../mode_backend/redundancy_expert_manager.py | 4 +- 54 files changed, 1727 insertions(+), 1541 deletions(-) rename lightllm/common/basemodel/layer_weights/meta_weights/{ => fused_moe}/fused_moe_weight_ep.py (74%) rename lightllm/common/basemodel/layer_weights/meta_weights/{ => fused_moe}/fused_moe_weight_ep_redundancy.py (96%) create mode 100644 lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_tp.py rename lightllm/common/basemodel/layer_weights/meta_weights/{ => fused_moe}/gpt_oss_fused_moe_weight_tp.py (99%) delete mode 100644 lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py delete mode 100644 lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_factory.py create mode 100644 lightllm/common/quantization/no_quant.py create mode 100644 lightllm/common/quantization/triton_quant/fp8/fp8w8a8_scaled_mm_per_token_kernel.py diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index cc50d0a08..3eb5d7dbe 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -120,6 +120,7 @@ def __init__(self, kvargs): self._check_mem_size() with self.torch_memory_saver.region(tag=MemoryTag.KV_CACHE): self._init_req_manager() + self.load_weights(self.weight_dict) self._init_infer_layer() self._init_some_value() self._init_custom() @@ -181,15 +182,6 @@ def _init_weights(self): ) for i in range(self.config["n_layer"]) ] - load_hf_weights( - self.data_type, - weight_dir=self.weight_dir_, - pre_post_layer=self.pre_post_weight, - transformer_layer_list=self.trans_layers_weight, - weight_dict=self.weight_dict, - ) - self.pre_post_weight.verify_load() - [weight.verify_load() for weight in self.trans_layers_weight] return def load_weights(self, weight_dict: dict): diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py index b3dab0614..e4f2beebc 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py @@ -1,11 +1,10 @@ from .base_weight import BaseWeight from .mm_weight import ( - MMWeightPack, MMWeightTpl, ROWMMWeight, COLMMWeight, ROWBMMWeight, ) from .norm_weight import NormWeight, GEMMANormWeight, TpNormWeight -from .fused_moe_weight_tp import create_tp_moe_wegiht_obj -from .fused_moe_weight_ep import FusedMoeWeightEP +from .fused_moe.fused_moe_weight_tp import create_tp_moe_wegiht_obj +from .fused_moe.fused_moe_weight_ep import FusedMoeWeightEP diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_ep.py similarity index 74% rename from lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py rename to lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_ep.py index 87a7b361e..0923d5dea 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_ep.py @@ -3,7 +3,7 @@ import threading from typing import Optional, Tuple, List, Dict, Any from lightllm.utils.dist_utils import get_global_world_size, get_global_rank, get_current_device_id -from .base_weight import BaseWeight +from lightllm.common.basemodel.layer_weights.meta_weights.base_weight import BaseWeight from lightllm.common.fused_moe.grouped_fused_moe_ep import ( fused_experts_impl, masked_group_gemm, @@ -23,6 +23,7 @@ from lightllm.common.basemodel.triton_kernel.redundancy_topk_ids_repair import redundancy_topk_ids_repair from lightllm.utils.log_utils import init_logger from lightllm.common.triton_utils.autotuner import Autotuner +from lightllm.common.quantization.quantize_method import WeightPack logger = init_logger(__name__) @@ -41,6 +42,7 @@ def __init__( network_config: Dict[str, Any], layer_num: int, quant_cfg=None, + hidden_size: Optional[int] = None, ) -> None: super().__init__() @@ -62,6 +64,7 @@ def __init__( self.e_score_correction_bias_name = e_score_correction_bias_name self.n_routed_experts = n_routed_experts self.data_type_ = data_type + self.hidden_size = hidden_size global_world_size = get_global_world_size() self.global_rank_ = get_global_rank() @@ -78,6 +81,7 @@ def __init__( assert self.n_routed_experts % global_world_size == 0 self.ep_n_routed_experts = self.n_routed_experts // global_world_size ep_load_expert_num = self.ep_n_routed_experts + self.redundancy_expert_num + self.ep_load_expert_num = ep_load_expert_num self.experts_up_projs = [None] * ep_load_expert_num self.experts_gate_projs = [None] * ep_load_expert_num self.experts_up_proj_scales = [None] * ep_load_expert_num @@ -105,6 +109,51 @@ def __init__( # auto update redundancy expert vars self.auto_update_redundancy_expert: bool = get_env_start_args().auto_update_redundancy_expert + # Pre-allocate memory if hidden_size is provided + if self.hidden_size is not None: + self._create_weight() + + def _create_weight(self): + """Pre-allocate GPU memory for fused MoE weights""" + if self.hidden_size is None: + return + + total_expert_num = self.ep_load_expert_num + # We need to determine intermediate size from network config or use a default + # This will be updated when first weight is loaded if needed + intermediate_size = getattr(self, "intermediate_size", None) + if intermediate_size is None: + # Default fallback - this will be corrected during load + intermediate_size = self.hidden_size * 4 + + device_id = get_current_device_id() + + if not self.quantized_weight and self.quant_method is not None: + # Quantized weights + w1_pack = self.quant_method.create_weight( + total_expert_num * intermediate_size * 2, self.hidden_size, dtype=self.data_type_, device_id=device_id + ) + self.w1[0] = w1_pack.weight.view(total_expert_num, intermediate_size * 2, self.hidden_size) + self.w1[1] = w1_pack.weight_scale.view(total_expert_num, intermediate_size * 2, self.hidden_size) + + w2_pack = self.quant_method.create_weight( + total_expert_num * self.hidden_size, intermediate_size, dtype=self.data_type_, device_id=device_id + ) + self.w2[0] = w2_pack.weight.view(total_expert_num, self.hidden_size, intermediate_size) + self.w2[1] = w2_pack.weight_scale.view(total_expert_num, self.hidden_size, intermediate_size) + else: + # Regular weights + self.w1[0] = torch.empty( + (total_expert_num, intermediate_size * 2, self.hidden_size), + dtype=self.data_type_, + device=f"cuda:{device_id}", + ) + self.w2[0] = torch.empty( + (total_expert_num, self.hidden_size, intermediate_size), + dtype=self.data_type_, + device=f"cuda:{device_id}", + ) + def experts( self, input_tensor, @@ -422,8 +471,12 @@ def _fuse(self): inter_shape, hidden_size = self.w2_list[0].shape[0], self.w2_list[0].shape[1] w2 = torch._utils._flatten_dense_tensors(self.w2_list).view(len(self.w2_list), inter_shape, hidden_size) if not self.quantized_weight and self.quant_method is not None: - self.w1 = self.quant_method.quantize(w1) - self.w2 = self.quant_method.quantize(w2) + qw1_pack = self.quant_method.quantize(w1) + qw2_pack = self.quant_method.quantize(w2) + self.w1[0] = qw1_pack.weight + self.w1[1] = qw1_pack.weight_scale + self.w2[0] = qw2_pack.weight + self.w2[1] = qw2_pack.weight_scale else: self.w1[0] = self._cuda(w1) self.w2[0] = self._cuda(w2) @@ -465,38 +518,74 @@ def _fuse_weight_scale(self): def load_hf_weights(self, weights): n_expert_ep = self.ep_n_routed_experts - # tp to ep here + + # Load bias if self.e_score_correction_bias_name in weights: self.e_score_correction_bias = self._cuda(weights[self.e_score_correction_bias_name]) + # Get weight shapes from first expert to determine intermediate size + first_expert_idx = 0 + n_expert_ep * self.global_rank_ + w1_weight_name = f"{self.weight_prefix}.{first_expert_idx}.{self.w1_weight_name}.weight" + if w1_weight_name in weights: + intermediate_size = weights[w1_weight_name].shape[0] + self.intermediate_size = intermediate_size + + # Re-create weights with correct size if needed + if self.w1[0].shape[1] != intermediate_size * 2: + self._create_weight() + + # Load regular experts for i_experts_ep in range(n_expert_ep): i_experts = i_experts_ep + n_expert_ep * self.global_rank_ - w1_weight = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.weight" - w2_weight = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.weight" - w3_weight = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.weight" - if w1_weight in weights: - self.experts_gate_projs[i_experts_ep] = weights[w1_weight] - if w3_weight in weights: - self.experts_up_projs[i_experts_ep] = weights[w3_weight] - if w2_weight in weights: - self.w2_list[i_experts_ep] = weights[w2_weight] - - # Load weight parameters for redundant experts + self._copy_expert_weights(i_experts_ep, i_experts, weights) + + # Load redundant experts for i, redundant_expert_id in enumerate(self.redundancy_expert_ids): - i_experts = redundant_expert_id - w1_weight = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.weight" - w2_weight = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.weight" - w3_weight = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.weight" - if w1_weight in weights: - self.experts_gate_projs[n_expert_ep + i] = weights[w1_weight] - if w3_weight in weights: - self.experts_up_projs[n_expert_ep + i] = weights[w3_weight] - if w2_weight in weights: - self.w2_list[n_expert_ep + i] = weights[w2_weight] + self._copy_expert_weights(n_expert_ep + i, redundant_expert_id, weights) if self.quantized_weight: - self._load_weight_scale(weights) - self._fuse() + self._load_weight_scale_direct(weights) + + def _copy_expert_weights(self, target_idx, expert_id, weights): + """Copy a single expert's weights to pre-allocated GPU memory""" + w1_weight = f"{self.weight_prefix}.{expert_id}.{self.w1_weight_name}.weight" + w2_weight = f"{self.weight_prefix}.{expert_id}.{self.w2_weight_name}.weight" + w3_weight = f"{self.weight_prefix}.{expert_id}.{self.w3_weight_name}.weight" + + intermediate_size = self.intermediate_size + + if w1_weight in weights and w3_weight in weights: + # Combine gate and up projections into w1 + gate_weight = weights[w1_weight] # [intermediate_size, hidden_size] + up_weight = weights[w3_weight] # [intermediate_size, hidden_size] + + # Copy to pre-allocated memory + if not self.quantized_weight and self.quant_method is not None: + # Quantized path + combined_cpu = torch.empty((intermediate_size * 2, self.hidden_size), dtype=gate_weight.dtype) + combined_cpu[:intermediate_size, :] = gate_weight + combined_cpu[intermediate_size:, :] = up_weight + quantized_pack = self.quant_method.quantize(combined_cpu) + self.w1[0][target_idx].copy_(quantized_pack.weight.view(intermediate_size * 2, self.hidden_size)) + if quantized_pack.weight_scale is not None: + self.w1[1][target_idx].copy_( + quantized_pack.weight_scale.view(intermediate_size * 2, self.hidden_size) + ) + else: + # Regular path + self.w1[0][target_idx, :intermediate_size, :].copy_(gate_weight) + self.w1[0][target_idx, intermediate_size:, :].copy_(up_weight) + + if w2_weight in weights: + # Copy w2 (down projection) + w2_weight_tensor = weights[w2_weight] # [hidden_size, intermediate_size] - already the correct shape + if not self.quantized_weight and self.quant_method is not None: + quantized_pack = self.quant_method.quantize(w2_weight_tensor) + self.w2[0][target_idx].copy_(quantized_pack.weight) + if quantized_pack.weight_scale is not None: + self.w2[1][target_idx].copy_(quantized_pack.weight_scale) + else: + self.w2[0][target_idx].copy_(w2_weight_tensor) def _load_weight_scale(self, weights: Dict[str, torch.Tensor]) -> None: n_expert_ep = self.ep_n_routed_experts @@ -526,6 +615,41 @@ def _load_weight_scale(self, weights: Dict[str, torch.Tensor]) -> None: if w2_scale in weights: self.w2_scale_list[n_expert_ep + i] = weights[w2_scale] + def _load_weight_scale_direct(self, weights: Dict[str, torch.Tensor]) -> None: + """Load weight scales directly to pre-allocated GPU memory""" + n_expert_ep = self.ep_n_routed_experts + + # Load regular expert scales + for i_experts_ep in range(n_expert_ep): + i_experts = i_experts_ep + n_expert_ep * self.global_rank_ + self._copy_expert_scales(i_experts_ep, i_experts, weights) + + # Load redundant expert scales + for i, redundant_expert_id in enumerate(self.redundancy_expert_ids): + self._copy_expert_scales(n_expert_ep + i, redundant_expert_id, weights) + + def _copy_expert_scales(self, target_idx, expert_id, weights): + """Copy a single expert's weight scales to pre-allocated GPU memory""" + w1_scale = f"{self.weight_prefix}.{expert_id}.{self.w1_weight_name}.{self.weight_scale_suffix}" + w2_scale = f"{self.weight_prefix}.{expert_id}.{self.w2_weight_name}.{self.weight_scale_suffix}" + w3_scale = f"{self.weight_prefix}.{expert_id}.{self.w3_weight_name}.{self.weight_scale_suffix}" + + intermediate_size = self.intermediate_size + + if w1_scale in weights and w3_scale in weights: + # Combine gate and up projection scales into w1 scale + gate_scale = weights[w1_scale] # [intermediate_size, hidden_size] + up_scale = weights[w3_scale] # [intermediate_size, hidden_size] + + # Copy to pre-allocated memory + self.w1[1][target_idx, :intermediate_size, :].copy_(gate_scale) + self.w1[1][target_idx, intermediate_size:, :].copy_(up_scale) + + if w2_scale in weights: + # Copy w2 scale (down projection) + w2_scale_tensor = weights[w2_scale] # [hidden_size, intermediate_size] + self.w2[1][target_idx].copy_(w2_scale_tensor) + def _cuda(self, cpu_tensor): device_id = get_current_device_id() if self.quantized_weight: diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep_redundancy.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_ep_redundancy.py similarity index 96% rename from lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep_redundancy.py rename to lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_ep_redundancy.py index 5558070a2..933a94f78 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep_redundancy.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_ep_redundancy.py @@ -102,12 +102,15 @@ def _fuse(self): inter_shape, hidden_size = self.w2_list[0].shape[0], self.w2_list[0].shape[1] w2 = torch._utils._flatten_dense_tensors(self.w2_list).view(len(self.w2_list), inter_shape, hidden_size) if not self._ep_w.quantized_weight and self._ep_w.quant_method is not None: - self.w1 = self._ep_w.quant_method.quantize(w1) - self.w2 = self._ep_w.quant_method.quantize(w2) + qw1_pack = self._ep_w.quant_method.quantize(w1) + qw2_pack = self._ep_w.quant_method.quantize(w2) + self.w1[0] = qw1_pack.weight + self.w1[1] = qw1_pack.weight_scale + self.w2[0] = qw2_pack.weight + self.w2[1] = qw2_pack.weight_scale else: self.w1[0] = w1 self.w2[0] = w2 - delattr(self, "w2_list") delattr(self, "experts_up_projs") delattr(self, "experts_gate_projs") diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_tp.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_tp.py new file mode 100644 index 000000000..bf7b218b7 --- /dev/null +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_tp.py @@ -0,0 +1,325 @@ +import os +import torch +import threading +from typing import Tuple, List, Dict, Any, Union, Callable +from lightllm.common.basemodel.layer_weights.meta_weights.base_weight import BaseWeight +from lightllm.utils.dist_utils import get_current_rank_in_dp, get_current_device_id, get_dp_world_size +from lightllm.common.quantization import Quantcfg +from lightllm.common.quantization.quantize_method import WeightPack +from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_slicer import ( + get_row_slice_mixin, + get_col_slice_mixin, +) + + +def create_tp_moe_wegiht_obj( + gate_proj_name: str, + down_proj_name: str, + up_proj_name: str, + e_score_correction_bias_name: str, + weight_prefix: str, + n_routed_experts: int, + num_fused_shared_experts: int, + split_inter_size: int, + data_type: torch.dtype, + network_config: Dict[str, Any], + layer_num: int, + quant_cfg: Quantcfg = None, +) -> Union["FusedMoeWeightTP", "FusedAWQMARLINMoeWeightTP"]: + quant_method = quant_cfg.get_quant_method(layer_num, "fused_moe") + if quant_method is not None and quant_method.method_name == "awq_marlin": + return FusedAWQMARLINMoeWeightTP( + gate_proj_name=gate_proj_name, + down_proj_name=down_proj_name, + up_proj_name=up_proj_name, + e_score_correction_bias_name=e_score_correction_bias_name, + weight_prefix=weight_prefix, + n_routed_experts=n_routed_experts, + num_fused_shared_experts=num_fused_shared_experts, + split_inter_size=split_inter_size, + data_type=data_type, + network_config=network_config, + layer_num=layer_num, + quant_cfg=quant_cfg, + ) + else: + return FusedMoeWeightTP( + gate_proj_name=gate_proj_name, + down_proj_name=down_proj_name, + up_proj_name=up_proj_name, + e_score_correction_bias_name=e_score_correction_bias_name, + weight_prefix=weight_prefix, + n_routed_experts=n_routed_experts, + num_fused_shared_experts=num_fused_shared_experts, + split_inter_size=split_inter_size, + data_type=data_type, + network_config=network_config, + layer_num=layer_num, + quant_cfg=quant_cfg, + ) + + +class FusedMoeWeightTP(BaseWeight): + def __init__( + self, + gate_proj_name: str, + down_proj_name: str, + up_proj_name: str, + e_score_correction_bias_name: str, + weight_prefix: str, + n_routed_experts: int, + num_fused_shared_experts: int, + split_inter_size: int, + data_type: torch.dtype, + network_config: Dict[str, Any], + layer_num: int, + quant_cfg: Quantcfg = None, + ) -> None: + super().__init__() + self.quant_method = quant_cfg.get_quant_method(layer_num, "fused_moe") + self.quantized_weight = quant_cfg.quantized_weight + if self.quant_method.method_name != "none": + self.weight_scale_suffix = self.quant_method.weight_scale_suffix + + self.w1_weight_name = gate_proj_name + self.w2_weight_name = down_proj_name + self.w3_weight_name = up_proj_name + + self.e_score_correction_bias_name = e_score_correction_bias_name + self.weight_prefix = weight_prefix + assert num_fused_shared_experts in [0, 1], "num_fused_shared_experts can only support 0 or 1 now." + self.n_routed_experts = n_routed_experts + num_fused_shared_experts + self.num_fused_shared_experts = num_fused_shared_experts + self.routed_scaling_factor = network_config.get("routed_scaling_factor", 1.0) + self.split_inter_size = split_inter_size + self.data_type_ = data_type + self.hidden_size = network_config.get("hidden_size") + self.tp_rank_ = get_current_rank_in_dp() + self.e_score_correction_bias = None + self.scoring_func = network_config.get("scoring_func", "softmax") + self.row_slicer = get_row_slice_mixin( + self.quant_method.method_name, tp_rank=self.tp_rank_, tp_world_size=get_dp_world_size() + ) + self.col_slicer = get_col_slice_mixin( + self.quant_method.method_name, tp_rank=self.tp_rank_, tp_world_size=get_dp_world_size() + ) + self._create_weight() + + def _create_weight(self): + total_expert_num = self.n_routed_experts + intermediate_size = self.split_inter_size + device_id = get_current_device_id() + + # Create e_score_correction_bias + if self.e_score_correction_bias is not None: + self.e_score_correction_bias = torch.empty( + (total_expert_num,), + dtype=self.data_type_, + device=f"cuda:{device_id}", + ) + + self.w13: WeightPack = self.quant_method.create_weight( + out_dim=intermediate_size * 2, + in_dim=self.hidden_size, + dtype=self.data_type_, + device_id=device_id, + num_experts=total_expert_num, + ) + self.w2: WeightPack = self.quant_method.create_weight( + out_dim=self.hidden_size, + in_dim=intermediate_size, + dtype=self.data_type_, + device_id=device_id, + num_experts=total_expert_num, + ) + + def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group): + from lightllm.common.fused_moe.topk_select import select_experts + + topk_weights, topk_ids = select_experts( + hidden_states=input_tensor, + router_logits=router_logits, + correction_bias=self.e_score_correction_bias, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + scoring_func=self.scoring_func, + ) + topk_weights.mul_(self.routed_scaling_factor) + if self.num_fused_shared_experts > 0: + pad_topk_ids = ( + torch.arange( + start=self.n_routed_experts - self.num_fused_shared_experts, + end=self.n_routed_experts, + step=1, + dtype=topk_ids.dtype, + device="cuda", + ) + .view(1, self.num_fused_shared_experts) + .repeat(topk_ids.shape[0], 1) + ) + pad_topk_weights = torch.full( + (topk_weights.shape[0], self.num_fused_shared_experts), + fill_value=1.0, + device="cuda", + dtype=topk_weights.dtype, + ) + + topk_ids = torch.cat([topk_ids, pad_topk_ids], dim=1) + topk_weights = torch.cat([topk_weights, pad_topk_weights], dim=1) + + w13, w13_scale = self.w13.weight, self.w13.weight_scale + w2, w2_scale = self.w2.weight, self.w2.weight_scale + use_fp8_w8a8 = self.quant_method.method_name != "none" + + from lightllm.common.fused_moe.grouped_fused_moe import fused_experts + + fused_experts( + hidden_states=input_tensor, + w1=w13, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + use_fp8_w8a8=use_fp8_w8a8, + w1_scale=w13_scale, + w2_scale=w2_scale, + ) + return + + def _cuda(self, cpu_tensor): + device_id = get_current_device_id() + if self.quantized_weight: + return cpu_tensor.cuda(device_id) + return cpu_tensor.cuda(device_id) + + def verify_load(self): + return True + + def load_hf_weights(self, weights): + # Load bias + if self.e_score_correction_bias_name in weights: + self.e_score_correction_bias.copy_(weights[self.e_score_correction_bias_name]) + + # Load each expert with TP slicing + for i_experts in range(self.n_routed_experts): + self._load_expert(i_experts, weights, type="weight", suffix=self.quant_method.weight_suffix) + if self.w13.weight_scale is not None: + self._load_expert(i_experts, weights, type="weight_scale", suffix=self.quant_method.weight_scale_suffix) + if self.w13.weight_zero_point is not None: + self._load_expert( + i_experts, weights, type="weight_zero_point", suffix=self.quant_method.weight_zero_point_suffix + ) + + def _load_weight_func(self, weight: torch.Tensor, weight_pack: WeightPack, start_idx: int = 0): + if self.quant_method.weight_need_quanted(weight): + self.quant_method.quantize(weight, weight_pack, start_idx) + else: + self.quant_method.load_weight(weight, weight_pack, start_idx) + + def _load_expert(self, expert_idx, weights, type: str, suffix: str = "weight"): + w1_weight = f"{self.weight_prefix}.{expert_idx}.{self.w1_weight_name}.{suffix}" + w2_weight = f"{self.weight_prefix}.{expert_idx}.{self.w2_weight_name}.{suffix}" + w3_weight = f"{self.weight_prefix}.{expert_idx}.{self.w3_weight_name}.{suffix}" + intermediate_size = self.split_inter_size + load_func, slice_func = self._get_load_and_slice_func(type, is_row=True) + if w1_weight in weights: + load_func(slice_func(weights[w1_weight]), self.w13.get_expert(expert_idx), start_idx=0) + if w3_weight in weights: + load_func(slice_func(weights[w3_weight]), self.w13.get_expert(expert_idx), start_idx=intermediate_size) + + load_func, slice_func = self._get_load_and_slice_func(type, is_row=False) + if w2_weight in weights: + load_func(slice_func(weights[w2_weight]), self.w2.get_expert(expert_idx), start_idx=0) + + def _get_load_and_slice_func(self, type: str, is_row: bool = True): + if is_row: + slicer = self.row_slicer + else: + slicer = self.col_slicer + if type == "weight": + return self._load_weight_func, slicer._slice_weight + elif type == "weight_scale": + return getattr(self.quant_method, "load_weight_scale"), slicer._slice_weight_scale + elif type == "weight_zero_point": + return getattr(self.quant_method, "load_weight_zero_point"), slicer._slice_weight_zero_point + + +class FusedAWQMARLINMoeWeightTP(FusedMoeWeightTP): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops + + assert HAS_VLLM, "moe awq marlin quantization requires kernels of vllm" + from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + marlin_make_workspace_new, + ) + + self.workspace = marlin_make_workspace_new(self.w13.weight.device, 4) + + def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group): + from lightllm.common.fused_moe.topk_select import select_experts + + topk_weights, topk_ids = select_experts( + hidden_states=input_tensor, + router_logits=router_logits, + correction_bias=self.e_score_correction_bias, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + scoring_func=self.scoring_func, + ) + topk_weights.mul_(self.routed_scaling_factor) + if self.num_fused_shared_experts > 0: + pad_topk_ids = ( + torch.arange( + start=self.n_routed_experts - self.num_fused_shared_experts, + end=self.n_routed_experts, + step=1, + dtype=topk_ids.dtype, + device="cuda", + ) + .view(1, self.num_fused_shared_experts) + .repeat(topk_ids.shape[0], 1) + ) + pad_topk_weights = torch.full( + (topk_weights.shape[0], self.num_fused_shared_experts), + fill_value=1.0, + device="cuda", + dtype=topk_weights.dtype, + ) + + topk_ids = torch.cat([topk_ids, pad_topk_ids], dim=1) + topk_weights = torch.cat([topk_weights, pad_topk_weights], dim=1) + + w1, w1_scale, w1_zero_point = self.w13.weight, self.w13.weight_scale, self.w13.weight_zero_point + w2, w2_scale, w2_zero_point = self.w2.weight, self.w2.weight_scale, self.w2.weight_zero_point + + from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe + + fused_marlin_moe( + input_tensor, + w1, + w2, + None, + None, + w1_scale, + w2_scale, + router_logits, + topk_weights, + topk_ids, + quant_type_id=self.quant_method.vllm_quant_type.id, + apply_router_weight_on_input=False, + global_num_experts=-1, + expert_map=None, + w1_zeros=w1_zero_point, + w2_zeros=w2_zero_point, + workspace=self.workspace, + inplace=True, + ) + + return diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/gpt_oss_fused_moe_weight_tp.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py similarity index 99% rename from lightllm/common/basemodel/layer_weights/meta_weights/gpt_oss_fused_moe_weight_tp.py rename to lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py index df72cc620..9d79ff7c2 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/gpt_oss_fused_moe_weight_tp.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py @@ -3,7 +3,7 @@ import threading from typing import Optional, Tuple, List, Dict, Any -from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe_weight_tp import FusedMoeWeightTP +from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe.fused_moe_weight_tp import FusedMoeWeightTP from lightllm.utils.dist_utils import get_current_rank_in_dp, get_current_device_id from lightllm.common.quantization import Quantcfg from lightllm.utils.log_utils import init_logger diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py deleted file mode 100644 index 0449db344..000000000 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py +++ /dev/null @@ -1,665 +0,0 @@ -import os -import torch -import threading -from typing import Optional, Tuple, List, Dict, Any, Union -from .base_weight import BaseWeight -from lightllm.utils.dist_utils import get_current_rank_in_dp, get_current_device_id -from lightllm.common.quantization import Quantcfg - - -def create_tp_moe_wegiht_obj( - gate_proj_name: str, - down_proj_name: str, - up_proj_name: str, - e_score_correction_bias_name: str, - weight_prefix: str, - n_routed_experts: int, - num_fused_shared_experts: int, - split_inter_size: int, - data_type: torch.dtype, - network_config: Dict[str, Any], - layer_num: int, - quant_cfg: Quantcfg = None, -) -> Union["FusedMoeWeightTP", "FusedAWQMARLINMoeWeightTP"]: - quant_method = quant_cfg.get_quant_method(layer_num, "fused_moe") - if quant_method is not None and quant_method.method_name == "awq_marlin": - return FusedAWQMARLINMoeWeightTP( - gate_proj_name=gate_proj_name, - down_proj_name=down_proj_name, - up_proj_name=up_proj_name, - e_score_correction_bias_name=e_score_correction_bias_name, - weight_prefix=weight_prefix, - n_routed_experts=n_routed_experts, - num_fused_shared_experts=num_fused_shared_experts, - split_inter_size=split_inter_size, - data_type=data_type, - network_config=network_config, - layer_num=layer_num, - quant_cfg=quant_cfg, - ) - else: - return FusedMoeWeightTP( - gate_proj_name=gate_proj_name, - down_proj_name=down_proj_name, - up_proj_name=up_proj_name, - e_score_correction_bias_name=e_score_correction_bias_name, - weight_prefix=weight_prefix, - n_routed_experts=n_routed_experts, - num_fused_shared_experts=num_fused_shared_experts, - split_inter_size=split_inter_size, - data_type=data_type, - network_config=network_config, - layer_num=layer_num, - quant_cfg=quant_cfg, - ) - - -class FusedMoeWeightTP(BaseWeight): - def __init__( - self, - gate_proj_name: str, - down_proj_name: str, - up_proj_name: str, - e_score_correction_bias_name: str, - weight_prefix: str, - n_routed_experts: int, - num_fused_shared_experts: int, - split_inter_size: int, - data_type: torch.dtype, - network_config: Dict[str, Any], - layer_num: int, - quant_cfg: Quantcfg = None, - ) -> None: - super().__init__() - self.quant_method = quant_cfg.get_quant_method(layer_num, "fused_moe") - self.quantized_weight = quant_cfg.quantized_weight - if self.quant_method is not None: - self.weight_scale_suffix = self.quant_method.weight_scale_suffix - self.quant_method.is_moe = True - self.w1_weight_name = gate_proj_name - self.w2_weight_name = down_proj_name - self.w3_weight_name = up_proj_name - - self.e_score_correction_bias_name = e_score_correction_bias_name - self.weight_prefix = weight_prefix - assert num_fused_shared_experts in [0, 1], "num_fused_shared_experts can only support 0 or 1 now." - self.n_routed_experts = n_routed_experts + num_fused_shared_experts - self.num_fused_shared_experts = num_fused_shared_experts - self.routed_scaling_factor = network_config.get("routed_scaling_factor", 1.0) - self.split_inter_size = split_inter_size - self.data_type_ = data_type - self.tp_rank_ = get_current_rank_in_dp() - self.experts_up_projs = [None] * self.n_routed_experts - self.experts_gate_projs = [None] * self.n_routed_experts - self.experts_up_proj_scales = [None] * self.n_routed_experts - self.experts_gate_proj_scales = [None] * self.n_routed_experts - self.e_score_correction_bias = None - self.w2_list = [None] * self.n_routed_experts - self.w2_scale_list = [None] * self.n_routed_experts - self.scoring_func = network_config.get("scoring_func", "softmax") - self.w1 = [None, None] # weight, weight_scale - self.w2 = [None, None] # weight, weight_scale - self.lock = threading.Lock() - - def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group): - from lightllm.common.fused_moe.topk_select import select_experts - - topk_weights, topk_ids = select_experts( - hidden_states=input_tensor, - router_logits=router_logits, - correction_bias=self.e_score_correction_bias, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - scoring_func=self.scoring_func, - ) - topk_weights.mul_(self.routed_scaling_factor) - if self.num_fused_shared_experts > 0: - pad_topk_ids = ( - torch.arange( - start=self.n_routed_experts - self.num_fused_shared_experts, - end=self.n_routed_experts, - step=1, - dtype=topk_ids.dtype, - device="cuda", - ) - .view(1, self.num_fused_shared_experts) - .repeat(topk_ids.shape[0], 1) - ) - pad_topk_weights = torch.full( - (topk_weights.shape[0], self.num_fused_shared_experts), - fill_value=1.0, - device="cuda", - dtype=topk_weights.dtype, - ) - - topk_ids = torch.cat([topk_ids, pad_topk_ids], dim=1) - topk_weights = torch.cat([topk_weights, pad_topk_weights], dim=1) - - w1, w1_scale = self.w1 - w2, w2_scale = self.w2 - use_fp8_w8a8 = self.quant_method is not None - - from lightllm.common.fused_moe.grouped_fused_moe import fused_experts - - fused_experts( - hidden_states=input_tensor, - w1=w1, - w2=w2, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - use_fp8_w8a8=use_fp8_w8a8, - w1_scale=w1_scale, - w2_scale=w2_scale, - ) - return - - def _fuse(self): - if self.quantized_weight: - self._fuse_weight_scale() - with self.lock: - if ( - hasattr(self, "experts_up_projs") - and None not in self.experts_up_projs - and None not in self.experts_gate_projs - and None not in self.w2_list - ): - gate_out_dim, gate_in_dim = self.experts_gate_projs[0].shape - up_out_dim, up_in_dim = self.experts_up_projs[0].shape - assert gate_in_dim == up_in_dim - dtype = self.experts_gate_projs[0].dtype - total_expert_num = self.n_routed_experts - - w1 = torch.empty((total_expert_num, gate_out_dim + up_out_dim, gate_in_dim), dtype=dtype, device="cpu") - - for i_experts in range(self.n_routed_experts): - w1[i_experts, 0:gate_out_dim:, :] = self.experts_gate_projs[i_experts] - w1[i_experts, gate_out_dim:, :] = self.experts_up_projs[i_experts] - - inter_shape, hidden_size = self.w2_list[0].shape[0], self.w2_list[0].shape[1] - w2 = torch._utils._flatten_dense_tensors(self.w2_list).view(len(self.w2_list), inter_shape, hidden_size) - if not self.quantized_weight and self.quant_method is not None: - self.w1 = self.quant_method.quantize(w1) - self.w2 = self.quant_method.quantize(w2) - else: - self.w1[0] = self._cuda(w1) - self.w2[0] = self._cuda(w2) - delattr(self, "w2_list") - delattr(self, "experts_up_projs") - delattr(self, "experts_gate_projs") - - def _fuse_weight_scale(self): - with self.lock: - if ( - hasattr(self, "experts_up_proj_scales") - and None not in self.experts_up_proj_scales - and None not in self.experts_gate_proj_scales - and None not in self.w2_scale_list - ): - gate_out_dim, gate_in_dim = self.experts_gate_proj_scales[0].shape - up_out_dim, up_in_dim = self.experts_up_proj_scales[0].shape - assert gate_in_dim == up_in_dim - dtype = self.experts_gate_proj_scales[0].dtype - total_expert_num = self.n_routed_experts - - w1_scale = torch.empty( - (total_expert_num, gate_out_dim + up_out_dim, gate_in_dim), dtype=dtype, device="cpu" - ) - - for i_experts in range(self.n_routed_experts): - w1_scale[i_experts, 0:gate_out_dim:, :] = self.experts_gate_proj_scales[i_experts] - w1_scale[i_experts, gate_out_dim:, :] = self.experts_up_proj_scales[i_experts] - inter_shape, hidden_size = self.w2_scale_list[0].shape[0], self.w2_scale_list[0].shape[1] - w2_scale = torch._utils._flatten_dense_tensors(self.w2_scale_list).view( - len(self.w2_scale_list), inter_shape, hidden_size - ) - self.w1[1] = self._cuda(w1_scale) - self.w2[1] = self._cuda(w2_scale) - delattr(self, "w2_scale_list") - delattr(self, "experts_up_proj_scales") - delattr(self, "experts_gate_proj_scales") - - def load_hf_weights(self, weights): - if self.e_score_correction_bias_name in weights: - self.e_score_correction_bias = self._cuda(weights[self.e_score_correction_bias_name]) - for i_experts in range(self.n_routed_experts): - w1_weight = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.weight" - w2_weight = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.weight" - w3_weight = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.weight" - - if w1_weight in weights: - self.experts_gate_projs[i_experts] = weights[w1_weight][ - self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), : - ] - if w3_weight in weights: - self.experts_up_projs[i_experts] = weights[w3_weight][ - self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), : - ] - - if w2_weight in weights: - self.w2_list[i_experts] = weights[w2_weight][ - :, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1) - ] - if self.quant_method is not None: - self._load_weight_scale(weights) - self._fuse() - - def _load_weight_scale(self, weights: Dict[str, torch.Tensor]) -> None: - block_size = 1 - if hasattr(self.quant_method, "block_size"): - block_size = self.quant_method.block_size - for i_experts in range(self.n_routed_experts): - w1_scale = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.{self.weight_scale_suffix}" - w2_scale = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.{self.weight_scale_suffix}" - w3_scale = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.{self.weight_scale_suffix}" - if w1_scale in weights: - self.experts_gate_proj_scales[i_experts] = weights[w1_scale][ - self.split_inter_size - // block_size - * self.tp_rank_ : self.split_inter_size - // block_size - * (self.tp_rank_ + 1), - :, - ] - if w3_scale in weights: - self.experts_up_proj_scales[i_experts] = weights[w3_scale][ - self.split_inter_size - // block_size - * self.tp_rank_ : self.split_inter_size - // block_size - * (self.tp_rank_ + 1), - :, - ] - - if w2_scale in weights: - self.w2_scale_list[i_experts] = weights[w2_scale][ - :, - self.split_inter_size - // block_size - * self.tp_rank_ : self.split_inter_size - // block_size - * (self.tp_rank_ + 1), - ] - - def _cuda(self, cpu_tensor): - device_id = get_current_device_id() - if self.quantized_weight: - return cpu_tensor.contiguous().cuda(device_id) - return cpu_tensor.contiguous().to(self.data_type_).cuda(device_id) - - def verify_load(self): - return self.w1 is not None and self.w2 is not None - - -class FusedAWQMARLINMoeWeightTP(BaseWeight): - def __init__( - self, - gate_proj_name: str, - down_proj_name: str, - up_proj_name: str, - e_score_correction_bias_name: str, - weight_prefix: str, - n_routed_experts: int, - num_fused_shared_experts: int, - split_inter_size: int, - data_type: torch.dtype, - network_config: Dict[str, Any], - layer_num: int, - quant_cfg: Quantcfg = None, - ) -> None: - super().__init__() - self.quant_method = quant_cfg.get_quant_method(layer_num, "fused_moe") - self.quantized_weight = quant_cfg.quantized_weight - if self.quant_method is not None: - self.weight_scale_suffix = self.quant_method.weight_scale_suffix - self.weight_zero_point_suffix = self.quant_method.weight_zero_point_suffix - self.quant_method.is_moe = True - hf_quantization_config = network_config.get("quantization_config", None) - self.num_bits = hf_quantization_config.get("bits", 4) - self.group_size = hf_quantization_config.get("group_size", 128) - self.pack_factor = 32 // self.num_bits - self.has_processed_weight = False - assert self.quant_method.method_name == "awq_marlin" - - self.w1_weight_name = gate_proj_name - self.w2_weight_name = down_proj_name - self.w3_weight_name = up_proj_name - - self.e_score_correction_bias_name = e_score_correction_bias_name - self.weight_prefix = weight_prefix - assert num_fused_shared_experts in [0, 1], "num_fused_shared_experts can only support 0 or 1 now." - self.n_routed_experts = n_routed_experts + num_fused_shared_experts - self.num_fused_shared_experts = num_fused_shared_experts - self.routed_scaling_factor = network_config.get("routed_scaling_factor", 1.0) - self.split_inter_size = split_inter_size - self.data_type_ = data_type - self.tp_rank_ = get_current_rank_in_dp() - self.experts_up_projs = [None] * self.n_routed_experts - self.experts_gate_projs = [None] * self.n_routed_experts - self.experts_up_proj_scales = [None] * self.n_routed_experts - self.experts_up_proj_zero_points = [None] * self.n_routed_experts - self.experts_gate_proj_scales = [None] * self.n_routed_experts - self.experts_gate_proj_zero_points = [None] * self.n_routed_experts - self.e_score_correction_bias = None - self.w2_list = [None] * self.n_routed_experts - self.w2_scale_list = [None] * self.n_routed_experts - self.w2_zero_point_list = [None] * self.n_routed_experts - self.scoring_func = network_config.get("scoring_func", "softmax") - self.w1 = [None, None, None] # weight, weight_scale, zero_point - self.w2 = [None, None, None] # weight, weight_scale, zero_point - self.lock = threading.Lock() - - def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group): - from lightllm.common.fused_moe.topk_select import select_experts - - topk_weights, topk_ids = select_experts( - hidden_states=input_tensor, - router_logits=router_logits, - correction_bias=self.e_score_correction_bias, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - scoring_func=self.scoring_func, - ) - topk_weights.mul_(self.routed_scaling_factor) - if self.num_fused_shared_experts > 0: - pad_topk_ids = ( - torch.arange( - start=self.n_routed_experts - self.num_fused_shared_experts, - end=self.n_routed_experts, - step=1, - dtype=topk_ids.dtype, - device="cuda", - ) - .view(1, self.num_fused_shared_experts) - .repeat(topk_ids.shape[0], 1) - ) - pad_topk_weights = torch.full( - (topk_weights.shape[0], self.num_fused_shared_experts), - fill_value=1.0, - device="cuda", - dtype=topk_weights.dtype, - ) - - topk_ids = torch.cat([topk_ids, pad_topk_ids], dim=1) - topk_weights = torch.cat([topk_weights, pad_topk_weights], dim=1) - - w1, w1_scale, w1_zero_point = self.w1 - w2, w2_scale, w2_zero_point = self.w2 - - from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe - - fused_marlin_moe( - input_tensor, - w1, - w2, - None, - None, - w1_scale, - w2_scale, - router_logits, - topk_weights, - topk_ids, - quant_type_id=self.quant_method.vllm_quant_type.id, - apply_router_weight_on_input=False, - global_num_experts=-1, - expert_map=None, - w1_zeros=w1_zero_point, - w2_zeros=w2_zero_point, - workspace=self.workspace, - inplace=True, - ) - - return - - def _fuse(self): - self._fuse_weight() - self._fuse_weight_scale() - self._fuse_weight_zero_point() - - def _fuse_weight(self): - with self.lock: - if ( - hasattr(self, "experts_up_projs") - and None not in self.experts_up_projs - and None not in self.experts_gate_projs - and None not in self.w2_list - ): - gate_in_dim, gate_out_dim = self.experts_gate_projs[0].shape - up_in_dim, up_out_dim = self.experts_up_projs[0].shape - assert gate_in_dim == up_in_dim - total_expert_num = self.n_routed_experts - - w1 = torch.empty( - (total_expert_num, gate_in_dim, gate_out_dim + up_out_dim), dtype=torch.int32, device="cpu" - ) - - for i_experts in range(self.n_routed_experts): - w1[i_experts, :, 0:gate_out_dim] = self.experts_gate_projs[i_experts] - w1[i_experts, :, gate_out_dim:] = self.experts_up_projs[i_experts] - - inter_shape, hidden_size = self.w2_list[0].shape[0], self.w2_list[0].shape[1] - w2 = torch._utils._flatten_dense_tensors(self.w2_list).view(len(self.w2_list), inter_shape, hidden_size) - self.w1[0] = self._cuda(w1) - self.w2[0] = self._cuda(w2) - delattr(self, "w2_list") - delattr(self, "experts_up_projs") - delattr(self, "experts_gate_projs") - - def _fuse_weight_scale(self): - with self.lock: - if ( - hasattr(self, "experts_up_proj_scales") - and None not in self.experts_up_proj_scales - and None not in self.experts_gate_proj_scales - and None not in self.w2_scale_list - ): - gate_in_dim, gate_out_dim = self.experts_gate_proj_scales[0].shape - up_in_dim, up_out_dim = self.experts_up_proj_scales[0].shape - dtype = self.experts_gate_proj_scales[0].dtype - assert gate_in_dim == up_in_dim - total_expert_num = self.n_routed_experts - w1_scale = torch.empty( - (total_expert_num, gate_in_dim, gate_out_dim + up_out_dim), dtype=dtype, device="cpu" - ) - for i_experts in range(self.n_routed_experts): - w1_scale[i_experts, :, 0:gate_out_dim] = self.experts_gate_proj_scales[i_experts] - w1_scale[i_experts, :, gate_out_dim:] = self.experts_up_proj_scales[i_experts] - inter_shape, hidden_size = self.w2_scale_list[0].shape[0], self.w2_scale_list[0].shape[1] - w2_scale = torch._utils._flatten_dense_tensors(self.w2_scale_list).view( - len(self.w2_scale_list), inter_shape, hidden_size - ) - self.w1[1] = self._cuda(w1_scale).to(self.data_type_) - self.w2[1] = self._cuda(w2_scale).to(self.data_type_) - delattr(self, "w2_scale_list") - delattr(self, "experts_up_proj_scales") - delattr(self, "experts_gate_proj_scales") - - def _fuse_weight_zero_point(self): - with self.lock: - if ( - hasattr(self, "experts_up_proj_zero_points") - and None not in self.experts_up_proj_zero_points - and None not in self.experts_gate_proj_zero_points - and None not in self.w2_zero_point_list - ): - gate_in_dim, gate_out_dim = self.experts_gate_proj_zero_points[0].shape - up_in_dim, up_out_dim = self.experts_up_proj_zero_points[0].shape - assert gate_in_dim == up_in_dim - total_expert_num = self.n_routed_experts - w1_zero_point = torch.empty( - (total_expert_num, gate_in_dim, gate_out_dim + up_out_dim), dtype=torch.int32, device="cpu" - ) - for i_experts in range(self.n_routed_experts): - w1_zero_point[i_experts, :, 0:gate_out_dim] = self.experts_gate_proj_zero_points[i_experts] - w1_zero_point[i_experts, :, gate_out_dim:] = self.experts_up_proj_zero_points[i_experts] - inter_shape, hidden_size = self.w2_zero_point_list[0].shape[0], self.w2_zero_point_list[0].shape[1] - w2_zero_point = torch._utils._flatten_dense_tensors(self.w2_zero_point_list).view( - len(self.w2_zero_point_list), inter_shape, hidden_size - ) - self.w1[2] = self._cuda(w1_zero_point) - self.w2[2] = self._cuda(w2_zero_point) - delattr(self, "w2_zero_point_list") - delattr(self, "experts_up_proj_zero_points") - delattr(self, "experts_gate_proj_zero_points") - - def load_hf_weights(self, weights): - self._load_weight(weights) - self._load_weight_scale(weights) - self._load_weight_zero_point(weights) - self._fuse() - self._process_weight_after_loading() - - def _load_weight(self, weights: Dict[str, torch.Tensor]) -> None: - # awq quantization weight shape: in x out - if self.e_score_correction_bias_name in weights: - self.e_score_correction_bias = self._cuda(weights[self.e_score_correction_bias_name]) - for i_experts in range(self.n_routed_experts): - w1_weight = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.qweight" - w2_weight = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.qweight" - w3_weight = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.qweight" - - if w1_weight in weights: - self.experts_gate_projs[i_experts] = weights[w1_weight][ - :, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1) - ] - if w3_weight in weights: - self.experts_up_projs[i_experts] = weights[w3_weight][ - :, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1) - ] - - if w2_weight in weights: - self.w2_list[i_experts] = weights[w2_weight][ - self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), : - ] - - def _load_weight_scale(self, weights: Dict[str, torch.Tensor]) -> None: - for i_experts in range(self.n_routed_experts): - w1_scale = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.{self.weight_scale_suffix}" - w2_scale = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.{self.weight_scale_suffix}" - w3_scale = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.{self.weight_scale_suffix}" - split_inter_size = self.split_inter_size * self.pack_factor - if w1_scale in weights: - self.experts_gate_proj_scales[i_experts] = weights[w1_scale][ - :, - split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), - ] - if w3_scale in weights: - self.experts_up_proj_scales[i_experts] = weights[w3_scale][ - :, - split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), - ] - - if w2_scale in weights: - self.w2_scale_list[i_experts] = weights[w2_scale][ - split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), - :, - ] - - def _load_weight_zero_point(self, weights: Dict[str, torch.Tensor]) -> None: - for i_experts in range(self.n_routed_experts): - w1_zero_point = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.{self.weight_zero_point_suffix}" - w2_zero_point = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.{self.weight_zero_point_suffix}" - w3_zero_point = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.{self.weight_zero_point_suffix}" - if w1_zero_point in weights: - self.experts_gate_proj_zero_points[i_experts] = weights[w1_zero_point][ - :, - self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), - ] - if w3_zero_point in weights: - self.experts_up_proj_zero_points[i_experts] = weights[w3_zero_point][ - :, - self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), - ] - if w2_zero_point in weights: - self.w2_zero_point_list[i_experts] = weights[w2_zero_point][ - self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), - :, - ] - - def _process_weight_after_loading(self): - with self.lock: - if None in self.w1 or None in self.w2 or self.has_processed_weight: - return - self.has_processed_weight = True - from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops - - assert HAS_VLLM, "moe awq marlin quantization requires kernels of vllm" - - from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - marlin_moe_permute_scales, - moe_awq_to_marlin_zero_points, - marlin_make_workspace_new, - ) - - num_experts = self.n_routed_experts - device = self.w1[0].device - - self.w13_g_idx_sort_indices = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, device=device), - requires_grad=False, - ) - self.w2_g_idx_sort_indices = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, device=device), - requires_grad=False, - ) - self.w1[0] = vllm_ops.awq_marlin_moe_repack( - self.w1[0], - self.w13_g_idx_sort_indices, - size_k=self.w1[0].shape[1], - size_n=self.w1[0].shape[2] * self.pack_factor, - num_bits=self.num_bits, - ) - - self.w2[0] = vllm_ops.awq_marlin_moe_repack( - self.w2[0], - self.w2_g_idx_sort_indices, - size_k=self.w2[0].shape[1], - size_n=self.w2[0].shape[2] * self.pack_factor, - num_bits=self.num_bits, - ) - - # Why does this take the intermediate size for size_k? - self.w1[1] = marlin_moe_permute_scales( - s=self.w1[1], - size_k=self.split_inter_size * self.pack_factor, - size_n=self.w1[1].shape[2], - group_size=self.group_size, - ) - - self.w2[1] = marlin_moe_permute_scales( - s=self.w2[1], - size_k=self.split_inter_size * self.pack_factor, - size_n=self.w2[1].shape[2], - group_size=self.group_size, - ) - - self.w1[2] = moe_awq_to_marlin_zero_points( - self.w1[2], - size_k=self.w1[2].shape[1], - size_n=self.w1[2].shape[2] * self.pack_factor, - num_bits=self.num_bits, - ) - - self.w2[2] = moe_awq_to_marlin_zero_points( - self.w2[2], - size_k=self.w2[2].shape[1], - size_n=self.w2[2].shape[2] * self.pack_factor, - num_bits=self.num_bits, - ) - - self.workspace = marlin_make_workspace_new(device, 4) - - def _cuda(self, cpu_tensor): - device_id = get_current_device_id() - if self.quantized_weight: - return cpu_tensor.cuda(device_id) - return cpu_tensor.cuda(device_id) - - def verify_load(self): - return self.w1 is not None and self.w2 is not None diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/__init__.py index 63605b177..34d989b01 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/__init__.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/__init__.py @@ -1,10 +1,5 @@ from .mm_weight import ( - MMWeightPack, MMWeightTpl, ) -from .mm_factory import ( - MMWeight, - ROWMMWeight, - ROWBMMWeight, - COLMMWeight, -) +from .rowmm_weight import ROWMMWeight, ROWBMMWeight +from .colmm_weight import COLMMWeight diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py index 281f30f02..bf73b9ad8 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py @@ -1,19 +1,19 @@ import torch from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import ( MMWeightTpl, - DeepGemmFP8W8A8B128MMWeight, - AWQMMWeightTpl, ) from lightllm.common.quantization import Quantcfg from lightllm.utils.dist_utils import get_current_device_id from lightllm.common.quantization.quantize_method import QuantizationMethod from typing import Dict, List, Optional, Union -from .mm_slicer import ColSliceMixin, QuantizedColSliceMixin, AwqQuantizedColSliceMixin +from .mm_slicer import get_col_slice_mixin -class StandardCOLMMWeight(MMWeightTpl): +class COLMMWeight(MMWeightTpl): def __init__( self, + in_dim: int, + out_dims: Optional[Union[int, List[int]]], weight_names: Union[str, List[str]], data_type: torch.dtype, bias_names: Optional[Union[str, List[str]]] = None, @@ -22,6 +22,8 @@ def __init__( tp_world_size: int = None, ) -> None: super().__init__( + in_dim=in_dim, + out_dims=out_dims, weight_names=weight_names, data_type=data_type, bias_names=bias_names, @@ -29,74 +31,6 @@ def __init__( tp_rank=tp_rank, tp_world_size=tp_world_size, ) - self.param_slicer = ColSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) - - -class DeepGemmFP8W8A8B128COLMMWeight(DeepGemmFP8W8A8B128MMWeight): - def __init__( - self, - weight_names: Union[str, List[str]], - data_type: torch.dtype, - bias_names: Optional[Union[str, List[str]]] = None, - quant_method: QuantizationMethod = None, - tp_rank: int = None, - tp_world_size: int = None, - ) -> None: - super().__init__( - weight_names=weight_names, - data_type=data_type, - bias_names=bias_names, - quant_method=quant_method, - tp_rank=tp_rank, - tp_world_size=tp_world_size, - ) - self.param_slicer = QuantizedColSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) - - -class AWQCOLMMWeight(AWQMMWeightTpl): - def __init__( - self, - weight_names: Union[str, List[str]], - data_type: torch.dtype, - bias_names: Optional[Union[str, List[str]]] = None, - quant_method: QuantizationMethod = None, - tp_rank: int = None, - tp_world_size: int = None, - ) -> None: - super().__init__( - weight_names=weight_names, - data_type=data_type, - bias_names=bias_names, - quant_method=quant_method, - tp_rank=tp_rank, - tp_world_size=tp_world_size, + self.param_slicer = get_col_slice_mixin( + self.quant_method.method_name, tp_rank=tp_rank, tp_world_size=tp_world_size ) - # 注意这里不是错误,因为awq的weight是按inxout存的 - self.param_slicer = AwqQuantizedColSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) - - -class AWQMARLINCOLMMWeight(AWQCOLMMWeight): - def __init__( - self, - weight_names: Union[str, List[str]], - data_type: torch.dtype, - bias_names: Optional[Union[str, List[str]]] = None, - quant_method: QuantizationMethod = None, - tp_rank: int = None, - tp_world_size: int = None, - ) -> None: - super().__init__( - weight_names=weight_names, - data_type=data_type, - bias_names=bias_names, - quant_method=quant_method, - tp_rank=tp_rank, - tp_world_size=tp_world_size, - ) - - -COLMM_WEIGHT_CLS_MAP = { - "deepgemm-fp8w8a8-b128": DeepGemmFP8W8A8B128COLMMWeight, - "awq": AWQCOLMMWeight, - "awq_marlin": AWQMARLINCOLMMWeight, -} diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_factory.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_factory.py deleted file mode 100644 index 464de8441..000000000 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_factory.py +++ /dev/null @@ -1,90 +0,0 @@ -from lightllm.common.quantization import Quantcfg -from lightllm.common.quantization.quantize_method import QuantizationMethod -from typing import Type, Union, Dict -from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import ( - MMWeightTpl, - BMMWeightTpl, -) -from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.rowmm_weight import ( - StandardROWMMWeight, - UnquantizedROWBMMWeight, - ROWMM_WEIGHT_CLS_MAP, -) -from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.colmm_weight import ( - StandardCOLMMWeight, - COLMM_WEIGHT_CLS_MAP, -) - - -class MMWeight: - def __new__(cls, **kwargs): - """ - weight_names, - data_type, - bias_names, - quant_cfg, - layer_num, - name, - tp_rank, - tp_world_size, - ... - 该类主要是通过重载 __new__ 为对应的mm权重绑定量化方法,其他参数都是透传。 - """ - - quant_cfg = kwargs.pop("quant_cfg", None) - layer_num_ = kwargs.pop("layer_num", None) - name = kwargs.pop("name", None) - quant_method, quantized_weight = cls._get_quant_method(quant_cfg, layer_num_, name) - # quantized_weight 本身是用来标识权重本身在文件中是否是以量化后的形式存储, - # 现在不再使用该参数,是否量化由后续的加载过程自动识别。 - kwargs["quant_method"] = quant_method - mmcls = cls._get_mmcls(quant_method) - return mmcls(**kwargs) - - @classmethod - def _get_quant_method(cls, quant_cfg: Quantcfg, layer_num_: int, name: str) -> QuantizationMethod: - if quant_cfg is None: - return None, False - quant_method: QuantizationMethod = quant_cfg.get_quant_method(layer_num_, name) - if quant_method is None: - return None, False - quant_method.hf_quantization_config = quant_cfg.hf_quantization_config - quantized_weight = quant_cfg.quantized_weight - return quant_method, quantized_weight - - @classmethod - def _get_mmcls(cls, quant_method: QuantizationMethod) -> Type[Union[MMWeightTpl, BMMWeightTpl]]: - raise NotImplementedError("Subclasses must implement _get_mmcls method") - - -class ROWMMWeight(MMWeight): - @classmethod - def _get_mmcls(cls, quant_method: QuantizationMethod): - if quant_method is None: - return StandardROWMMWeight - - return ROWMM_WEIGHT_CLS_MAP.get( - quant_method.method_name, - StandardROWMMWeight, - ) - - -class ROWBMMWeight(MMWeight): - @classmethod - def _get_mmcls(cls, quant_method: QuantizationMethod): - if quant_method is None: - return UnquantizedROWBMMWeight - else: - # TODO: Implement more quantization weight - raise NotImplementedError("ROWBMMWeight is not implemented") - - -class COLMMWeight(MMWeight): - @classmethod - def _get_mmcls(cls, quant_method: QuantizationMethod): - if quant_method is None: - return StandardCOLMMWeight - return COLMM_WEIGHT_CLS_MAP.get( - quant_method.method_name, - StandardCOLMMWeight, - ) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py index e3ef5b0ea..e2830ab61 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py @@ -132,3 +132,21 @@ def __init__(self, tp_rank: int = None, tp_world_size: int = None): def _slice_bias(self, bias: torch.Tensor) -> torch.Tensor: return bias / self.tp_world_size_ + + +def get_row_slice_mixin(quant_method_name: str, tp_rank: int = None, tp_world_size: int = None) -> SliceMixinTpl: + if quant_method_name.startswith("awq"): + return AwqQuantizedRowSliceMixin(tp_rank, tp_world_size) + elif quant_method_name == "none": + return RowSliceMixin(tp_rank, tp_world_size) + else: + return QuantizedRowSliceMixin(tp_rank, tp_world_size) + + +def get_col_slice_mixin(quant_method_name: str, tp_rank: int = None, tp_world_size: int = None) -> SliceMixinTpl: + if quant_method_name.startswith("awq"): + return AwqQuantizedColSliceMixin(tp_rank, tp_world_size) + elif quant_method_name == "none": + return ColSliceMixin(tp_rank, tp_world_size) + else: + return QuantizedColSliceMixin(tp_rank, tp_world_size) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py index e603032ec..014cf2ec2 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py @@ -5,9 +5,10 @@ from dataclasses import dataclass from typing import Optional, Tuple, List, Dict, Union, Type from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager -from lightllm.common.quantization.quantize_method import QuantizationMethod +from lightllm.common.quantization.quantize_method import QuantizationMethod, WeightPack from lightllm.common.basemodel.layer_weights.meta_weights.base_weight import BaseWeightTpl from lightllm.common.quantization import Quantcfg +from lightllm.common.quantization.no_quant import NoQuantization from lightllm.utils.dist_utils import get_current_device_id from lightllm.utils.log_utils import init_logger from .mm_slicer import SliceMixinTpl @@ -15,53 +16,11 @@ logger = init_logger(__name__) -@dataclass -class MMWeightPack: - weight: Optional[torch.Tensor] = None - bias: Optional[torch.Tensor] = None - weight_scale: Optional[torch.Tensor] = None - weight_zero_point: Optional[torch.Tensor] = None - - has_bias: bool = False - has_weight_scale: bool = False - has_weight_zero_point: bool = False - - def is_ready(self) -> bool: - return ( - self.weight is not None - and (not self.has_bias or (self.has_bias and self.bias is not None)) - and (not self.has_weight_scale or (self.has_weight_scale and self.weight_scale is not None)) - and (not self.has_weight_zero_point or (self.has_weight_zero_point and self.weight_zero_point is not None)) - ) - - def ready_for_fused_merge(self) -> bool: - """ - 判断权重是否满足可以和其他权重进行融合cat的条件,因为可能权重是量化和非量化后的权重,所以复杂一些。 - """ - weight_ready = self.weight is not None and self.weight.dtype in [ - torch.bfloat16, - torch.float16, - torch.float32, - torch.float64, - ] - bias_ready = (self.has_bias and self.bias is not None) or (not self.has_bias) - if weight_ready and bias_ready: - return True - else: - return self.is_ready() - - def is_load_finished(self): - return ( - (self.is_ready() and self.weight.is_cuda) - and ((self.has_bias and self.bias.is_cuda) or (not self.has_bias)) - and ((self.has_weight_scale and self.weight_scale.is_cuda) or (not self.has_weight_scale)) - and ((self.has_weight_zero_point and self.weight_zero_point.is_cuda) or (not self.has_weight_zero_point)) - ) - - class MMWeightTpl(BaseWeightTpl): def __init__( self, + in_dim: int, + out_dims: Optional[Union[int, List[int]]], weight_names: Union[str, List[str]], bias_names: Optional[Union[str, List[str]]], data_type: torch.dtype, @@ -72,6 +31,14 @@ def __init__( super().__init__(tp_rank, tp_world_size, data_type) self.lock = threading.Lock() + self.in_dim = in_dim + if isinstance(out_dims, int): + out_dims = [out_dims] + self.out_dims = out_dims + self.cusum_out_dims = [0] + for out_dim in out_dims[:-1]: + self.cusum_out_dims.append(self.cusum_out_dims[-1] + out_dim) + if isinstance(weight_names, str): weight_names = [weight_names] if isinstance(bias_names, str): @@ -82,60 +49,29 @@ def __init__( if bias_names[0] is None: bias_names = None - if quant_method is not None: - has_weight_scale = quant_method.has_weight_scale - has_weight_zero_point = quant_method.has_weight_zero_point - else: - has_weight_scale = False - has_weight_zero_point = False - # 同时存在 weight_names 和 quanted_weight_names 是为了兼容在线和离线两种加载方案 self.weight_names = weight_names - self.bias_names = bias_names - has_bias = self.bias_names is not None - - self.gen_weight_quant_param_names(quant_method=quant_method) - self.quant_method = quant_method - self.sub_child_mm_params: List[MMWeightPack] = [ - MMWeightPack( - has_bias=has_bias, - has_weight_scale=has_weight_scale, - has_weight_zero_point=has_weight_zero_point, - ) - for _ in range(len(weight_names)) - ] - self.mm_param: MMWeightPack = MMWeightPack( - has_bias=has_bias, - has_weight_scale=has_weight_scale, - has_weight_zero_point=has_weight_zero_point, - ) + self.quant_method: QuantizationMethod = NoQuantization() if quant_method is None else quant_method self.param_slicer: SliceMixinTpl = None + self._create_weight() + self.gen_weight_quant_param_names(quant_method=quant_method) - self.weight_fused_dim = 0 - self.bias_fused_dim = 0 - self.weight_scale_and_zero_point_fused_dim = 0 - - self.load_finished: bool = False + def _create_weight(self): + self.bias = None + if self.bias_names is not None: + self.bias = torch.empty(self.cusum_out_dims[-1], dtype=self.data_type_).cuda(get_current_device_id()) + self.mm_param: WeightPack = self.quant_method.create_weight( + in_dim=self.in_dim, out_dim=sum(self.out_dims), dtype=self.data_type_, device_id=get_current_device_id() + ) + return def mm( self, input_tensor: torch.Tensor, out: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True ) -> torch.Tensor: - if self.quant_method is not None: - return self.quant_method.apply( - input_tensor, self.mm_param, out, use_custom_tensor_mananger=use_custom_tensor_mananger - ) - if out is None: - shape = (input_tensor.shape[0], self.mm_param.weight.shape[1]) - dtype = input_tensor.dtype - device = input_tensor.device - if use_custom_tensor_mananger: - out = g_cache_manager.alloc_tensor(shape, dtype, device=device, is_graph_out=False) - else: - out = torch.empty(shape, dtype=dtype, device=device) - if self.mm_param.bias is None: - return torch.mm(input_tensor, self.mm_param.weight, out=out) - return torch.addmm(self.mm_param.bias, input_tensor, self.mm_param.weight, out=out) + return self.quant_method.apply( + input_tensor, self.mm_param, out, use_custom_tensor_mananger=use_custom_tensor_mananger, bias=self.bias + ) def gen_weight_quant_param_names(self, quant_method: Optional[QuantizationMethod]): if quant_method is None: @@ -176,8 +112,6 @@ def gen_weight_quant_param_names(self, quant_method: Optional[QuantizationMethod return def load_hf_weights(self, weights): - if self.mm_param.is_load_finished(): - return for sub_child_index, param_name in enumerate(self.weight_names): self._load_weight(param_name=param_name, weights=weights, sub_child_index=sub_child_index) @@ -196,51 +130,8 @@ def load_hf_weights(self, weights): for sub_child_index, param_name in enumerate(self.weight_zero_point_names): self._load_weight_zero_point(param_name=param_name, weights=weights, sub_child_index=sub_child_index) - with self.lock: - # 如果需要fused的请求,全部ok了以后进行merge操作。, all([]) 竟然返回是True, 需要len(self.sub_child_mm_params) > 0 的额外判断。 - if len(self.sub_child_mm_params) > 0 and all(e.ready_for_fused_merge() for e in self.sub_child_mm_params): - self._fuse_weights() - self.sub_child_mm_params.clear() - - # 在线量化操作 - if ( - self.quant_method is not None - and self.mm_param.weight is not None - and self.quant_method.weight_need_quanted(self.mm_param.weight) - and self.load_finished is False - ): - logger.info(f"online quant weight names: {self.weight_names}") - quantized_weight, weight_scale, weight_zero_point = self.quant_method.quantize( - self.mm_param.weight.cuda(get_current_device_id()) - ) - self.mm_param.weight = quantized_weight - self.mm_param.weight_scale = weight_scale - self.mm_param.weight_zero_point = weight_zero_point - - # repack 操作 - if ( - self.quant_method is not None - and self.mm_param.is_ready() - and self.quant_method.params_need_repack() - and self.load_finished is False - ): - ( - self.mm_param.weight, - self.mm_param.weight_scale, - self.mm_param.weight_zero_point, - ) = self.quant_method.params_repack( - weight=self.mm_param.weight, - weight_scale=self.mm_param.weight_scale, - weight_zero_point=self.mm_param.weight_zero_point, - dtype_type=self.data_type_, - ) - - if self.mm_param.is_ready() and self.load_finished is False: - self._to_gpu_device() - self.load_finished = True - def verify_load(self) -> bool: - return self.mm_param.is_ready() + return True # 执行顺序 def _load_weight( @@ -248,7 +139,11 @@ def _load_weight( ) -> None: if param_name in weights: weight = self.param_slicer._slice_weight(weights[param_name]) - self.sub_child_mm_params[sub_child_index].weight = weight + start_idx = self.cusum_out_dims[sub_child_index] + if self.quant_method.weight_need_quanted(weight): + self.quant_method.quantize(weight, self.mm_param, offset=start_idx) + else: + self.quant_method.load_weight(weight, self.mm_param, start_idx) return def _load_bias( @@ -256,7 +151,9 @@ def _load_bias( ) -> None: if param_name in weights: bias = self.param_slicer._slice_bias(weights[param_name]) - self.sub_child_mm_params[sub_child_index].bias = bias + start_idx = self.cusum_out_dims[sub_child_index] + end_idx = start_idx + bias.shape[0] + self.mm_param.bias[start_idx:end_idx].copy_(bias) return def _load_weight_scale( @@ -264,7 +161,8 @@ def _load_weight_scale( ) -> None: if param_name in weights: weight_scale = self.param_slicer._slice_weight_scale(weights[param_name]) - self.sub_child_mm_params[sub_child_index].weight_scale = weight_scale + start_idx = self.cusum_out_dims[sub_child_index] + self.quant_method.load_weight_scale(weight_scale, self.mm_param, start_idx) return def _load_weight_zero_point( @@ -272,88 +170,8 @@ def _load_weight_zero_point( ) -> None: if param_name in weights: weight_zero_point = self.param_slicer._slice_weight_zero_point(weights[param_name]) - self.sub_child_mm_params[sub_child_index].weight_zero_point = weight_zero_point - return - - # weight merge - def _fuse_weights(self) -> None: - need_merge = len(self.sub_child_mm_params) > 1 - if self.mm_param.weight is None and all(p.weight is not None for p in self.sub_child_mm_params): - if need_merge: - weight = torch.cat([p.weight for p in self.sub_child_mm_params], dim=self.weight_fused_dim) - else: - weight = self.sub_child_mm_params[0].weight - - # 快速删除,防止占用显存过久 - for p in self.sub_child_mm_params: - p.weight = None - - self.mm_param.weight = weight - - if ( - self.mm_param.has_bias - and self.mm_param.bias is None - and all(p.bias is not None for p in self.sub_child_mm_params) - ): - if need_merge: - bias = torch.cat([p.bias for p in self.sub_child_mm_params], dim=self.bias_fused_dim) - else: - bias = self.sub_child_mm_params[0].bias - - # 快速删除,防止占用显存过久 - for p in self.sub_child_mm_params: - p.bias = None - - self.mm_param.bias = bias - - if self.mm_param.weight_scale is None and all(p.weight_scale is not None for p in self.sub_child_mm_params): - if need_merge: - weight_scale = torch.cat( - [p.weight_scale for p in self.sub_child_mm_params], dim=self.weight_scale_and_zero_point_fused_dim - ) - else: - weight_scale = self.sub_child_mm_params[0].weight_scale - - # 快速删除,防止占用显存过久 - for p in self.sub_child_mm_params: - p.weight_scale = None - - self.mm_param.weight_scale = weight_scale - - if self.mm_param.weight_zero_point is None and all( - p.weight_zero_point is not None for p in self.sub_child_mm_params - ): - if need_merge: - weight_zero_point = torch.cat( - [p.weight_zero_point for p in self.sub_child_mm_params], - dim=self.weight_scale_and_zero_point_fused_dim, - ) - else: - weight_zero_point = self.sub_child_mm_params[0].weight_zero_point - - # 快速删除,防止占用显存过久 - for p in self.sub_child_mm_params: - p.weight_zero_point = None - - self.mm_param.weight_zero_point = weight_zero_point - return - - def _to_gpu_device(self) -> None: - if self.mm_param.weight is not None: - if self.quant_method is not None: - self.mm_param.weight = self.mm_param.weight.cuda(get_current_device_id()) - else: - # 让 k dim 更连续,大多数split k 算法的算子可能能更快 - self.mm_param.weight = ( - self.mm_param.weight.to(self.data_type_).cuda(get_current_device_id()).transpose(0, 1) - ) - if self.mm_param.weight_scale is not None: - self.mm_param.weight_scale = self.mm_param.weight_scale.cuda(get_current_device_id()) - if self.mm_param.weight_zero_point is not None: - self.mm_param.weight_zero_point = self.mm_param.weight_zero_point.cuda(get_current_device_id()) - if self.mm_param.bias is not None: - # TODO 是不是所有的bias都需要转换为全局设置的数据类型吗,会不会影响精度 - self.mm_param.bias = self.mm_param.bias.to(self.data_type_).cuda(get_current_device_id()) + start_idx = self.cusum_out_dims[sub_child_index] + self.quant_method.load_weight_zero_point(weight_zero_point, self.mm_param, start_idx) return @@ -376,90 +194,6 @@ def bmm( out = g_cache_manager.alloc_tensor(shape, dtype, device=device, is_graph_out=False) else: out = torch.empty(shape, dtype=dtype, device=device) - if self.mm_param.bias is None: + if self.bias is None: return torch.bmm(input_tensor, fpweight, out=out) - return torch.addbmm(self.mm_param.bias, input_tensor, fpweight, out=out) - - def _to_gpu_device(self) -> None: - if self.mm_param.weight is not None: - if self.quant_method is not None: - self.mm_param.weight = self.mm_param.weight.cuda(get_current_device_id()) - else: - # bmm 不需要 transpose 操作 - self.mm_param.weight = self.mm_param.weight.to(self.data_type_).cuda(get_current_device_id()) - if self.mm_param.weight_scale is not None: - self.mm_param.weight_scale = self.mm_param.weight_scale.cuda(get_current_device_id()) - if self.mm_param.weight_zero_point is not None: - self.mm_param.weight_zero_point = self.mm_param.weight_zero_point.cuda(get_current_device_id()) - if self.mm_param.bias is not None: - # TODO 是不是所有的bias都需要转换为全局设置的数据类型吗,会不会影响精度 - self.mm_param.bias = self.mm_param.bias.to(self.data_type_).cuda(get_current_device_id()) - return - - -class DeepGemmFP8W8A8B128MMWeight(MMWeightTpl): - def __init__( - self, - weight_names: Union[str, List[str]], - data_type: torch.dtype, - bias_names: Optional[Union[str, List[str]]] = None, - quant_method: QuantizationMethod = None, - tp_rank: int = None, - tp_world_size: int = None, - ) -> None: - super().__init__( - weight_names=weight_names, - bias_names=bias_names, - data_type=data_type, - quant_method=quant_method, - tp_rank=tp_rank, - tp_world_size=tp_world_size, - ) - - def _to_gpu_device(self) -> None: - if self.mm_param.weight is not None: - self.mm_param.weight = self.mm_param.weight.cuda(get_current_device_id()).transpose(0, 1) - if self.mm_param.weight_scale is not None: - self.mm_param.weight_scale = self.mm_param.weight_scale.cuda(get_current_device_id()).transpose(0, 1) - - assert self.mm_param.has_weight_zero_point is False - - if self.mm_param.bias is not None: - # TODO 是不是所有的bias都需要转换为全局设置的数据类型吗,会不会影响精度 - self.mm_param.bias = self.mm_param.bias.to(self.data_type_).cuda(get_current_device_id()) - return - - -class AWQMMWeightTpl(MMWeightTpl): - def __init__( - self, - weight_names: Union[str, List[str]], - bias_names: Optional[Union[str, List[str]]] = None, - data_type: torch.dtype = None, - quant_method: QuantizationMethod = None, - tp_rank: int = None, - tp_world_size: int = None, - ) -> None: - super().__init__( - weight_names=weight_names, - bias_names=bias_names, - data_type=data_type, - quant_method=quant_method, - tp_rank=tp_rank, - tp_world_size=tp_world_size, - ) - self.weight_fused_dim = 1 - self.bias_fused_dim = 0 - self.weight_scale_and_zero_point_fused_dim = 1 - - def _to_gpu_device(self) -> None: - if self.mm_param.weight is not None: - self.mm_param.weight = self.mm_param.weight.cuda(get_current_device_id()) - if self.mm_param.weight_scale is not None: - self.mm_param.weight_scale = self.mm_param.weight_scale.to(self.data_type_).cuda(get_current_device_id()) - if self.mm_param.weight_zero_point is not None: - self.mm_param.weight_zero_point = self.mm_param.weight_zero_point.cuda(get_current_device_id()) - if self.mm_param.bias is not None: - # TODO 是不是所有的bias都需要转换为全局设置的数据类型吗,会不会影响精度 - self.mm_param.bias = self.mm_param.bias.to(self.data_type_).cuda(get_current_device_id()) - return + return torch.addbmm(self.bias, input_tensor, fpweight, out=out) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py index 0eebdc74d..e53d643ce 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py @@ -1,20 +1,20 @@ import torch from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import ( MMWeightTpl, - DeepGemmFP8W8A8B128MMWeight, - AWQMMWeightTpl, BMMWeightTpl, ) from lightllm.common.quantization import Quantcfg from lightllm.utils.dist_utils import get_current_device_id from lightllm.common.quantization.quantize_method import QuantizationMethod from typing import Dict, List, Optional, Union -from .mm_slicer import RowSliceMixin, QuantizedRowSliceMixin, AwqQuantizedRowSliceMixin +from .mm_slicer import get_row_slice_mixin -class StandardROWMMWeight(MMWeightTpl): +class ROWMMWeight(MMWeightTpl): def __init__( self, + in_dim: int, + out_dims: Optional[Union[int, List[int]]], weight_names: Union[str, List[str]], data_type: torch.dtype, bias_names: Optional[Union[str, List[str]]] = None, @@ -23,6 +23,8 @@ def __init__( tp_world_size: int = None, ) -> None: super().__init__( + in_dim=in_dim, + out_dims=out_dims, weight_names=weight_names, bias_names=bias_names, data_type=data_type, @@ -30,32 +32,12 @@ def __init__( tp_rank=tp_rank, tp_world_size=tp_world_size, ) - self.param_slicer = RowSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) - - -class DeepGemmFP8W8A8B128ROWMMWeight(DeepGemmFP8W8A8B128MMWeight): - def __init__( - self, - weight_names: Union[str, List[str]], - data_type: torch.dtype, - bias_names: Optional[Union[str, List[str]]] = None, - quant_method: QuantizationMethod = None, - tp_rank: int = None, - tp_world_size: int = None, - ) -> None: - super().__init__( - weight_names=weight_names, - data_type=data_type, - bias_names=bias_names, - quant_method=quant_method, - tp_rank=tp_rank, - tp_world_size=tp_world_size, + self.param_slicer = get_row_slice_mixin( + self.quant_method.method_name, tp_rank=tp_rank, tp_world_size=tp_world_size ) - self.param_slicer = QuantizedRowSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) - return -class UnquantizedROWBMMWeight(BMMWeightTpl): +class ROWBMMWeight(BMMWeightTpl): def __init__( self, weight_names: Union[str, List[str]], @@ -73,53 +55,5 @@ def __init__( tp_rank=tp_rank, tp_world_size=tp_world_size, ) - self.param_slicer = RowSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) - - -class AWQROWMMWeight(AWQMMWeightTpl): - def __init__( - self, - weight_names: Union[str, List[str]], - data_type: torch.dtype, - bias_names: Optional[Union[str, List[str]]] = None, - quant_method: QuantizationMethod = None, - tp_rank: int = None, - tp_world_size: int = None, - ) -> None: - super().__init__( - weight_names=weight_names, - data_type=data_type, - bias_names=bias_names, - quant_method=quant_method, - tp_rank=tp_rank, - tp_world_size=tp_world_size, - ) - - self.param_slicer = AwqQuantizedRowSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) - - -class AWQMARLINROWMMWeight(AWQROWMMWeight): - def __init__( - self, - weight_names: Union[str, List[str]], - data_type: torch.dtype, - bias_names: Optional[Union[str, List[str]]] = None, - quant_method: QuantizationMethod = None, - tp_rank: int = None, - tp_world_size: int = None, - ) -> None: - super().__init__( - weight_names=weight_names, - data_type=data_type, - bias_names=bias_names, - quant_method=quant_method, - tp_rank=tp_rank, - tp_world_size=tp_world_size, - ) - - -ROWMM_WEIGHT_CLS_MAP = { - "deepgemm-fp8w8a8-b128": DeepGemmFP8W8A8B128ROWMMWeight, - "awq": AWQROWMMWeight, - "awq_marlin": AWQMARLINROWMMWeight, -} + # bmm 不支持量化运算操作 + self.param_slicer = get_row_slice_mixin(quant_method_name="none", tp_rank=tp_rank, tp_world_size=tp_world_size) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py index 7ec672ab8..b92ec24cb 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py @@ -4,49 +4,59 @@ class NormWeight(BaseWeightTpl): - def __init__(self, weight_name, data_type, bias_name=None): + def __init__(self, norm_dim: int, weight_name, data_type, bias_name=None): super().__init__() + self.norm_dim = norm_dim self.weight_name = weight_name self.bias_name = bias_name self.data_type_ = data_type self.weight = None self.bias = None + self.is_weight_ready = False + self.is_bias_ready = False + self._create_weight() + + def _create_weight(self): + device = f"cuda:{get_current_device_id()}" + self.weight = torch.empty(self.norm_dim, dtype=self.data_type_, device=device) + self.bias = ( + torch.empty(self.norm_dim, dtype=self.data_type_, device=device) if self.bias_name is not None else None + ) def load_hf_weights(self, weights): if self.weight_name in weights: - self.weight = weights[self.weight_name].to(self.data_type_).cuda(get_current_device_id()) + self.weight.copy_(weights[self.weight_name]) + self.is_weight_ready = True if self.bias_name in weights: - self.bias = weights[self.bias_name].to(self.data_type_).cuda(get_current_device_id()) + self.bias.copy_(weights[self.bias_name]) + self.is_bias_ready = True def verify_load(self): - load_ok = True - # Verify weight. The weight must be not None. - load_ok = load_ok and self.weight is not None - # Verify bias. If bias_name is set, it must be not None. - if self.bias_name is not None: - load_ok = load_ok and self.bias is not None - return load_ok + return self.is_weight_ready and (self.bias_name is None or self.is_bias_ready) class GEMMANormWeight(NormWeight): - def __init__(self, weight_name, data_type, bias_name=None): - super().__init__(weight_name, data_type, bias_name) + def __init__(self, norm_dim: int, weight_name, data_type, bias_name=None): + super().__init__(norm_dim, weight_name, data_type, bias_name) def load_hf_weights(self, weights): + # TODO: 这里直接 +1 会不会导致精度问题? 计算时要求 (1.0 + weight.float()) ? if self.weight_name in weights: - self.weight = (weights[self.weight_name] + 1).to(self.data_type_).cuda(get_current_device_id()) + self.weight.copy_((weights[self.weight_name] + 1).to(self.data_type_)) + self.is_weight_ready = True class TpNormWeight(NormWeight): - def __init__(self, weight_name, data_type, split_n_embed, bias_name=None): - super().__init__(weight_name, data_type, bias_name) - self.split_n_embed = split_n_embed + def __init__(self, norm_dim: int, weight_name, data_type, bias_name=None): + super().__init__(norm_dim, weight_name, data_type, bias_name) def load_hf_weights(self, weights): - start = self.split_n_embed * self.tp_rank_ - end = self.split_n_embed * (self.tp_rank_ + 1) + start = self.norm_dim * self.tp_rank_ + end = self.norm_dim * (self.tp_rank_ + 1) if self.weight_name in weights: - self.weight = weights[self.weight_name][start:end].to(self.data_type_).cuda(get_current_device_id()) + self.weight.copy_(weights[self.weight_name][start:end].to(self.data_type_)) + self.is_weight_ready = True if self.bias_name in weights: - self.bias = weights[self.bias_name][start:end].to(self.data_type_).cuda(get_current_device_id()) + self.bias.copy_(weights[self.bias_name][start:end].to(self.data_type_)) + self.is_bias_ready = True diff --git a/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py b/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py index 97bc76237..1889ceb39 100644 --- a/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py +++ b/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py @@ -4,6 +4,7 @@ from .base_layer_weight import BaseLayerWeight from .meta_weights import BaseWeight, MMWeightTpl from lightllm.utils.log_utils import init_logger +from lightllm.common.quantization import Quantcfg logger = init_logger(__name__) @@ -15,7 +16,7 @@ def __init__(self, layer_num, data_type, network_config, mode, quant_cfg): self.data_type_ = data_type self.network_config_ = network_config self.mode = mode - self.quant_cfg = quant_cfg + self.quant_cfg: Quantcfg = quant_cfg self._parse_config() self._init_weight_names() self._init_weight() @@ -41,3 +42,6 @@ def load_hf_weights(self, weights): attr.load_hf_weights(weights) elif isinstance(attr, BaseWeight): attr.load_hf_weights(weights) + + def get_quant_method(self, name): + return self.quant_cfg.get_quant_method(self.layer_num_, name) diff --git a/lightllm/common/quantization/__init__.py b/lightllm/common/quantization/__init__.py index 26f59258c..ecf2e6d42 100644 --- a/lightllm/common/quantization/__init__.py +++ b/lightllm/common/quantization/__init__.py @@ -6,6 +6,7 @@ from .triton_quant.triton_quant import * from .deepgemm_quant import * from .awq_quant import * +from .no_quant import * from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -78,4 +79,6 @@ def get_quant_type(self, layer_num, name): def get_quant_method(self, layer_num, name): quant_type = self.get_quant_type(layer_num, name) - return QUANTMETHODS.get(quant_type) + quant_method = QUANTMETHODS.get(quant_type) + quant_method.hf_quantization_config = self.hf_quantization_config + return quant_method diff --git a/lightllm/common/quantization/awq_quant.py b/lightllm/common/quantization/awq_quant.py index 8c04cdcea..d523cce75 100644 --- a/lightllm/common/quantization/awq_quant.py +++ b/lightllm/common/quantization/awq_quant.py @@ -9,8 +9,7 @@ from typing import TYPE_CHECKING, Optional, Tuple from lightllm.utils.dist_utils import get_current_device_id -if TYPE_CHECKING: - from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import MMWeightPack +from .quantize_method import WeightPack if HAS_VLLM: awq_dequantize = vllm_ops.awq_dequantize @@ -39,16 +38,17 @@ def __init__(self): self.cache_manager = g_cache_manager - def quantize(self, weight: torch.Tensor): + def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0): raise NotImplementedError("AWQ online quantization is not supported yet.") def apply( self, input_tensor: torch.Tensor, - weight_pack: "MMWeightPack", + weight_pack: WeightPack, out: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: raise NotImplementedError("AWQ online quantization is not supported yet.") @@ -72,21 +72,21 @@ def __init__(self): def method_name(self): return "awq" - def quantize(self, weight: torch.Tensor): + def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0): raise NotImplementedError("AWQ online quantization is not supported yet.") def apply( self, input_tensor: torch.Tensor, - weight_pack: "MMWeightPack", + weight_pack: WeightPack, out: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: qweight = weight_pack.weight weight_scale = weight_pack.weight_scale qzeros = weight_pack.weight_zero_point - bias = weight_pack.bias NEED_DEQUANT_WEIGHT = input_tensor.shape[:-1].numel() >= 256 if NEED_DEQUANT_WEIGHT: @@ -99,6 +99,33 @@ def apply( out.add_(bias) return out + def create_weight( + self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> WeightPack: + group_size = self.hf_quantization_config["group_size"] + expert_prefix = (num_experts,) if num_experts > 1 else () + weight = torch.empty(expert_prefix + (in_dim, out_dim // self.pack_factor), dtype=torch.int32).cuda(device_id) + weight_scale = torch.empty(expert_prefix + (in_dim // group_size, out_dim), dtype=dtype).cuda(device_id) + weight_zero_point = torch.empty( + expert_prefix + (in_dim // group_size, out_dim // self.pack_factor), dtype=torch.int32 + ).cuda(device_id) + return WeightPack(weight=weight, weight_scale=weight_scale, weight_zero_point=weight_zero_point) + + def load_weight(self, weight: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + start_idx = start_idx // self.pack_factor + weight_pack.weight[:, start_idx : start_idx + weight.shape[1]].copy_(weight) + return + + def load_weight_scale(self, weight_scale: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + weight_pack.weight_scale[:, start_idx : start_idx + weight_scale.shape[1]].copy_(weight_scale) + return + + def load_weight_zero_point(self, weight_zero_point: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + start_idx = start_idx // self.pack_factor + end_idx = start_idx + weight_zero_point.shape[1] + weight_pack.weight_zero_point[:, start_idx:end_idx].copy_(weight_zero_point) + return + @QUANTMETHODS.register("awq_marlin") class AWQMARLINW4A16QuantizationMethod(AWQBaseQuantizationMethod): @@ -115,20 +142,15 @@ def __init__(self): self.vllm_quant_type = TYPE_MAP[self.nbits] self.has_weight_scale = True self.has_weight_zero_point = True + self.tile_size = 16 @property def method_name(self): return "awq_marlin" - def quantize(self, weight: torch.Tensor): + def quantize(self, weight: torch.Tensor, offset: int = 0) -> WeightPack: raise NotImplementedError("AWQ online quantization is not supported yet.") - def params_need_repack(self) -> bool: - """ - 用于说明是否需要对量化后的权重进行repack操作,目前只有awq支持 - """ - return True - def params_repack( self, weight: torch.Tensor, weight_scale: torch.Tensor, weight_zero_point: torch.Tensor, dtype_type: torch.dtype ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -144,47 +166,18 @@ def params_repack( ) return weight, weight_scale, weight_zero_point - def _process_weight_after_loading(self, weight: torch.Tensor) -> torch.Tensor: - assert self.hf_quantization_config is not None, "hf_quantization_config is not set" - self.k = weight.shape[0] - self.n = weight.shape[1] * self.pack_factor - return vllm_ops.awq_marlin_repack( - weight, - size_k=weight.shape[0], - size_n=weight.shape[1] * self.pack_factor, - num_bits=self.hf_quantization_config["bits"], - ) - - def _process_weight_scale_after_loading(self, weight_scale: torch.Tensor) -> torch.Tensor: - assert self.hf_quantization_config is not None, "hf_quantization_config is not set" - group_size = self.hf_quantization_config["group_size"] - return marlin_permute_scales( - weight_scale, - size_k=weight_scale.shape[0] * group_size, - size_n=weight_scale.shape[1], - group_size=self.hf_quantization_config["group_size"], - ) - - def _process_weight_zero_point_after_loading(self, weight_zero_point: torch.Tensor) -> torch.Tensor: - return awq_to_marlin_zero_points( - weight_zero_point, - size_k=weight_zero_point.shape[0], - size_n=weight_zero_point.shape[1] * self.pack_factor, - num_bits=self.hf_quantization_config["bits"], - ) - def apply( self, input_tensor: torch.Tensor, - weight_pack: "MMWeightPack", + weight_pack: WeightPack, out: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: qweight = weight_pack.weight weight_scale = weight_pack.weight_scale qzeros = weight_pack.weight_zero_point - bias = weight_pack.bias reshaped_x = input_tensor.reshape(-1, input_tensor.shape[-1]) use_atomic_add = should_use_atomic_add_reduce( @@ -219,6 +212,62 @@ def apply( out.add_(bias) return out + def create_weight( + self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> WeightPack: + self.n = out_dim + self.k = in_dim + group_size = self.hf_quantization_config["group_size"] + expert_prefix = (num_experts,) if num_experts > 1 else () + weight = torch.empty( + expert_prefix + (in_dim // self.tile_size, out_dim * self.tile_size // self.pack_factor), dtype=torch.int32 + ).cuda(device_id) + weight_scale = torch.empty(expert_prefix + (in_dim // group_size, out_dim), dtype=dtype).cuda(device_id) + weight_zero_point = torch.empty( + expert_prefix + (in_dim // group_size, out_dim // self.pack_factor), dtype=torch.int32 + ).cuda(device_id) + return WeightPack(weight=weight, weight_scale=weight_scale, weight_zero_point=weight_zero_point) + + def load_weight(self, weight: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + assert self.hf_quantization_config is not None, "hf_quantization_config is not set" + device_id = get_current_device_id() + repack_weight = vllm_ops.awq_marlin_repack( + weight.cuda(device_id), + size_k=weight.shape[0], + size_n=weight.shape[1] * self.pack_factor, + num_bits=self.hf_quantization_config["bits"], + ) + start_idx = start_idx // self.pack_factor * self.tile_size + weight_pack.weight[:, start_idx : start_idx + repack_weight.shape[1]].copy_(repack_weight) + return + + def load_weight_scale(self, weight_scale: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + assert self.hf_quantization_config is not None, "hf_quantization_config is not set" + group_size = self.hf_quantization_config["group_size"] + device_id = get_current_device_id() + repack_weight_scale = marlin_permute_scales( + weight_scale.cuda(device_id), + size_k=weight_scale.shape[0] * group_size, + size_n=weight_scale.shape[1], + group_size=self.hf_quantization_config["group_size"], + ) + weight_pack.weight_scale[:, start_idx : start_idx + repack_weight_scale.shape[1]].copy_(repack_weight_scale) + return + + def load_weight_zero_point(self, weight_zero_point: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + device_id = get_current_device_id() + repack_weight_zero_point = awq_to_marlin_zero_points( + weight_zero_point.cuda(device_id), + size_k=weight_zero_point.shape[0], + size_n=weight_zero_point.shape[1] * self.pack_factor, + num_bits=self.hf_quantization_config["bits"], + ) + start_idx = start_idx // self.pack_factor + weight_pack.weight_zero_point[:, start_idx : start_idx + repack_weight_zero_point.shape[1]].copy_( + repack_weight_zero_point + ) + return + # adapted from # https://github.com/vllm-project/vllm/blob/aef368aa08572505b820db01da82e2fbb3d43a72/vllm/model_executor/layers/quantization/awq_marlin.py#L211-L212 diff --git a/lightllm/common/quantization/deepgemm_quant.py b/lightllm/common/quantization/deepgemm_quant.py index f56630780..86dd9b572 100644 --- a/lightllm/common/quantization/deepgemm_quant.py +++ b/lightllm/common/quantization/deepgemm_quant.py @@ -1,5 +1,6 @@ import os import torch +from torch.types import Device from .quantize_method import QuantizationMethod from .registry import QUANTMETHODS import torch.nn.functional as F @@ -9,8 +10,8 @@ ) from typing import TYPE_CHECKING, Optional -if TYPE_CHECKING: - from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import MMWeightPack +from .quantize_method import WeightPack + try: HAS_DEEPGEMM = True import deep_gemm @@ -26,17 +27,17 @@ def __init__(self): self.cache_manager = g_cache_manager assert HAS_DEEPGEMM, "deepgemm is not installed, you can't use quant api of it" - def quantize(self, weight: torch.Tensor): - """ """ - pass + def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0): + raise NotImplementedError("Not implemented") def apply( self, input_tensor: torch.Tensor, - weight_pack: "MMWeightPack", + weight_pack: WeightPack, out: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: raise NotImplementedError("Not implemented") @@ -60,25 +61,30 @@ def __init__(self): def method_name(self): return "deepgemm-fp8w8a8-b128" - def quantize(self, weight: torch.Tensor): + def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0): from lightllm.common.quantization.triton_quant.fp8.fp8w8a8_block_quant_kernel import weight_quant - return weight_quant(weight, self.block_size) + device = output.weight.device + weight, scale = weight_quant(weight.cuda(device), self.block_size) + output.weight[offset : offset + weight.shape[0], :].copy_(weight) + output.weight_scale[offset // self.block_size : offset + weight.shape[0] // self.block_size].copy_(scale) + return def apply( self, input_tensor: torch.Tensor, - weight_pack: "MMWeightPack", + weight_pack: "WeightPack", out: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: qweight = weight_pack.weight weight_scale = weight_pack.weight_scale input_scale = None alloc_func = torch.empty if not use_custom_tensor_mananger else self.cache_manager.empty m, k = input_tensor.shape - n = qweight.shape[1] + n = qweight.shape[0] if input_scale is None: qinput_tensor, input_scale = per_token_group_quant_fp8( input_tensor, @@ -91,9 +97,35 @@ def apply( if out is None: out = alloc_func((m, n), dtype=input_tensor.dtype, device=input_tensor.device) - _deepgemm_fp8_nt((qinput_tensor, input_scale), (qweight.t(), weight_scale.t()), out) + _deepgemm_fp8_nt((qinput_tensor, input_scale), (qweight, weight_scale), out) return out + def create_weight( + self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> WeightPack: + expert_prefix = (num_experts,) if num_experts > 1 else () + weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=torch.float8_e4m3fn).cuda(device_id) + weight_scale = torch.empty( + expert_prefix + (out_dim // self.block_size, in_dim // self.block_size), dtype=torch.float32 + ).cuda(device_id) + return WeightPack(weight=weight, weight_scale=weight_scale) + + def load_weight(self, weight: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + weight_pack.weight[start_idx : start_idx + weight.shape[0]].copy_(weight) + return + + def load_weight_scale(self, weight_scale: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + weight_pack.weight_scale[ + start_idx // self.block_size : start_idx + weight_scale.shape[0] // self.block_size + ].copy_(weight_scale) + return + + def load_weight_zero_point(self, weight_zero_point: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + weight_pack.weight_zero_point[ + start_idx // self.block_size : start_idx + weight_zero_point.shape[0] // self.block_size + ].copy_(weight_zero_point) + return + def _deepgemm_fp8_nt(a_tuple, b_tuple, out): if HAS_DEEPGEMM: diff --git a/lightllm/common/quantization/no_quant.py b/lightllm/common/quantization/no_quant.py new file mode 100644 index 000000000..f342607c1 --- /dev/null +++ b/lightllm/common/quantization/no_quant.py @@ -0,0 +1,52 @@ +from .quantize_method import QuantizationMethod, WeightPack +from .registry import QUANTMETHODS +import torch +from typing import Optional + + +@QUANTMETHODS.register("none") +class NoQuantization(QuantizationMethod): + def apply( + self, + input_tensor: torch.Tensor, + weight_pack: WeightPack, + out: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager + + weight = weight_pack.weight.t() + if out is None: + shape = (input_tensor.shape[0], weight.shape[1]) + dtype = input_tensor.dtype + device = input_tensor.device + if use_custom_tensor_mananger: + out = g_cache_manager.alloc_tensor(shape, dtype, device=device, is_graph_out=False) + else: + out = torch.empty(shape, dtype=dtype, device=device) + if bias is None: + return torch.mm(input_tensor, weight, out=out) + return torch.addmm(bias, input_tensor, weight, out=out) + + def create_weight( + self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> WeightPack: + expert_prefix = (num_experts,) if num_experts > 1 else () + weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=dtype).cuda(device_id) + return WeightPack(weight=weight, weight_scale=None, weight_zero_point=None) + + def weight_need_quanted(self, weight: torch.Tensor) -> bool: + return False + + def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0) -> None: + return + + @property + def method_name(self): + return "none" + + def load_weight(self, weight: torch.Tensor, weight_pack: WeightPack, start_idx: int = 0) -> None: + weight_pack.weight[start_idx : start_idx + weight.shape[0], :].copy_(weight) + return diff --git a/lightllm/common/quantization/quantize_method.py b/lightllm/common/quantization/quantize_method.py index 9b629bcaf..77e59465e 100644 --- a/lightllm/common/quantization/quantize_method.py +++ b/lightllm/common/quantization/quantize_method.py @@ -1,38 +1,58 @@ import torch from abc import ABC, abstractmethod +from dataclasses import dataclass from lightllm.utils.dist_utils import get_current_device_id -from typing import TYPE_CHECKING, Optional, Tuple +from typing import Optional, Tuple -if TYPE_CHECKING: - from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import MMWeightPack + +@dataclass +class WeightPack: + weight: Optional[torch.Tensor] = None + weight_scale: Optional[torch.Tensor] = None + weight_zero_point: Optional[torch.Tensor] = None + + def get_expert(self, expert_idx: int): + assert self.weight.ndim == 3, f"weight must be a 3D tensor, but got {self.weight.ndim}" + weight = self.weight[expert_idx] + weight_scale = self.weight_scale[expert_idx] if self.weight_scale is not None else None + weight_zero_point = self.weight_zero_point[expert_idx] if self.weight_zero_point is not None else None + return WeightPack(weight=weight, weight_scale=weight_scale, weight_zero_point=weight_zero_point) class QuantizationMethod(ABC): def __init__(self): super().__init__() self.device_id_ = get_current_device_id() - self.weight_suffix = None + self.weight_suffix = "weight" self.weight_scale_suffix = None self.weight_zero_point_suffix = None self.act_scale_suffix = None self.has_weight_scale: bool = None self.has_weight_zero_point: bool = None + self.group_size: int = -1 # -1表示不分组即per-channel量化,其他表示分组大小 + self.pack_factor: int = 1 + # 一些量化模式需要用到的额外量化参数,如awq量化 self.hf_quantization_config = None @abstractmethod - def quantize(self, weights: torch.Tensor): + def quantize( + self, + weight: torch.Tensor, + output: WeightPack, + offset: int = 0, + ) -> None: pass @abstractmethod def apply( self, input_tensor: torch.Tensor, - weight_pack: "MMWeightPack", - bias: Optional[torch.Tensor] = None, + weight_pack: "WeightPack", out: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: pass @@ -41,20 +61,26 @@ def apply( def method_name(self): pass + def create_weight( + self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> WeightPack: + pass + def weight_need_quanted(self, weight: torch.Tensor) -> bool: # 判断一个 weight 是否需要进行量化操作。 return weight.dtype in [torch.bfloat16, torch.float16, torch.float32, torch.float64] - def params_need_repack(self) -> bool: - """ - 用于说明是否需要对量化后的权重进行repack操作,目前只有awq支持 - """ - return False - - def params_repack( - self, weight: torch.Tensor, weight_scale: torch.Tensor, weight_zero_point: torch.Tensor, dtype_type: torch.dtype - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - 一些量化方法在将参数完成量化后,为了加速性能,还需要将参数进行重拍,使算子性能达到最优,如awq方法。 - """ - return weight, weight_scale, weight_zero_point + def load_weight(self, weight: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + raise NotImplementedError( + f"quantization method {self.method_name} is not supported to load offline quantized weight" + ) + + def load_weight_scale(self, weight_scale: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + raise NotImplementedError( + f"quantization method {self.method_name} is not supported to load offline quantized weight scale" + ) + + def load_weight_zero_point(self, weight_zero_point: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + raise NotImplementedError( + f"quantization method {self.method_name} is not supported to load offline quantized weight zero point" + ) diff --git a/lightllm/common/quantization/registry.py b/lightllm/common/quantization/registry.py index 674a22b60..e9b407398 100644 --- a/lightllm/common/quantization/registry.py +++ b/lightllm/common/quantization/registry.py @@ -1,5 +1,4 @@ from .quantize_method import QuantizationMethod -from typing import Type class QuantMethodFactory: @@ -17,9 +16,7 @@ def decorator(cls): return decorator - def get(self, key, *args, **kwargs) -> Type[QuantizationMethod]: - if key == "none": - return None + def get(self, key, *args, **kwargs) -> "QuantizationMethod": quant_method_class = self._quant_methods.get(key) if not quant_method_class: raise ValueError(f"QuantMethod '{key}' not supported.") diff --git a/lightllm/common/quantization/torchao_quant.py b/lightllm/common/quantization/torchao_quant.py index ba4115b1d..d1db65b35 100644 --- a/lightllm/common/quantization/torchao_quant.py +++ b/lightllm/common/quantization/torchao_quant.py @@ -5,8 +5,7 @@ import torch.nn.functional as F from typing import TYPE_CHECKING, Optional -if TYPE_CHECKING: - from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import MMWeightPack +from .quantize_method import WeightPack try: HAS_TORCH_AO = True @@ -34,17 +33,17 @@ def __init__(self): assert TORCH_VERSION_AT_LEAST_2_4, "torchao requires torch >=2.4" self.quant_func = None - def quantize(self, weight: torch.Tensor): + def quantize(self, weight: torch.Tensor, offset: int = 0) -> WeightPack: """ """ dummy_linear = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) dummy_linear.weight = torch.nn.Parameter(weight.cuda(self.device_id_)) quantize_(dummy_linear, self.quant_func) - return dummy_linear.weight, None, None + return WeightPack(weight=dummy_linear.weight, weight_scale=None, weight_zero_point=None) def apply( self, input_tensor: torch.Tensor, - weight_pack: "MMWeightPack", + weight_pack: WeightPack, out: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, diff --git a/lightllm/common/quantization/triton_quant/fp8/fp8w8a8_block_quant_kernel.py b/lightllm/common/quantization/triton_quant/fp8/fp8w8a8_block_quant_kernel.py index 11c1897d7..3881cfe4b 100644 --- a/lightllm/common/quantization/triton_quant/fp8/fp8w8a8_block_quant_kernel.py +++ b/lightllm/common/quantization/triton_quant/fp8/fp8w8a8_block_quant_kernel.py @@ -55,4 +55,4 @@ def weight_quant(x: torch.Tensor, block_size: int = 128) -> tuple[torch.Tensor, return y_quant, s_scales else: y_quant, s_scales = mm_weight_quant(x, block_size) - return y_quant.t(), s_scales.t() + return y_quant, s_scales diff --git a/lightllm/common/quantization/triton_quant/fp8/fp8w8a8_scaled_mm_per_token_kernel.py b/lightllm/common/quantization/triton_quant/fp8/fp8w8a8_scaled_mm_per_token_kernel.py new file mode 100644 index 000000000..7c76e82c9 --- /dev/null +++ b/lightllm/common/quantization/triton_quant/fp8/fp8w8a8_scaled_mm_per_token_kernel.py @@ -0,0 +1,471 @@ +import torch +import triton +import triton.language as tl + +from lightllm.common.kernel_config import KernelConfigs +from frozendict import frozendict +from functools import lru_cache +from typing import Any, Dict, List, Optional, Tuple +from triton import Config +from lightllm.common.triton_utils.autotuner import autotune +from lightllm.utils.device_utils import triton_support_tensor_descriptor, is_5090_gpu + + +class Fp8ScaledMMKernelConfig(KernelConfigs): + kernel_name: str = "fp8_scaled_mm_per_token" + + @classmethod + @lru_cache(maxsize=200) + def try_to_get_best_config( + cls, + M: int, + N: int, + K: int, + out_dtype: str, + ) -> dict: + key_params = { + "N": N, + "K": K, + "out_dtype": str(out_dtype), + } + key_params = frozendict(key_params) + + finded_config = cls.get_the_config(key_params) + + if finded_config: + # find by M + config: dict = finded_config[min(finded_config.keys(), key=lambda x: abs(int(x) - M))] + return config + else: + config = { + "BLOCK_M": 64, + "BLOCK_N": 64, + "BLOCK_K": 64, + "GROUP_M": 8, + "num_warps": 4, + "num_stages": 3, + } + return config + + @classmethod + def save_config(cls, N: int, K: int, out_dtype: str, config_json: Dict[int, Dict[int, Dict]]): + + key_params = { + "N": N, + "K": K, + "out_dtype": str(out_dtype), + } + key_params = frozendict(key_params) + + return cls.store_config(key_params, config_json) + + +@triton.jit +def grouped_launch(pid, m_block_num, n_block_num, group_m: tl.constexpr): + + num_pid_in_group = group_m * n_block_num + group_id = pid // num_pid_in_group + first_pid_m = group_id * group_m + group_size_m = tl.minimum(m_block_num - first_pid_m, group_m) + in_group_index = pid % num_pid_in_group + + # Swizzle pattern: zigzag traversal + back_mark = (in_group_index // group_size_m) % 2 + back_mark1 = -1 * (2 * back_mark - 1) + pid_m = first_pid_m + back_mark * (group_size_m - 1) + back_mark1 * (in_group_index % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + return pid_m, pid_n + + +@triton.jit +def _scaled_mm_per_token( + A, + A_desc: "tl.core.tensor_descriptor", + B, + B_desc: "tl.core.tensor_descriptor", + out, + out_desc: "tl.core.tensor_descriptor", + Ascale, + Bscale, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + USE_TMA: tl.constexpr, + B_IS_TRANS: tl.constexpr, + NEED_N_MASK: tl.constexpr, + NEED_K_MASK: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, +): + pid = tl.program_id(0) + m_block_num = tl.cdiv(M, BLOCK_M) + n_block_num = tl.cdiv(N, BLOCK_N) + pid_m, pid_n = grouped_launch(pid, m_block_num, n_block_num, GROUP_M) + + start_m = pid_m * BLOCK_M + start_n = pid_n * BLOCK_N + + offs_am = start_m + tl.arange(0, BLOCK_M) + offs_bn = start_n + tl.arange(0, BLOCK_N) + + offs_am = tl.where(offs_am < M, offs_am, 0) + offs_bn = tl.where(offs_bn < N, offs_bn, 0) + + offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_M), BLOCK_M) + offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_N), BLOCK_N) + + offs_k = tl.arange(0, BLOCK_K) + + if not USE_TMA: + a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + Ascale_ptrs = Ascale + offs_am + Bscale_ptrs = Bscale + offs_bn + a_s = tl.load(Ascale_ptrs) + b_s = tl.load(Bscale_ptrs) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_K)): + if USE_TMA: + a = A_desc.load([start_m, k * BLOCK_K]) + if not B_IS_TRANS: + b = B_desc.load([k * BLOCK_K, start_n]) + else: + b = B_desc.load([start_n, k * BLOCK_K]).T + elif NEED_K_MASK: + k_remaining = K - k * BLOCK_K + a = tl.load(a_ptrs, mask=offs_k[None, :] < k_remaining, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0) + else: + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + acc = tl.dot(a, b, acc) + if not USE_TMA: + a_ptrs += BLOCK_K * stride_ak + b_ptrs += BLOCK_K * stride_bk + + acc = acc * a_s[:, None] * b_s[None, :] + + acc = acc.to(out.dtype.element_ty) + + if not USE_TMA: + offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + c_ptrs = out + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + if NEED_N_MASK: + mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + else: + mask = offs_cm[:, None] < M + tl.store(c_ptrs, acc, mask=mask) + else: + out_desc.store([start_m, start_n], acc) + + +def get_test_configs(): + fp8_gemm_configs = [] + + for BLOCK_M in [8, 16, 32, 64]: + for BLOCK_N in [64, 128, 256]: + for BLOCK_K in [32, 64, 128, 256]: + if BLOCK_K * BLOCK_M * BLOCK_N >= 256 * 256 * 128: + continue + for num_warps in [2, 4, 8]: + for num_stages in [2, 3, 4, 5, 6]: + config = { + "BLOCK_M": BLOCK_M, + "BLOCK_N": BLOCK_N, + "BLOCK_K": BLOCK_K, + "GROUP_M": 8, + "num_stages": num_stages, + "num_warps": num_warps, + } + fp8_gemm_configs.append(config) + + return fp8_gemm_configs + + +def _get_static_key(A, B, out_dtype): + M, K = A.shape + _, N = B.shape + return { + "N": N, + "K": K, + "out_dtype": str(out_dtype), + } + + +@autotune( + kernel_name="fp8_scaled_mm_per_token:v3", + configs_gen_func=get_test_configs, + static_key_func=_get_static_key, + run_key_func=lambda A: A.shape[0], + mutates_args=["out"], +) +def fp8_scaled_mm_per_token( + A: torch.Tensor, + B: torch.Tensor, + Ascale: torch.Tensor, + Bscale: torch.Tensor, + out_dtype: torch.dtype, + out: torch.Tensor, + run_config=None, +) -> torch.Tensor: + """w8a8fp8 per-token quantization mm. + + Args: + A: Matrix A with shape of [M, K]. + B: Matrix B with shape of [K, N]. + Ascale: per-token Quantization scale for A: [M] or [M, 1]. + Bscale: per-channel Quantization scale for B: [N] or [1, N]. + out_dtype: The data type of out. + out: The output matrix with the shape of [M, N]. + Returns: + torch.Tensor: out. + """ + assert A.is_contiguous() + B_is_trans = not B.is_contiguous() and B.stride(0) == 1 + + M, K = A.shape + _, N = B.shape + if not run_config: + run_config = Fp8ScaledMMKernelConfig.try_to_get_best_config(M=M, N=N, K=K, out_dtype=out_dtype) + NEED_N_MASK = N % run_config["BLOCK_N"] != 0 + NEED_K_MASK = K % run_config["BLOCK_K"] != 0 + grid = (triton.cdiv(M, run_config["BLOCK_M"]) * triton.cdiv(N, run_config["BLOCK_N"]),) + + BLOCK_M = run_config["BLOCK_M"] + BLOCK_K = run_config["BLOCK_K"] + BLOCK_N = run_config["BLOCK_N"] + + # use tma + support_tma = triton_support_tensor_descriptor() + # 5090 上,小shape开启tma性能不是很好。 + support_tma = support_tma and (not is_5090_gpu()) + if support_tma: + stride = A.stride(-2) + if (stride * A.dtype.itemsize) % 16 != 0: + support_tma = False + _B = B if not B_is_trans else B.transpose(0, 1) + stride = _B.stride(-2) + if (stride * _B.dtype.itemsize) % 16 != 0: + support_tma = False + + if support_tma: + # TMA descriptors require a global memory allocation + def alloc_fn(size: int, alignment: int, stream: Optional[int]): + return torch.empty(size, device="cuda", dtype=torch.int8) + + triton.set_allocator(alloc_fn) + + from triton.tools.tensor_descriptor import TensorDescriptor + + A_desc = TensorDescriptor(A, A.shape, A.stride(), [BLOCK_M, BLOCK_K]) + if B_is_trans: + _B = B.transpose(0, 1) + assert _B.is_contiguous() + B_desc = TensorDescriptor(_B, _B.shape, _B.stride(), [BLOCK_N, BLOCK_K]) + else: + B_desc = TensorDescriptor(B, B.shape, B.stride(), [BLOCK_K, BLOCK_N]) + out_desc = TensorDescriptor(out, out.shape, out.stride(), [BLOCK_M, BLOCK_N]) + else: + A_desc = None + B_desc = None + out_desc = None + + _scaled_mm_per_token[grid]( + A=A, + A_desc=A_desc, + B=B, + B_desc=B_desc, + out=out, + out_desc=out_desc, + Ascale=Ascale, + Bscale=Bscale, + M=M, + N=N, + K=K, + stride_am=A.stride(0), + stride_ak=A.stride(1), + stride_bk=B.stride(0), + stride_bn=B.stride(1), + stride_cm=out.stride(0), + stride_cn=out.stride(1), + USE_TMA=support_tma, + B_IS_TRANS=B_is_trans, + NEED_N_MASK=NEED_N_MASK, + NEED_K_MASK=NEED_K_MASK, + **run_config, + ) + + return out + + +if __name__ == "__main__": + import time + import os + from lightllm.common.triton_utils.autotuner import Autotuner + import torch.nn.functional as F + + output_dtype = torch.bfloat16 + N, K = 4096, 5120 + + # 测试多个不同的 M 值 + M_list = [1, 2, 4, 8, 16, 32, 48] + + print(f"{'='*80}") + print(f"Starting Autotune for FP8 Scaled MM (N={N}, K={K})") + print(f"M values to test: {M_list}") + print(f"Total configs per M: {len(get_test_configs())}") + print(f"{'='*80}\n") + + # 准备权重矩阵 B(所有测试共享) + B = torch.randn((N, K), dtype=output_dtype).cuda().to(torch.float8_e4m3fn).transpose(0, 1) # [K, N] + Bscale = torch.ones((1, N)).cuda() + + # 准备所有测试数据 + test_data = {} + for M in M_list: + A = torch.randn((M, K), dtype=output_dtype).cuda().to(torch.float8_e4m3fn) + Ascale = torch.randn((M, 1)).cuda() + out = torch.zeros((M, N), dtype=output_dtype).cuda() + test_data[M] = {"A": A, "Ascale": Ascale, "out": out} + + # ============ Phase 0: Correctness Check ============ + print("\n" + "=" * 80) + print("PHASE 0: Verifying Correctness Before Autotune") + print("=" * 80) + + # 选择一个中等大小的 M 进行正确性验证 + M_verify = 16 if 16 in M_list else M_list[len(M_list) // 2] + A_verify = test_data[M_verify]["A"] + Ascale_verify = test_data[M_verify]["Ascale"] + out_verify = test_data[M_verify]["out"] + + print(f"\n[Verification] Testing with M={M_verify}") + + # 计算ground truth + d_A = A_verify.to(output_dtype) * Ascale_verify.to(output_dtype) + d_B = B.to(output_dtype) * Bscale.to(output_dtype) + gt_C = d_A.mm(d_B) + + # 运行kernel验证正确性 + fp8_scaled_mm_per_token(A_verify, B, Ascale_verify, Bscale, output_dtype, out_verify) + + # 计算cosine similarity + cosine_sim = F.cosine_similarity(out_verify.flatten().unsqueeze(0), gt_C.flatten().unsqueeze(0), dim=1) + print(f"[Verification] Cosine Similarity: {cosine_sim.item():.6f}") + + # 计算max absolute error + max_abs_error = torch.max(torch.abs(out_verify - gt_C)).item() + mean_abs_error = torch.mean(torch.abs(out_verify - gt_C)).item() + print(f"[Verification] Max Absolute Error: {max_abs_error:.6e}") + print(f"[Verification] Mean Absolute Error: {mean_abs_error:.6e}") + + # 验证阈值 + if cosine_sim.item() < 0.99: + raise RuntimeError(f"Correctness check failed! Cosine similarity {cosine_sim.item():.6f} < 0.99") + + print("[Verification] ✅ Correctness check passed!") + print("=" * 80) + + # ============ Phase 1: Autotune ============ + print("\n" + "=" * 80) + print("PHASE 1: Running Autotune") + print("=" * 80) + Autotuner.start_autotune_warmup() + + for M in M_list: + print(f"\n[M={M}] Running autotune...") + A = test_data[M]["A"] + Ascale = test_data[M]["Ascale"] + out = test_data[M]["out"] + fp8_scaled_mm_per_token(A, B, Ascale, Bscale, output_dtype, out) + print(f"[M={M}] Autotune completed!") + + Autotuner.end_autotune_warmup() + print("\n" + "=" * 80) + print("All autotune completed! Now starting benchmarks...") + print("=" * 80) + + # ============ Phase 2: Benchmark ============ + results = [] + from sgl_kernel import fp8_scaled_mm + + for M in M_list: + print(f"\n{'='*80}") + print(f"Benchmarking M={M}") + print(f"{'='*80}") + + A = test_data[M]["A"] + Ascale = test_data[M]["Ascale"] + out = test_data[M]["out"] + + # 验证正确性 + print(f"[M={M}] Verifying correctness...") + d_A = A.to(output_dtype) * Ascale.to(output_dtype) + d_B = B.to(output_dtype) * Bscale.to(output_dtype) + gt_C = d_A.mm(d_B) + + # 运行一次确保结果正确 + fp8_scaled_mm_per_token(A, B, Ascale, Bscale, output_dtype, out) + sgl_res = fp8_scaled_mm(A, B, Ascale, Bscale, output_dtype) + + cosine_sim = F.cosine_similarity(out.flatten().unsqueeze(0), gt_C.flatten().unsqueeze(0), dim=1) + sgl_cosine_sim = F.cosine_similarity(sgl_res.flatten().unsqueeze(0), gt_C.flatten().unsqueeze(0), dim=1) + print(f"[M={M}] Cosine Similarity - Our: {cosine_sim.item():.6f}, SGL: {sgl_cosine_sim.item():.6f}") + + # Benchmark 性能 + print(f"[M={M}] Benchmarking performance...") + + # BF16 baseline + fn_bf16 = lambda: torch.mm(d_A, d_B) + ms_bf16 = triton.testing.do_bench(fn_bf16, warmup=25, rep=100) + + # SGL kernel + fn_sgl = lambda: fp8_scaled_mm(A, B, Ascale, Bscale, output_dtype) + ms_sgl = triton.testing.do_bench(fn_sgl, warmup=25, rep=100) + + # Our kernel + fn_ours = lambda: fp8_scaled_mm_per_token(A, B, Ascale, Bscale, output_dtype, out) + ms_ours = triton.testing.do_bench_cudagraph(fn_ours, rep=100) + + print(f"[M={M}] BF16: {ms_bf16:.3f} ms") + print(f"[M={M}] SGL FP8: {ms_sgl:.3f} ms ({ms_bf16/ms_sgl:.2f}x)") + print(f"[M={M}] Our FP8: {ms_ours:.3f} ms ({ms_bf16/ms_ours:.2f}x)") + + results.append( + { + "M": M, + "bf16_ms": ms_bf16, + "sgl_ms": ms_sgl, + "ours_ms": ms_ours, + "cosine_sim": cosine_sim.item(), + } + ) + + # 打印汇总结果 + print(f"\n{'='*80}") + print("SUMMARY - Performance Comparison") + print(f"{'='*80}") + print(f"{'M':<8} {'BF16(ms)':<12} {'SGL(ms)':<12} {'Our(ms)':<12} {'vs BF16':<10} {'vs SGL':<10}") + print(f"{'-'*80}") + for r in results: + vs_bf16 = f"{r['bf16_ms']/r['ours_ms']:.2f}x" + vs_sgl = f"{r['sgl_ms']/r['ours_ms']:.2f}x" + emoji = "🎉" if r["ours_ms"] < r["sgl_ms"] else "" + print( + f"{r['M']:<8} {r['bf16_ms']:<12.3f} {r['sgl_ms']:<12.3f}" + f"{r['ours_ms']:<12.3f} {vs_bf16:<10} {vs_sgl:<10} {emoji}" + ) + print(f"{'='*80}") diff --git a/lightllm/common/quantization/triton_quant/triton_quant.py b/lightllm/common/quantization/triton_quant/triton_quant.py index 410f925a5..9f6a7bee2 100644 --- a/lightllm/common/quantization/triton_quant/triton_quant.py +++ b/lightllm/common/quantization/triton_quant/triton_quant.py @@ -7,8 +7,7 @@ from .fp8.fp8act_quant_kernel import per_token_group_quant_fp8 from typing import TYPE_CHECKING, Optional -if TYPE_CHECKING: - from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import MMWeightPack +from lightllm.common.quantization.quantize_method import WeightPack class TritonBaseQuantizationMethod(QuantizationMethod): @@ -18,16 +17,17 @@ def __init__(self): self.cache_manager = g_cache_manager - def quantize(self, weight: torch.Tensor): - pass + def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0) -> WeightPack: + raise NotImplementedError("Not implemented") def apply( self, input_tensor: torch.Tensor, - weight_pack: "MMWeightPack", + weight_pack: WeightPack, out: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: raise NotImplementedError("Not implemented") @@ -44,17 +44,18 @@ def __init__(self): self.has_weight_scale = True self.has_weight_zero_point = False - def quantize(self, weight: torch.Tensor): + def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0) -> None: # TODO block-wise quant kernel - pass + raise NotImplementedError("Not implemented") def apply( self, input_tensor: torch.Tensor, - weight_pack: "MMWeightPack", + weight_pack: WeightPack, out: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: qweight = weight_pack.weight weight_scale = weight_pack.weight_scale @@ -83,3 +84,29 @@ def apply( dtype=input_tensor.dtype, ) return out + + def create_weight( + self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> WeightPack: + expert_prefix = (num_experts,) if num_experts > 1 else () + weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=torch.float8_e4m3fn).cuda(device_id) + weight_scale = torch.empty( + expert_prefix + (out_dim // self.block_size, in_dim // self.block_size), dtype=torch.float32 + ).cuda(device_id) + return WeightPack(weight=weight, weight_scale=weight_scale) + + def load_weight(self, weight: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + weight_pack.weight[start_idx : start_idx + weight.shape[0]].copy_(weight) + return + + def load_weight_scale(self, weight_scale: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + weight_pack.weight_scale[ + start_idx // self.block_size : start_idx + weight_scale.shape[0] // self.block_size + ].copy_(weight_scale) + return + + def load_weight_zero_point(self, weight_zero_point: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + weight_pack.weight_zero_point[ + start_idx // self.block_size : start_idx + weight_zero_point.shape[0] // self.block_size + ].copy_(weight_zero_point) + return diff --git a/lightllm/common/quantization/w8a8_quant.py b/lightllm/common/quantization/w8a8_quant.py index cec6d1778..8c5d1cc1e 100644 --- a/lightllm/common/quantization/w8a8_quant.py +++ b/lightllm/common/quantization/w8a8_quant.py @@ -9,8 +9,8 @@ from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops, cutlass_scaled_mm from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL, light_ops -if TYPE_CHECKING: - from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import MMWeightPack + +from .quantize_method import WeightPack if HAS_LIGHTLLM_KERNEL: @@ -30,16 +30,17 @@ def __init__(self): self.cache_manager = g_cache_manager - def quantize(self, weight: torch.Tensor): - pass + def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0) -> None: + raise NotImplementedError("Not implemented") def apply( self, input_tensor: torch.Tensor, - weight_pack: "MMWeightPack", + weight_pack: WeightPack, out: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: raise NotImplementedError("Not implemented") @@ -47,6 +48,11 @@ def apply( def method_name(self): return "w8a8-base" + def create_weight( + self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> WeightPack: + raise NotImplementedError("Not implemented") + @QUANTMETHODS.register(["vllm-w8a8", "w8a8"]) class w8a8QuantizationMethod(BaseQuantizationMethod): @@ -55,27 +61,27 @@ def __init__(self): self.has_weight_scale = True self.has_weight_zero_point = False - def quantize(self, weight: torch.Tensor): - if isinstance(weight, tuple): - return (weight[0].transpose(0, 1).cuda(self.device_id_),) + weight[1:] - weight = weight.float() + def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0) -> None: + weight = weight.float().cuda(self.device_id_) scale = weight.abs().max(dim=-1)[0] / 127 - weight = weight.transpose(0, 1) / scale.reshape(1, -1) + weight = weight / scale.reshape(-1, 1) weight = torch.round(weight.clamp(min=-128, max=127)).to(dtype=torch.int8) - return weight.cuda(self.device_id_), scale.cuda(self.device_id_), None + output.weight[offset : offset + weight.shape[0]].copy_(weight) + output.weight_scale[offset : offset + weight.shape[0]].copy_(scale) + return def apply( self, input_tensor: torch.Tensor, - weight_pack: "MMWeightPack", + weight_pack: WeightPack, out: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: input_scale = None - qweight = weight_pack.weight + qweight = weight_pack.weight.t() weight_scale = weight_pack.weight_scale - bias = weight_pack.bias input_scale = None # dynamic quantization for input tensor x_q, x_scale, x_zp = vllm_ops.scaled_int8_quant(input_tensor, scale=input_scale, azp=None, symmetric=True) m = input_tensor.shape[0] @@ -94,6 +100,14 @@ def apply( def method_name(self): return "vllm-w8a8" + def create_weight( + self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> WeightPack: + expert_prefix = (num_experts,) if num_experts > 1 else () + weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=torch.int8).cuda(device_id) + weight_scale = torch.empty(expert_prefix + (out_dim,), dtype=torch.float32).cuda(device_id) + return WeightPack(weight=weight, weight_scale=weight_scale) + @QUANTMETHODS.register(["vllm-fp8w8a8", "fp8w8a8"]) class FP8w8a8QuantizationMethod(BaseQuantizationMethod): @@ -103,19 +117,20 @@ def __init__(self): self.has_weight_scale = True self.has_weight_zero_point = False - def quantize(self, weight: torch.Tensor): + def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0) -> None: if self.is_moe: - return self.quantize_moe(weight) + return self.quantize_moe(weight, output, offset) qweight, weight_scale = scaled_fp8_quant( - weight.contiguous().cuda(self.device_id_), scale=None, use_per_token_if_dynamic=True + weight.cuda(self.device_id_), scale=None, use_per_token_if_dynamic=True ) - return qweight.transpose(0, 1), weight_scale, None + output.weight[offset : offset + qweight.shape[0], :].copy_(qweight) + output.weight_scale[offset : offset + weight_scale.shape[0]].copy_(weight_scale.view(-1)) + return - def quantize_moe(self, weight: torch.Tensor): + def quantize_moe(self, weight: torch.Tensor) -> WeightPack: num_experts = weight.shape[0] - qweights = [] - weight_scales = [] qweights = torch.empty_like(weight, dtype=torch.float8_e4m3fn).cuda(self.device_id_) + weight_scales = [] for i in range(num_experts): qweight, weight_scale = scaled_fp8_quant( weight[i].contiguous().cuda(self.device_id_), scale=None, use_per_token_if_dynamic=True @@ -123,19 +138,19 @@ def quantize_moe(self, weight: torch.Tensor): qweights[i] = qweight weight_scales.append(weight_scale) weight_scale = torch.stack(weight_scales, dim=0).contiguous() - return qweights, weight_scale + return WeightPack(weight=qweights, weight_scale=weight_scale) def apply( self, input_tensor: torch.Tensor, - weight_pack: "MMWeightPack", + weight_pack: WeightPack, out: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - qweight = weight_pack.weight + qweight = weight_pack.weight.t() weight_scale = weight_pack.weight_scale - bias = weight_pack.bias x_q, x_scale = scaled_fp8_quant(input_tensor, scale=None, scale_ub=None, use_per_token_if_dynamic=True) m = input_tensor.shape[0] n = qweight.shape[1] @@ -153,6 +168,14 @@ def apply( def method_name(self): return "vllm-fp8w8a8" + def create_weight( + self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> WeightPack: + expert_prefix = (num_experts,) if num_experts > 1 else () + weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=torch.float8_e4m3fn).cuda(device_id) + weight_scale = torch.empty(expert_prefix + (out_dim,), dtype=torch.float32).cuda(device_id) + return WeightPack(weight=weight, weight_scale=weight_scale) + @QUANTMETHODS.register(["vllm-fp8w8a8-b128", "fp8w8a8-b128"]) class FP8w8a8B128QuantizationMethod(BaseQuantizationMethod): @@ -163,21 +186,26 @@ def __init__(self): self.has_weight_scale = True self.has_weight_zero_point = False - def quantize(self, weight: torch.Tensor): + def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0) -> None: + from lightllm.common.quantization.triton_quant.fp8.fp8w8a8_block_quant_kernel import weight_quant - raise Exception("Not implemented") + device = output.weight.device + weight, scale = weight_quant(weight.cuda(device), self.block_size) + output.weight[offset : offset + weight.shape[0], :].copy_(weight) + output.weight_scale[offset // self.block_size : offset + weight.shape[0] // self.block_size].copy_(scale) + return def apply( self, input_tensor: torch.Tensor, - weight_pack: "MMWeightPack", + weight_pack: WeightPack, out: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - qweight = weight_pack.weight - weight_scale = weight_pack.weight_scale - bias = weight_pack.bias + qweight = weight_pack.weight.t() + weight_scale = weight_pack.weight_scale.t() input_scale = None # dynamic quantization for input tensor m, k = input_tensor.shape n = qweight.shape[1] @@ -206,3 +234,13 @@ def apply( @property def method_name(self): return "vllm-fp8w8a8-b128" + + def create_weight( + self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> WeightPack: + expert_prefix = (num_experts,) if num_experts > 1 else () + weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=torch.float8_e4m3fn).cuda(device_id) + weight_scale = torch.empty( + expert_prefix + (out_dim // self.block_size, in_dim // self.block_size), dtype=torch.float32 + ).cuda(device_id) + return WeightPack(weight=weight, weight_scale=weight_scale) diff --git a/lightllm/models/bloom/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/bloom/layer_weights/pre_and_post_layer_weight.py index b740bb62f..b03ed061d 100644 --- a/lightllm/models/bloom/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/bloom/layer_weights/pre_and_post_layer_weight.py @@ -6,21 +6,36 @@ class BloomPreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) + self._create_weight() + + def _create_weight(self): + vob_size = self.network_config_["vocab_size"] + hidden_size = self.network_config_["hidden_size"] + split_vob_size = vob_size // self.tp_world_size_ + + # Pre-allocate memory for weights + self.pre_norm_weight_ = torch.empty(hidden_size, dtype=self.data_type_).cuda() + self.pre_norm_bias_ = torch.empty(hidden_size, dtype=self.data_type_).cuda() + self.final_norm_weight_ = torch.empty(hidden_size, dtype=self.data_type_).cuda() + self.final_norm_bias_ = torch.empty(hidden_size, dtype=self.data_type_).cuda() + self.wte_weight_ = torch.empty((split_vob_size, hidden_size), dtype=self.data_type_).cuda() + self.lm_head_weight_ = torch.empty((split_vob_size, hidden_size), dtype=self.data_type_).cuda() + return def load_hf_weights(self, weights): if "word_embeddings_layernorm.weight" in weights: - self.pre_norm_weight_ = self._cuda(weights["word_embeddings_layernorm.weight"]) + self.pre_norm_weight_.copy_(weights["word_embeddings_layernorm.weight"]) if "word_embeddings_layernorm.bias" in weights: - self.pre_norm_bias_ = self._cuda(weights["word_embeddings_layernorm.bias"]) + self.pre_norm_bias_.copy_(weights["word_embeddings_layernorm.bias"]) if "ln_f.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["ln_f.weight"]) + self.final_norm_weight_.copy_(weights["ln_f.weight"]) if "ln_f.bias" in weights: - self.final_norm_bias_ = self._cuda(weights["ln_f.bias"]) + self.final_norm_bias_.copy_(weights["ln_f.bias"]) if "word_embeddings.weight" in weights: vob_size = self.network_config_["vocab_size"] split_vob_size = vob_size // self.tp_world_size_ - self.wte_weight_ = self._cuda( + self.wte_weight_.copy_( weights["word_embeddings.weight"][ split_vob_size * self.tp_rank_ : split_vob_size * (self.tp_rank_ + 1), : ] diff --git a/lightllm/models/chatglm2/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/chatglm2/layer_weights/pre_and_post_layer_weight.py index 0b125bea3..a93b30f94 100644 --- a/lightllm/models/chatglm2/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/chatglm2/layer_weights/pre_and_post_layer_weight.py @@ -13,19 +13,16 @@ def load_hf_weights(self, weights): vob_size = self.network_config_["padded_vocab_size"] split_vob_size = vob_size // self.tp_world_size_ if "transformer.embedding.word_embeddings.weight" in weights: - self.wte_weight_ = weights["transformer.embedding.word_embeddings.weight"] - self.wte_weight_ = self.wte_weight_[ + wte_weight = weights["transformer.embedding.word_embeddings.weight"] + self.wte_weight_.copy_(wte_weight[ split_vob_size * self.tp_rank_ : split_vob_size * (self.tp_rank_ + 1), : - ] - self.wte_weight_ = self._cuda(self.wte_weight_) + ]) if "transformer.output_layer.weight" in weights: - self.lm_head_weight_ = weights["transformer.output_layer.weight"] - self.lm_head_weight_ = self.lm_head_weight_[ + lm_head_weight = weights["transformer.output_layer.weight"] + self.lm_head_weight_.copy_(lm_head_weight[ split_vob_size * self.tp_rank_ : split_vob_size * (self.tp_rank_ + 1), : - ] - self.lm_head_weight_ = self._cuda(self.lm_head_weight_) + ]) if "transformer.encoder.final_layernorm.weight" in weights: - self.final_norm_weight_ = weights["transformer.encoder.final_layernorm.weight"] - self.final_norm_weight_ = self._cuda(self.final_norm_weight_) + self.final_norm_weight_.copy_(weights["transformer.encoder.final_layernorm.weight"]) return diff --git a/lightllm/models/cohere/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/cohere/layer_weights/pre_and_post_layer_weight.py index 993acd64d..ed550fecf 100644 --- a/lightllm/models/cohere/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/cohere/layer_weights/pre_and_post_layer_weight.py @@ -13,13 +13,13 @@ def load_hf_weights(self, weights): split_end = split_indexes[self.tp_rank_ + 1] if "model.embed_tokens.weight" in weights: # print(weights['model.embed_tokens.weight'].shape) - self.wte_weight_ = self._cuda(weights["model.embed_tokens.weight"][split_start:split_end, :]) + self.wte_weight_.copy_(weights["model.embed_tokens.weight"][split_start:split_end, :]) if tie_weight: self.lm_head_weight_ = self.wte_weight_ if "model.norm.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["model.norm.weight"]) + self.final_norm_weight_.copy_(weights["model.norm.weight"]) if "model.lm_head.weight" in weights: - self.lm_head_weight_ = self._cuda(weights["model.lm_head.weight"]) + self.lm_head_weight_.copy_(weights["model.lm_head.weight"]) return def verify_load(self): diff --git a/lightllm/models/cohere/layer_weights/transformer_layer_weight.py b/lightllm/models/cohere/layer_weights/transformer_layer_weight.py index fff92abf5..1dce8b51f 100644 --- a/lightllm/models/cohere/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/cohere/layer_weights/transformer_layer_weight.py @@ -18,14 +18,14 @@ def _init_norm(self, weights): q_split_head = self.network_config_["num_attention_heads"] // self.tp_world_size_ k_split_head = self.network_config_["num_key_value_heads"] // self.tp_world_size_ - self.att_norm_weight_ = NormWeight(self._att_norm_weight_name, self.data_type_) + self.att_norm_weight_ = NormWeight(self.n_embed, self._att_norm_weight_name, self.data_type_) if self.use_qk_norm: self.q_norm_weight_ = TpNormWeight( - f"model.layers.{self.layer_num_}.self_attn.q_norm.weight", self.data_type_, q_split_head + q_split_head, f"model.layers.{self.layer_num_}.self_attn.q_norm.weight", self.data_type_ ) self.k_norm_weight_ = TpNormWeight( - f"model.layers.{self.layer_num_}.self_attn.k_norm.weight", self.data_type_, k_split_head + k_split_head, f"model.layers.{self.layer_num_}.self_attn.k_norm.weight", self.data_type_ ) return diff --git a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py index c899751eb..390a26aa8 100644 --- a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py @@ -299,14 +299,14 @@ def _init_ffn(self): self._load_mlp(f"model.layers.{self.layer_num_}.mlp") def _init_norm(self): - self.att_norm_weight_ = NormWeight(f"model.layers.{self.layer_num_}.input_layernorm.weight", self.data_type_) + self.att_norm_weight_ = NormWeight(self.network_config_["n_embed"], f"model.layers.{self.layer_num_}.input_layernorm.weight", self.data_type_) self.ffn_norm_weight_ = NormWeight( - f"model.layers.{self.layer_num_}.post_attention_layernorm.weight", self.data_type_ + self.network_config_["n_embed"], f"model.layers.{self.layer_num_}.post_attention_layernorm.weight", self.data_type_ ) self.kv_a_layernorm_ = NormWeight( - f"model.layers.{self.layer_num_}.self_attn.kv_a_layernorm.weight", self.data_type_ + self.network_config_["n_embed"], f"model.layers.{self.layer_num_}.self_attn.kv_a_layernorm.weight", self.data_type_ ) if self.q_lora_rank is not None: self.q_a_layernorm_ = NormWeight( - f"model.layers.{self.layer_num_}.self_attn.q_a_layernorm.weight", self.data_type_ + self.network_config_["n_embed"], f"model.layers.{self.layer_num_}.self_attn.q_a_layernorm.weight", self.data_type_ ) diff --git a/lightllm/models/deepseek_mtp/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/deepseek_mtp/layer_weights/pre_and_post_layer_weight.py index f5b805647..66131a858 100644 --- a/lightllm/models/deepseek_mtp/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/deepseek_mtp/layer_weights/pre_and_post_layer_weight.py @@ -1,3 +1,4 @@ +import torch import numpy as np from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight @@ -10,15 +11,26 @@ def __init__(self, data_type, network_config, mode): self.lm_head_weight_ = None return + def _create_weight(self): + hidden_size = self.network_config_["hidden_size"] + moe_intermediate_size = self.network_config_["moe_intermediate_size"] + + # Pre-allocate memory for weights + self.eh_proj_weight_ = torch.empty((moe_intermediate_size, hidden_size), dtype=self.data_type_).cuda() + self.enorm_weight_ = torch.empty(hidden_size, dtype=self.data_type_).cuda() + self.hnorm_weight_ = torch.empty(hidden_size, dtype=self.data_type_).cuda() + self.final_norm_weight_ = torch.empty(hidden_size, dtype=self.data_type_).cuda() + return + def load_hf_weights(self, weights): if "model.layers.0.eh_proj.weight" in weights: - self.eh_proj_weight_ = self._cuda(weights["model.layers.0.eh_proj.weight"]).t() + self.eh_proj_weight_.copy_(weights["model.layers.0.eh_proj.weight"].t()) if "model.layers.0.enorm.weight" in weights: - self.enorm_weight_ = self._cuda(weights["model.layers.0.enorm.weight"]) + self.enorm_weight_.copy_(weights["model.layers.0.enorm.weight"]) if "model.layers.0.hnorm.weight" in weights: - self.hnorm_weight_ = self._cuda(weights["model.layers.0.hnorm.weight"]) + self.hnorm_weight_.copy_(weights["model.layers.0.hnorm.weight"]) if "model.layers.0.shared_head.norm.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["model.layers.0.shared_head.norm.weight"]) + self.final_norm_weight_.copy_(weights["model.layers.0.shared_head.norm.weight"]) return def verify_load(self): diff --git a/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py b/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py index 6f5530461..ec05a98ed 100644 --- a/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py @@ -66,13 +66,13 @@ def _init_qkv(self): def _init_norm(self): super()._init_norm() - self.k_norm_weight_ = NormWeight(self._k_norm_weight_name, self.data_type_, bias_name=None) - self.q_norm_weight_ = NormWeight(self._q_norm_weight_name, self.data_type_, bias_name=None) + self.k_norm_weight_ = NormWeight(self.head_dim, self._k_norm_weight_name, self.data_type_, bias_name=None) + self.q_norm_weight_ = NormWeight(self.head_dim, self._q_norm_weight_name, self.data_type_, bias_name=None) self.pre_feedforward_layernorm_weight_ = NormWeight( - self._pre_feedforward_layernorm_name, self.data_type_, bias_name=None + self.n_embed, self._pre_feedforward_layernorm_name, self.data_type_, bias_name=None ) self.post_feedforward_layernorm_weight_ = NormWeight( - self._post_feedforward_layernorm_name, self.data_type_, bias_name=None + self.n_embed, self._post_feedforward_layernorm_name, self.data_type_, bias_name=None ) def load_hf_weights(self, weights): diff --git a/lightllm/models/gemma_2b/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/gemma_2b/layer_weights/pre_and_post_layer_weight.py index c119960c5..fe388d532 100644 --- a/lightllm/models/gemma_2b/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/gemma_2b/layer_weights/pre_and_post_layer_weight.py @@ -15,11 +15,11 @@ def load_hf_weights(self, weights): split_end = split_indexes[self.tp_rank_ + 1] if "model.embed_tokens.weight" in weights: # print(weights['model.embed_tokens.weight'].shape) - self.wte_weight_ = self._cuda(weights["model.embed_tokens.weight"][split_start:split_end, :]) + self.wte_weight_.copy_(weights["model.embed_tokens.weight"][split_start:split_end, :]) self.lm_head_weight_ = self.wte_weight_ if "model.norm.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["model.norm.weight"]) + self.final_norm_weight_.copy_(weights["model.norm.weight"]) self.final_norm_weight_ = self.final_norm_weight_ + 1 return diff --git a/lightllm/models/gemma_2b/layer_weights/transformer_layer_weight.py b/lightllm/models/gemma_2b/layer_weights/transformer_layer_weight.py index 32248e6dd..49bc6150c 100644 --- a/lightllm/models/gemma_2b/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/gemma_2b/layer_weights/transformer_layer_weight.py @@ -29,5 +29,5 @@ def _init_qkv(self): ) def _init_norm(self): - self.att_norm_weight_ = GEMMANormWeight(self._att_norm_weight_name, self.data_type_) - self.ffn_norm_weight_ = GEMMANormWeight(self._ffn_norm_weight_name, self.data_type_) + self.att_norm_weight_ = GEMMANormWeight(self.n_embed, self._att_norm_weight_name, self.data_type_) + self.ffn_norm_weight_ = GEMMANormWeight(self.n_embed, self._ffn_norm_weight_name, self.data_type_) diff --git a/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py b/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py index 7e6035dc5..55fcf4f33 100644 --- a/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py @@ -2,7 +2,9 @@ import torch import numpy as np -from lightllm.common.basemodel.layer_weights.meta_weights.gpt_oss_fused_moe_weight_tp import GPTOSSFusedMoeWeightTP +from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe.gpt_oss_fused_moe_weight_tp import ( + GPTOSSFusedMoeWeightTP, +) from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight import ROWMMWeight from lightllm.common.basemodel.layer_weights.meta_weights.norm_weight import NormWeight, TpNormWeight from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight @@ -71,7 +73,7 @@ def _init_weight(self): super()._init_weight() n_split_head = self.network_config_["num_attention_heads"] // self.tp_world_size_ - self.attn_sinks = TpNormWeight(self._attn_sink_name, torch.bfloat16, n_split_head) + self.attn_sinks = TpNormWeight(n_split_head, self._attn_sink_name, torch.bfloat16) def _init_ffn(self): self._init_moe() diff --git a/lightllm/models/internlm2/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/internlm2/layer_weights/pre_and_post_layer_weight.py index dd8c64915..c77269db8 100644 --- a/lightllm/models/internlm2/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/internlm2/layer_weights/pre_and_post_layer_weight.py @@ -8,16 +8,30 @@ def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) return + def _create_weight(self): + vob_size = self.network_config_["vocab_size"] + hidden_size = self.network_config_["hidden_size"] + split_indexes = np.linspace(0, vob_size, self.tp_world_size_ + 1, dtype=np.int64) + split_start = split_indexes[self.tp_rank_] + split_end = split_indexes[self.tp_rank_ + 1] + split_vob_size = split_end - split_start + + # Pre-allocate memory for weights + self.wte_weight_ = torch.empty((split_vob_size, hidden_size), dtype=self.data_type_).cuda() + self.lm_head_weight_ = torch.empty((split_vob_size, hidden_size), dtype=self.data_type_).cuda() + self.final_norm_weight_ = torch.empty(hidden_size, dtype=self.data_type_).cuda() + return + def load_hf_weights(self, weights): vob_size = self.network_config_["vocab_size"] split_indexes = np.linspace(0, vob_size, self.tp_world_size_ + 1, dtype=np.int64) split_start = split_indexes[self.tp_rank_] split_end = split_indexes[self.tp_rank_ + 1] if "model.tok_embeddings.weight" in weights: - self.wte_weight_ = self._cuda(weights["model.tok_embeddings.weight"][split_start:split_end, :]) + self.wte_weight_.copy_(weights["model.tok_embeddings.weight"][split_start:split_end, :]) if "output.weight" in weights: - self.lm_head_weight_ = self._cuda(weights["output.weight"][split_start:split_end, :]) + self.lm_head_weight_.copy_(weights["output.weight"][split_start:split_end, :]) if "model.norm.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["model.norm.weight"]) + self.final_norm_weight_.copy_(weights["model.norm.weight"]) return diff --git a/lightllm/models/internlm2_reward/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/internlm2_reward/layer_weights/pre_and_post_layer_weight.py index 78fb0c5d7..7735a3f30 100644 --- a/lightllm/models/internlm2_reward/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/internlm2_reward/layer_weights/pre_and_post_layer_weight.py @@ -14,10 +14,10 @@ def load_hf_weights(self, weights): split_start = split_indexes[self.tp_rank_] split_end = split_indexes[self.tp_rank_ + 1] if "model.tok_embeddings.weight" in weights: - self.wte_weight_ = self._cuda(weights["model.tok_embeddings.weight"][split_start:split_end, :]) + self.wte_weight_.copy_(weights["model.tok_embeddings.weight"][split_start:split_end, :]) if "v_head.weight" in weights: - self.lm_head_weight_ = self._cuda(weights["v_head.weight"]).transpose(0, 1) + self.lm_head_weight_.copy_(weights["v_head.weight"].transpose(0, 1)) if "model.norm.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["model.norm.weight"]) + self.final_norm_weight_.copy_(weights["model.norm.weight"]) return diff --git a/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py index 711406e3f..98cf0d51c 100644 --- a/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py @@ -7,6 +7,18 @@ class LlamaPreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) + self._create_weight() + return + + def _create_weight(self): + vob_size = self.network_config_["vocab_size"] + hidden_size = self.network_config_["hidden_size"] + split_vob_size = vob_size // self.tp_world_size_ + + # Pre-allocate memory for weights + self.wte_weight_ = torch.empty((split_vob_size, hidden_size), dtype=self.data_type_).cuda() + self.lm_head_weight_ = torch.empty((split_vob_size, hidden_size), dtype=self.data_type_).cuda() + self.final_norm_weight_ = torch.empty(hidden_size, dtype=self.data_type_).cuda() return def load_hf_weights(self, weights): @@ -15,14 +27,14 @@ def load_hf_weights(self, weights): split_start = split_indexes[self.tp_rank_] split_end = split_indexes[self.tp_rank_ + 1] if "model.embed_tokens.weight" in weights: - self.wte_weight_ = self._cuda(weights["model.embed_tokens.weight"][split_start:split_end, :]) + self.wte_weight_.copy_(weights["model.embed_tokens.weight"][split_start:split_end, :]) tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False) if tie_word_embeddings: self.lm_head_weight_ = self.wte_weight_ if "lm_head.weight" in weights: - self.lm_head_weight_ = self._cuda(weights["lm_head.weight"][split_start:split_end, :]) + self.lm_head_weight_.copy_(weights["lm_head.weight"][split_start:split_end, :]) if "model.norm.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["model.norm.weight"]) + self.final_norm_weight_.copy_(weights["model.norm.weight"]) return diff --git a/lightllm/models/llama/layer_weights/transformer_layer_weight.py b/lightllm/models/llama/layer_weights/transformer_layer_weight.py index 624717007..8ca0fe15d 100644 --- a/lightllm/models/llama/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/llama/layer_weights/transformer_layer_weight.py @@ -24,11 +24,16 @@ def _init_weight(self): self._init_norm() def _parse_config(self): + self.tp_q_head_num_ = self.network_config_["num_attention_heads"] // self.tp_world_size_ + self.tp_k_head_num_ = max(self.network_config_["num_key_value_heads"] // self.tp_world_size_, 1) + self.tp_v_head_num_ = self.tp_k_head_num_ + self.tp_o_head_num_ = self.tp_q_head_num_ + head_dim = self.network_config_["hidden_size"] // self.network_config_["num_attention_heads"] + self.head_dim = self.network_config_.get("head_dim", head_dim) + assert (self.tp_k_head_num_ * self.tp_world_size_) % self.network_config_["num_key_value_heads"] == 0 self.n_embed = self.network_config_["hidden_size"] - self.n_head = self.network_config_["num_attention_heads"] self.n_inter = self.network_config_["intermediate_size"] - self.n_kv_head = self.network_config_["num_key_value_heads"] - self.head_dim = self.network_config_.get("head_dim", self.n_embed // self.n_head) + self.n_head = self.network_config_["num_attention_heads"] def _init_weight_names(self): self._q_weight_name = f"model.layers.{self.layer_num_}.self_attn.q_proj.weight" @@ -57,55 +62,63 @@ def _init_weight_names(self): self._ffn_norm_bias_name = None def _init_qkv(self): + in_dim = self.n_embed + q_out_dim = self.tp_q_head_num_ * self.head_dim + k_out_dim = self.tp_k_head_num_ * self.head_dim + v_out_dim = self.tp_v_head_num_ * self.head_dim self.q_proj = ROWMMWeight( + in_dim=in_dim, + out_dims=[q_out_dim], weight_names=self._q_weight_name, data_type=self.data_type_, bias_names=self._q_bias_name, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="q_proj", + quant_method=self.get_quant_method("q_proj"), ) self.kv_proj = ROWMMWeight( + in_dim=in_dim, + out_dims=[k_out_dim, v_out_dim], weight_names=[self._k_weight_name, self._v_weight_name], data_type=self.data_type_, bias_names=[self._k_bias_name, self._v_bias_name], - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="kv_proj", + quant_method=self.get_quant_method("kv_proj"), ) def _init_o(self): + in_dim = self.tp_o_head_num_ * self.head_dim + out_dim = self.n_embed self.o_proj = COLMMWeight( + in_dim=in_dim, + out_dims=[out_dim], weight_names=self._o_weight_name, data_type=self.data_type_, bias_names=self._o_bias_name, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="o_proj", + quant_method=self.get_quant_method("o_proj"), ) def _init_ffn(self): + in_dim = self.n_embed + out_dim = self.n_inter // self.tp_world_size_ self.gate_up_proj = ROWMMWeight( + in_dim=in_dim, + out_dims=[out_dim, out_dim], weight_names=[self._gate_weight_name, self._up_weight_name], data_type=self.data_type_, bias_names=[self._gate_bias_name, self._up_bias_name], - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="gate_up_proj", + quant_method=self.get_quant_method("gate_up_proj"), ) self.down_proj = COLMMWeight( + in_dim=out_dim, + out_dims=[in_dim], weight_names=self._down_weight_name, data_type=self.data_type_, bias_names=self._down_bias_name, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="down_proj", + quant_method=self.get_quant_method("down_proj"), ) def _init_norm(self): self.att_norm_weight_ = NormWeight( - self._att_norm_weight_name, self.data_type_, bias_name=self._att_norm_bias_name + self.n_embed, self._att_norm_weight_name, self.data_type_, bias_name=self._att_norm_bias_name ) self.ffn_norm_weight_ = NormWeight( - self._ffn_norm_weight_name, self.data_type_, bias_name=self._ffn_norm_bias_name + self.n_embed, self._ffn_norm_weight_name, self.data_type_, bias_name=self._ffn_norm_bias_name ) diff --git a/lightllm/models/minicpm/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/minicpm/layer_weights/pre_and_post_layer_weight.py index 94e1a27e0..0a9230b5b 100644 --- a/lightllm/models/minicpm/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/minicpm/layer_weights/pre_and_post_layer_weight.py @@ -18,11 +18,11 @@ def load_hf_weights(self, weights): split_start = split_indexes[self.tp_rank_] split_end = split_indexes[self.tp_rank_ + 1] if "model.embed_tokens.weight" in weights: - self.wte_weight_ = self._cuda(weights["model.embed_tokens.weight"][split_start:split_end, :]) + self.wte_weight_.copy_(weights["model.embed_tokens.weight"][split_start:split_end, :]) if "lm_head.weight" in weights: - self.lm_head_weight_ = self._cuda(weights["lm_head.weight"][split_start:split_end, :]) / self.lm_head_scale + self.lm_head_weight_.copy_(weights["lm_head.weight"][split_start:split_end, :] / self.lm_head_scale) if "model.norm.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["model.norm.weight"]) + self.final_norm_weight_.copy_(weights["model.norm.weight"]) return diff --git a/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py b/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py index f425ad08b..496768710 100644 --- a/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py @@ -62,6 +62,7 @@ def _init_moe(self): layer_num=self.layer_num_, quant_cfg=self.quant_cfg, num_fused_shared_experts=0, + hidden_size=self.network_config_.get("hidden_size"), ) else: raise ValueError(f"Unsupported moe mode: {moe_mode}") diff --git a/lightllm/models/qwen/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen/layer_weights/pre_and_post_layer_weight.py index 95af6ecd3..a07a55e8c 100644 --- a/lightllm/models/qwen/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen/layer_weights/pre_and_post_layer_weight.py @@ -6,6 +6,18 @@ class QwenPreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) + self._create_weight() + return + + def _create_weight(self): + vob_size = self.network_config_["vocab_size"] + hidden_size = self.network_config_["hidden_size"] + split_vob_size = vob_size // self.tp_world_size_ + + # Pre-allocate memory for weights + self.wte_weight_ = torch.empty((split_vob_size, hidden_size), dtype=self.data_type_).cuda() + self.lm_head_weight_ = torch.empty((split_vob_size, hidden_size), dtype=self.data_type_).cuda() + self.final_norm_weight_ = torch.empty(hidden_size, dtype=self.data_type_).cuda() return def load_hf_weights(self, weights): @@ -14,17 +26,17 @@ def load_hf_weights(self, weights): split_vob_size = vob_size // self.tp_world_size_ if "transformer.wte.weight" in weights: - self.wte_weight_ = self._cuda( + self.wte_weight_.copy_( weights["transformer.wte.weight"][ split_vob_size * self.tp_rank_ : split_vob_size * (self.tp_rank_ + 1), : ] ) if "lm_head.weight" in weights: - self.lm_head_weight_ = self._cuda( + self.lm_head_weight_.copy_( weights["lm_head.weight"][split_vob_size * self.tp_rank_ : split_vob_size * (self.tp_rank_ + 1), :] ) if "transformer.ln_f.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["transformer.ln_f.weight"]) + self.final_norm_weight_.copy_(weights["transformer.ln_f.weight"]) return diff --git a/lightllm/models/qwen2/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen2/layer_weights/pre_and_post_layer_weight.py index 5735b0339..772400a1e 100644 --- a/lightllm/models/qwen2/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen2/layer_weights/pre_and_post_layer_weight.py @@ -6,6 +6,21 @@ class Qwen2PreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) + self._create_weight() + return + + def _create_weight(self): + vob_size = self.network_config_["vocab_size"] + hidden_size = self.network_config_["hidden_size"] + split_indexes = np.linspace(0, vob_size, self.tp_world_size_ + 1, dtype=np.int64) + split_start = split_indexes[self.tp_rank_] + split_end = split_indexes[self.tp_rank_ + 1] + split_vob_size = split_end - split_start + + # Pre-allocate memory for weights + self.wte_weight_ = torch.empty((split_vob_size, hidden_size), dtype=self.data_type_).cuda() + self.lm_head_weight_ = torch.empty((split_vob_size, hidden_size), dtype=self.data_type_).cuda() + self.final_norm_weight_ = torch.empty(hidden_size, dtype=self.data_type_).cuda() return def load_hf_weights(self, weights): @@ -14,14 +29,14 @@ def load_hf_weights(self, weights): split_start = split_indexes[self.tp_rank_] split_end = split_indexes[self.tp_rank_ + 1] if "model.embed_tokens.weight" in weights: - self.wte_weight_ = self._cuda(weights["model.embed_tokens.weight"][split_start:split_end, :]) + self.wte_weight_.copy_(weights["model.embed_tokens.weight"][split_start:split_end, :]) tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False) if tie_word_embeddings: self.lm_head_weight_ = self.wte_weight_ if "lm_head.weight" in weights: - self.lm_head_weight_ = self._cuda(weights["lm_head.weight"][split_start:split_end, :]) + self.lm_head_weight_.copy_(weights["lm_head.weight"][split_start:split_end, :]) if "model.norm.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["model.norm.weight"]) + self.final_norm_weight_.copy_(weights["model.norm.weight"]) return diff --git a/lightllm/models/qwen2/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen2/layer_weights/transformer_layer_weight.py index 2e2c0d3bb..d65353108 100644 --- a/lightllm/models/qwen2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen2/layer_weights/transformer_layer_weight.py @@ -14,15 +14,6 @@ def _init_weight_names(self): self._k_bias_name = f"model.layers.{self.layer_num_}.self_attn.k_proj.bias" self._v_bias_name = f"model.layers.{self.layer_num_}.self_attn.v_proj.bias" - def _parse_config(self): - self.tp_q_head_num_ = self.network_config_["num_attention_heads"] // self.tp_world_size_ - self.tp_k_head_num_ = max(self.network_config_["num_key_value_heads"] // self.tp_world_size_, 1) - self.tp_v_head_num_ = self.tp_k_head_num_ - self.tp_o_head_num_ = self.tp_q_head_num_ - head_dim = self.network_config_["hidden_size"] // self.network_config_["num_attention_heads"] - self.head_dim = self.network_config_.get("head_dim", head_dim) - assert (self.tp_k_head_num_ * self.tp_world_size_) % self.network_config_["num_key_value_heads"] == 0 - def _repeat_weight(self, name, weights): # for tp_world_size_ > num_key_value_heads if name not in weights: diff --git a/lightllm/models/qwen2_reward/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen2_reward/layer_weights/pre_and_post_layer_weight.py index a56c5d6cb..9babad35a 100644 --- a/lightllm/models/qwen2_reward/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen2_reward/layer_weights/pre_and_post_layer_weight.py @@ -9,29 +9,49 @@ def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) return + def _create_weight(self): + vob_size = self.network_config_["vocab_size"] + hidden_size = self.network_config_["hidden_size"] + split_indexes = np.linspace(0, vob_size, self.tp_world_size_ + 1, dtype=np.int64) + split_start = split_indexes[self.tp_rank_] + split_end = split_indexes[self.tp_rank_ + 1] + split_vob_size = split_end - split_start + + # Pre-allocate memory for weights + self.wte_weight_ = torch.empty((split_vob_size, hidden_size), dtype=self.data_type_).cuda() + self.lm_head_weight_ = torch.empty((split_vob_size, hidden_size), dtype=self.data_type_).cuda() + self.final_norm_weight_ = torch.empty(hidden_size, dtype=self.data_type_).cuda() + + # Reward model specific weights + self.score_up_weight = torch.empty((hidden_size, 1), dtype=self.data_type_).cuda() + self.score_up_bias = torch.empty(1, dtype=self.data_type_).cuda() + self.score_down_weight = torch.empty((hidden_size, 1), dtype=self.data_type_).cuda() + self.score_down_bias = torch.empty(1, dtype=self.data_type_).cuda() + return + def load_hf_weights(self, weights): vob_size = self.network_config_["vocab_size"] split_indexes = np.linspace(0, vob_size, self.tp_world_size_ + 1, dtype=np.int64) split_start = split_indexes[self.tp_rank_] split_end = split_indexes[self.tp_rank_ + 1] if "model.embed_tokens.weight" in weights: - self.wte_weight_ = self._cuda(weights["model.embed_tokens.weight"][split_start:split_end, :]) + self.wte_weight_.copy_(weights["model.embed_tokens.weight"][split_start:split_end, :]) tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False) if tie_word_embeddings: self.lm_head_weight_ = self.wte_weight_ if "model.norm.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["model.norm.weight"]) + self.final_norm_weight_.copy_(weights["model.norm.weight"]) if "score.0.weight" in weights: - self.score_up_weight = self._cuda(weights["score.0.weight"]).transpose(0, 1) + self.score_up_weight.copy_(weights["score.0.weight"].transpose(0, 1)) if "score.0.bias" in weights: - self.score_up_bias = self._cuda(weights["score.0.bias"]) + self.score_up_bias.copy_(weights["score.0.bias"]) if "score.2.weight" in weights: - self.score_down_weight = self._cuda(weights["score.2.weight"]).transpose(0, 1) + self.score_down_weight.copy_(weights["score.2.weight"].transpose(0, 1)) if "score.2.bias" in weights: - self.score_down_bias = self._cuda(weights["score.2.bias"]) + self.score_down_bias.copy_(weights["score.2.bias"]) return diff --git a/lightllm/models/qwen3/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3/layer_weights/transformer_layer_weight.py index 4c0ef586f..dcee72a1c 100644 --- a/lightllm/models/qwen3/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3/layer_weights/transformer_layer_weight.py @@ -20,5 +20,5 @@ def _init_weight_names(self): def _init_norm(self): super()._init_norm() - self.q_norm_weight_ = NormWeight(weight_name=self._q_norm_name, data_type=self.data_type_) - self.k_norm_weight_ = NormWeight(weight_name=self._k_norm_name, data_type=self.data_type_) + self.q_norm_weight_ = NormWeight(self.head_dim, weight_name=self._q_norm_name, data_type=self.data_type_) + self.k_norm_weight_ = NormWeight(self.head_dim, weight_name=self._k_norm_name, data_type=self.data_type_) diff --git a/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py index 72721f9d6..bc4b54819 100644 --- a/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py @@ -53,10 +53,11 @@ def _init_weight(self): def _init_moe(self): moe_intermediate_size = self.network_config_["moe_intermediate_size"] self.moe_gate = ROWMMWeight( + in_dim=self.network_config_["hidden_size"], + out_dims=[self.n_routed_experts], weight_names=f"model.layers.{self.layer_num_}.mlp.gate.weight", data_type=self.data_type_, - layer_num=self.layer_num_, - name="moe_gate", + quant_method=None, tp_rank=0, tp_world_size=1, ) diff --git a/lightllm/models/stablelm/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/stablelm/layer_weights/pre_and_post_layer_weight.py index 80966c7b4..9002c463d 100755 --- a/lightllm/models/stablelm/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/stablelm/layer_weights/pre_and_post_layer_weight.py @@ -8,6 +8,21 @@ def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) return + def _create_weight(self): + vob_size = self.network_config_["vocab_size"] + hidden_size = self.network_config_["hidden_size"] + split_indexes = np.linspace(0, vob_size, self.tp_world_size_ + 1, dtype=np.int64) + split_start = split_indexes[self.tp_rank_] + split_end = split_indexes[self.tp_rank_ + 1] + split_vob_size = split_end - split_start + + # Pre-allocate memory for weights + self.wte_weight_ = torch.empty((split_vob_size, hidden_size), dtype=self.data_type_).cuda() + self.lm_head_weight_ = torch.empty((split_vob_size, hidden_size), dtype=self.data_type_).cuda() + self.final_norm_weight_ = torch.empty(hidden_size, dtype=self.data_type_).cuda() + self.final_norm_bias_ = torch.empty(hidden_size, dtype=self.data_type_).cuda() + return + def load_hf_weights(self, weights): vob_size = self.network_config_["vocab_size"] split_indexes = np.linspace(0, vob_size, self.tp_world_size_ + 1, dtype=np.int64) @@ -15,14 +30,14 @@ def load_hf_weights(self, weights): split_end = split_indexes[self.tp_rank_ + 1] if "model.embed_tokens.weight" in weights: # print(weights['model.embed_tokens.weight'].shape) - self.wte_weight_ = self._cuda(weights["model.embed_tokens.weight"][split_start:split_end, :]) + self.wte_weight_.copy_(weights["model.embed_tokens.weight"][split_start:split_end, :]) if "lm_head.weight" in weights: # print(weights['lm_head.weight'].shape) - self.lm_head_weight_ = self._cuda(weights["lm_head.weight"][split_start:split_end, :]) + self.lm_head_weight_.copy_(weights["lm_head.weight"][split_start:split_end, :]) if "model.norm.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["model.norm.weight"]) + self.final_norm_weight_.copy_(weights["model.norm.weight"]) if "model.norm.bias" in weights: - self.final_norm_bias_ = self._cuda(weights["model.norm.bias"]) + self.final_norm_bias_.copy_(weights["model.norm.bias"]) return diff --git a/lightllm/models/starcoder/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/starcoder/layer_weights/pre_and_post_layer_weight.py index 8d87c1163..b54fc068b 100644 --- a/lightllm/models/starcoder/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/starcoder/layer_weights/pre_and_post_layer_weight.py @@ -6,6 +6,21 @@ class StarcoderPreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) + self._create_weight() + + def _create_weight(self): + vob_size = self.network_config_["vocab_size"] + hidden_size = self.network_config_["hidden_size"] + split_vob_size = vob_size // self.tp_world_size_ + max_position_embeddings = self.network_config_["max_position_embeddings"] + + # Pre-allocate memory for weights + self.wte_weight_ = torch.empty((split_vob_size, hidden_size), dtype=self.data_type_).cuda() + self.wpe_weight_ = torch.empty((max_position_embeddings, hidden_size), dtype=self.data_type_).cuda() + self.lm_head_weight_ = torch.empty((split_vob_size, hidden_size), dtype=self.data_type_).cuda() + self.final_norm_weight_ = torch.empty(hidden_size, dtype=self.data_type_).cuda() + self.final_norm_bias_ = torch.empty(hidden_size, dtype=self.data_type_).cuda() + return def load_hf_weights(self, weights): @@ -13,28 +28,22 @@ def load_hf_weights(self, weights): split_vob_size = vob_size // self.tp_world_size_ if "transformer.wte.weight" in weights: # print(weights['transformer.wte.weight'].shape) - self.wte_weight_ = ( + self.wte_weight_.copy_( weights["transformer.wte.weight"][ split_vob_size * self.tp_rank_ : split_vob_size * (self.tp_rank_ + 1), : ] - .contiguous() - .to(self.data_type_) - .cuda() ) if "transformer.wpe.weight" in weights: # print(weights['transformer.wpe.weight'].shape) - self.wpe_weight_ = weights["transformer.wpe.weight"].to(self.data_type_).cuda() + self.wpe_weight_.copy_(weights["transformer.wpe.weight"]) if "lm_head.weight" in weights: - self.lm_head_weight_ = ( + self.lm_head_weight_.copy_( weights["lm_head.weight"][split_vob_size * self.tp_rank_ : split_vob_size * (self.tp_rank_ + 1), :] - .contiguous() - .to(self.data_type_) - .cuda() ) if "transformer.ln_f.weight" in weights: - self.final_norm_weight_ = weights["transformer.ln_f.weight"].contiguous().to(self.data_type_).cuda() + self.final_norm_weight_.copy_(weights["transformer.ln_f.weight"]) if "transformer.ln_f.bias" in weights: - self.final_norm_bias_ = weights["transformer.ln_f.bias"].contiguous().to(self.data_type_).cuda() + self.final_norm_bias_.copy_(weights["transformer.ln_f.bias"]) return def verify_load(self): diff --git a/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py index fd2d47575..cfe1969c0 100644 --- a/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py @@ -1,3 +1,4 @@ +import torch import numpy as np from lightllm.common.basemodel import PreAndPostLayerWeight @@ -5,6 +6,22 @@ class Starcoder2PreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) + self._create_weight() + return + + def _create_weight(self): + vob_size = self.network_config_["vocab_size"] + hidden_size = self.network_config_["hidden_size"] + split_indexes = np.linspace(0, vob_size, self.tp_world_size_ + 1, dtype=np.int64) + split_start = split_indexes[self.tp_rank_] + split_end = split_indexes[self.tp_rank_ + 1] + split_vob_size = split_end - split_start + + # Pre-allocate memory for weights + self.wte_weight_ = torch.empty((split_vob_size, hidden_size), dtype=self.data_type_).cuda() + self.lm_head_weight_ = torch.empty((split_vob_size, hidden_size), dtype=self.data_type_).cuda() + self.final_norm_weight_ = torch.empty(hidden_size, dtype=self.data_type_).cuda() + self.final_norm_bias_ = torch.empty(hidden_size, dtype=self.data_type_).cuda() return def load_hf_weights(self, weights): @@ -13,19 +30,19 @@ def load_hf_weights(self, weights): split_start = split_indexes[self.tp_rank_] split_end = split_indexes[self.tp_rank_ + 1] if "model.embed_tokens.weight" in weights: - self.wte_weight_ = self._cuda(weights["model.embed_tokens.weight"][split_start:split_end, :]) + self.wte_weight_.copy_(weights["model.embed_tokens.weight"][split_start:split_end, :]) # for starcoder2-3b and 7b which didn't use lm_head.weight (tie_word_embeddings) self.lm_head_weight_ = self.wte_weight_ if "lm_head.weight" in weights: - self.lm_head_weight_ = self._cuda(weights["lm_head.weight"][split_start:split_end, :]) + self.lm_head_weight_.copy_(weights["lm_head.weight"][split_start:split_end, :]) if "model.norm.weight" in weights: - self.final_norm_weight_ = self._cuda(weights["model.norm.weight"]) + self.final_norm_weight_.copy_(weights["model.norm.weight"]) if "model.norm.bias" in weights: - self.final_norm_bias_ = self._cuda(weights["model.norm.bias"]) + self.final_norm_bias_.copy_(weights["model.norm.bias"]) return diff --git a/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py index 276d4e5d0..1a6f76fde 100644 --- a/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py @@ -13,6 +13,34 @@ def __init__(self, data_type, network_config, mode): self.image_size = self.network_config_["image_size"] self.patch_size = self.network_config_["patch_size"] self.llm_hidden_size = self.network_config_["llm_hidden_size"] + self._create_weight() + return + + def _create_weight(self): + split_indexes = np.linspace(0, self.embed_dim, self.tp_world_size_ + 1, dtype=np.int64) + split_start = split_indexes[self.tp_rank_] + split_end = split_indexes[self.tp_rank_ + 1] + split_embed_dim = split_end - split_start + + # Pre-allocate memory for vision model weights + self.class_embedding = torch.empty((1, 1, split_embed_dim), dtype=self.data_type_).cuda() + self.position_embedding = torch.empty((1, 197, split_embed_dim), dtype=self.data_type_).cuda() # 197 = (224//16)^2 + 1 + self.patch_embedding_weight_ = torch.empty((split_embed_dim, 3, self.patch_size, self.patch_size), dtype=self.data_type_).cuda() + self.patch_embedding_bias_ = torch.empty(split_embed_dim, dtype=self.data_type_).cuda() + + # Pre-allocate memory for adapter weights + self.layernorm_weight_ = torch.empty(self.embed_dim, dtype=self.data_type_).cuda() + self.layernorm_bias_ = torch.empty(self.embed_dim, dtype=self.data_type_).cuda() + + split_indexes_llm = np.linspace(0, self.llm_hidden_size, self.tp_world_size_ + 1, dtype=np.int64) + split_start_llm = split_indexes_llm[self.tp_rank_] + split_end_llm = split_indexes_llm[self.tp_rank_ + 1] + split_llm_hidden_size = split_end_llm - split_start_llm + + self.mlp1_1_weight_ = torch.empty((self.llm_hidden_size, split_llm_hidden_size), dtype=self.data_type_).cuda() + self.mlp1_1_bias_ = torch.empty(split_llm_hidden_size, dtype=self.data_type_).cuda() + self.mlp1_3_weight_ = torch.empty((split_llm_hidden_size, self.llm_hidden_size), dtype=self.data_type_).cuda() + self.mlp1_3_bias_ = torch.empty(self.llm_hidden_size, dtype=self.data_type_).cuda() return def _cuda(self, cpu_tensor): @@ -40,40 +68,40 @@ def load_hf_weights(self, weights): split_start = split_indexes[self.tp_rank_] split_end = split_indexes[self.tp_rank_ + 1] if "vision_model.embeddings.class_embedding" in weights: - self.class_embedding = self._cuda( + self.class_embedding.copy_( weights["vision_model.embeddings.class_embedding"][:, :, split_start:split_end] ) if "vision_model.embeddings.position_embedding" in weights: - self.position_embedding = self._cuda( + self.position_embedding.copy_( weights["vision_model.embeddings.position_embedding"][:, :, split_start:split_end] ) if "vision_model.embeddings.patch_embedding.weight" in weights: - self.patch_embedding_weight_ = self._cuda( + self.patch_embedding_weight_.copy_( weights["vision_model.embeddings.patch_embedding.weight"][split_start:split_end, :, :, :] ) if "vision_model.embeddings.patch_embedding.bias" in weights: - self.patch_embedding_bias_ = self._cuda( + self.patch_embedding_bias_.copy_( weights["vision_model.embeddings.patch_embedding.bias"][split_start:split_end] ) if "mlp1.0.weight" in weights: - self.layernorm_weight_ = self._cuda(weights["mlp1.0.weight"]) + self.layernorm_weight_.copy_(weights["mlp1.0.weight"]) if "mlp1.0.bias" in weights: - self.layernorm_bias_ = self._cuda(weights["mlp1.0.bias"]) + self.layernorm_bias_.copy_(weights["mlp1.0.bias"]) split_indexes = np.linspace(0, self.llm_hidden_size, self.tp_world_size_ + 1, dtype=np.int64) split_start = split_indexes[self.tp_rank_] split_end = split_indexes[self.tp_rank_ + 1] if "mlp1.1.weight" in weights: - self.mlp1_1_weight_ = self._cuda(weights["mlp1.1.weight"][split_start:split_end, :]).t() + self.mlp1_1_weight_.copy_(weights["mlp1.1.weight"][split_start:split_end, :].t()) if "mlp1.1.bias" in weights: - self.mlp1_1_bias_ = self._cuda(weights["mlp1.1.bias"][split_start:split_end]) + self.mlp1_1_bias_.copy_(weights["mlp1.1.bias"][split_start:split_end]) if "mlp1.3.weight" in weights: - self.mlp1_3_weight_ = self._cuda(weights["mlp1.3.weight"][:, split_start:split_end]).t() + self.mlp1_3_weight_.copy_(weights["mlp1.3.weight"][:, split_start:split_end].t()) if "mlp1.3.bias" in weights: - self.mlp1_3_bias_ = self._cuda(weights["mlp1.3.bias"]) + self.mlp1_3_bias_.copy_(weights["mlp1.3.bias"]) return diff --git a/lightllm/models/vit/layer_weights/transformer_layer_weight.py b/lightllm/models/vit/layer_weights/transformer_layer_weight.py index f1de0bdc1..05d8edbad 100644 --- a/lightllm/models/vit/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/vit/layer_weights/transformer_layer_weight.py @@ -119,17 +119,18 @@ def _init_ffn(self): ) def _init_norm(self): + n_embed = self.network_config_["hidden_size"] self.att_norm_weight_ = NormWeight( - self._att_norm_weight_name, self.data_type_, bias_name=self._att_norm_bias_name + n_embed, self._att_norm_weight_name, self.data_type_, bias_name=self._att_norm_bias_name ) self.ffn_norm_weight_ = NormWeight( - self._ffn_norm_weight_name, self.data_type_, bias_name=self._ffn_norm_bias_name + n_embed, self._ffn_norm_weight_name, self.data_type_, bias_name=self._ffn_norm_bias_name ) if self.qk_norm: n_embed = self.network_config_["hidden_size"] split_n_embed = (n_embed + self.padding_hidden_size) // self.tp_world_size_ - self.q_norm_weight_ = TpNormWeight(self._q_norm_weight_name, self.data_type_, split_n_embed) - self.k_norm_weight_ = TpNormWeight(self._k_norm_weight_name, self.data_type_, split_n_embed) + self.q_norm_weight_ = TpNormWeight(split_n_embed, self._q_norm_weight_name, self.data_type_) + self.k_norm_weight_ = TpNormWeight(split_n_embed, self._k_norm_weight_name, self.data_type_) def load_hf_weights(self, weights): if f"vision_model.encoder.layers.{self.layer_num_}.attn.qkv.weight" in weights: diff --git a/lightllm/server/router/model_infer/mode_backend/redundancy_expert_manager.py b/lightllm/server/router/model_infer/mode_backend/redundancy_expert_manager.py index 811d39a72..e3a71379d 100644 --- a/lightllm/server/router/model_infer/mode_backend/redundancy_expert_manager.py +++ b/lightllm/server/router/model_infer/mode_backend/redundancy_expert_manager.py @@ -8,10 +8,10 @@ import json from typing import List from lightllm.common.basemodel.basemodel import TpPartBaseModel -from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe_weight_ep_redundancy import ( +from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe.fused_moe_weight_ep_redundancy import ( FusedMoeWeightEPAutoRedundancy, ) -from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe_weight_ep import FusedMoeWeightEP +from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe.fused_moe_weight_ep import FusedMoeWeightEP from lightllm.utils.envs_utils import get_env_start_args, get_redundancy_expert_update_interval from lightllm.utils.envs_utils import get_redundancy_expert_update_max_load_count from lightllm.utils.envs_utils import get_redundancy_expert_num From 3d225d70ca7d22846355a47dd6a30eee3e482f3a Mon Sep 17 00:00:00 2001 From: sufubao Date: Tue, 25 Nov 2025 04:07:03 +0000 Subject: [PATCH 11/71] add_cli --- lightllm/server/api_cli.py | 3 +-- lightllm/server/api_server.py | 12 ++++++++---- lightllm/server/core/objs/start_args_type.py | 3 +++ 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index ee3f184e4..fa01dd068 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -1,8 +1,7 @@ import argparse -def make_argument_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() +def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: parser.add_argument( "--run_mode", diff --git a/lightllm/server/api_server.py b/lightllm/server/api_server.py index dd531f58d..1eb5ff24c 100755 --- a/lightllm/server/api_server.py +++ b/lightllm/server/api_server.py @@ -1,16 +1,17 @@ import torch -from .api_cli import make_argument_parser +from .api_cli import add_cli_args from lightllm.server.core.objs.start_args_type import StartArgs from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) + def launch_server(args: StartArgs): from .api_start import pd_master_start, normal_or_p_d_start, config_server_start - + try: # this code will not be ok for settings to fork to subprocess - torch.multiprocessing.set_start_method("spawn") + torch.multiprocessing.set_start_method("spawn") except RuntimeError as e: logger.warning(f"Failed to set start method: {e}") except Exception as e: @@ -26,7 +27,10 @@ def launch_server(args: StartArgs): if __name__ == "__main__": - parser = make_argument_parser() + from argparse import ArgumentParser + + parser = ArgumentParser() + add_cli_args(parser) args = parser.parse_args() launch_server(StartArgs(**vars(args))) diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index eff4dfab5..0af795a09 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -150,3 +150,6 @@ class StartArgs: enable_weight_cpu_backup: bool = field(default=False) weight_version: str = "default" + + enable_torch_memory_saver: bool = field(default=False) + enable_weight_cpu_backup: bool = field(default=False) From 499074a871727c2787ea90963112410d1bc8c1fc Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 8 Dec 2025 14:42:16 +0000 Subject: [PATCH 12/71] add 30b moe configs --- lightllm/common/basemodel/basemodel.py | 2 +- .../basemodel/layer_weights/hf_load_utils.py | 69 +---------- lightllm/common/quantization/w8a8_quant.py | 13 --- ...num=8,use_fp8_w8a8=false}_NVIDIA_H200.json | 110 ++++++++++++++++++ ...num=1,use_fp8_w8a8=false}_NVIDIA_H200.json | 110 ++++++++++++++++++ ...orch.bfloat16,topk_num=8}_NVIDIA_H200.json | 74 ++++++++++++ ...out_dtype=torch.bfloat16}_NVIDIA_H200.json | 74 ++++++++++++ lightllm/models/bloom/model.py | 9 -- lightllm/models/deepseek2/model.py | 9 -- lightllm/models/llama/model.py | 20 ---- lightllm/server/api_cli.py | 2 +- lightllm/server/api_start.py | 14 +-- 12 files changed, 378 insertions(+), 128 deletions(-) create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=768,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=384,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=384,out_dtype=torch.bfloat16}_NVIDIA_H200.json diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 3eb5d7dbe..a0cafa25d 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -125,7 +125,7 @@ def __init__(self, kvargs): self._init_some_value() self._init_custom() self._init_inferstate_cls() - # self._autotune_warmup() + self._autotune_warmup() self._init_padded_req() # wait必须在init cudagraph 之前,避免错误捕获 self._wait_other_modules_ready() diff --git a/lightllm/common/basemodel/layer_weights/hf_load_utils.py b/lightllm/common/basemodel/layer_weights/hf_load_utils.py index 2a9006efd..ec0e28284 100755 --- a/lightllm/common/basemodel/layer_weights/hf_load_utils.py +++ b/lightllm/common/basemodel/layer_weights/hf_load_utils.py @@ -30,7 +30,7 @@ def load_func(file_, use_safetensors=False, pre_post_layer=None, transformer_lay gc.collect() -def load_hf_weights_old(data_type, weight_dir, pre_post_layer=None, transformer_layer_list=None, weight_dict=None): +def load_hf_weights(data_type, weight_dir, pre_post_layer=None, transformer_layer_list=None, weight_dict=None): if isinstance(data_type, str): data_type = torch.float16 if data_type == "fp16" else torch.float32 if pre_post_layer is not None: @@ -67,74 +67,7 @@ def load_hf_weights_old(data_type, weight_dir, pre_post_layer=None, transformer_ iterator = p.imap_unordered(partial_func, candidate_files, chunksize=1) desc_str = f"pid {os.getpid()} Loading model weights with {worker} workers" iterator = tqdm(iterator, total=len(candidate_files), desc=desc_str) - for _ in iterator: pass return - - -def _read_file(file_, use_safetensors, weight_dir): - if use_safetensors: - weights = safe_open(os.path.join(weight_dir, file_), "pt", "cpu") - weights = {k: weights.get_tensor(k) for k in weights.keys()} - else: - weights = utils.PetrelHelper.load(os.path.join(weight_dir, file_), map_location="cpu") - - return weights - - -def load_hf_weights(data_type, weight_dir, pre_post_layer=None, transformer_layer_list=None, weight_dict=None): - if isinstance(data_type, str): - data_type = torch.float16 if data_type == "fp16" else torch.float32 - if pre_post_layer is not None: - assert pre_post_layer.data_type_ == data_type, "type is not right" - if transformer_layer_list is not None: - assert transformer_layer_list[0].data_type_ == data_type, "type is not right" - if weight_dict: - if pre_post_layer is not None: - pre_post_layer.load_hf_weights(weight_dict) - if transformer_layer_list is not None: - for layer in transformer_layer_list: - layer.load_hf_weights(weight_dict) - del weight_dict - return - use_safetensors = True - files = utils.PetrelHelper.list(weight_dir, extension="all") - candidate_files = list(filter(lambda x: x.endswith(".safetensors"), files)) - if len(candidate_files) == 0: - use_safetensors = False - candidate_files = list(filter(lambda x: x.endswith(".bin"), files)) - assert len(candidate_files) != 0, "can only support pytorch tensor and safetensors format for weights." - - weight_queue = Queue(maxsize=5) # 控制内存使用 - - def producer(chunk): - for file_ in chunk: - weights = _read_file(file_, use_safetensors, weight_dir) - weight_queue.put(weights) - - LOADWORKER = int(os.environ.get("LOADWORKER", 1)) - - num_producers = min(LOADWORKER, len(candidate_files)) # 生产者数量 - chunk_size = (len(candidate_files) + num_producers - 1) // num_producers - file_chunks = [candidate_files[i : i + chunk_size] for i in range(0, len(candidate_files), chunk_size)] - - producer_threads = [] - for i, chunk in enumerate(file_chunks): - thread = Thread(target=producer, args=(chunk,), name=f"Producer-{i}") - thread.start() - producer_threads.append(thread) - - for _ in tqdm(range(len(candidate_files)), desc="Loading weights"): - weights = weight_queue.get() - if pre_post_layer is not None: - pre_post_layer.load_hf_weights(weights) - if transformer_layer_list is not None: - for layer in transformer_layer_list: - layer.load_hf_weights(weights) - del weights - gc.collect() - - for thread in producer_threads: - thread.join() diff --git a/lightllm/common/quantization/w8a8_quant.py b/lightllm/common/quantization/w8a8_quant.py index 8c5d1cc1e..e4f7b552a 100644 --- a/lightllm/common/quantization/w8a8_quant.py +++ b/lightllm/common/quantization/w8a8_quant.py @@ -127,19 +127,6 @@ def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0) -> output.weight_scale[offset : offset + weight_scale.shape[0]].copy_(weight_scale.view(-1)) return - def quantize_moe(self, weight: torch.Tensor) -> WeightPack: - num_experts = weight.shape[0] - qweights = torch.empty_like(weight, dtype=torch.float8_e4m3fn).cuda(self.device_id_) - weight_scales = [] - for i in range(num_experts): - qweight, weight_scale = scaled_fp8_quant( - weight[i].contiguous().cuda(self.device_id_), scale=None, use_per_token_if_dynamic=True - ) - qweights[i] = qweight - weight_scales.append(weight_scale) - weight_scale = torch.stack(weight_scales, dim=0).contiguous() - return WeightPack(weight=qweights, weight_scale=weight_scale) - def apply( self, input_tensor: torch.Tensor, diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=768,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=768,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 000000000..c75c871c7 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=768,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 5, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 8 + }, + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "8448": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=384,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=384,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 000000000..14026090e --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=384,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,110 @@ +{ + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "32768": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "67584": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "800": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8192": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H200.json new file mode 100644 index 000000000..939c93952 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H200.json @@ -0,0 +1,74 @@ +{ + "1": { + "BLOCK_DIM": 128, + "BLOCK_M": 2, + "NUM_STAGE": 2, + "num_warps": 4 + }, + "100": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "1024": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "128": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 16 + }, + "16": { + "BLOCK_DIM": 128, + "BLOCK_M": 1, + "NUM_STAGE": 2, + "num_warps": 4 + }, + "2048": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 2 + }, + "256": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "32": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 2 + }, + "4096": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "64": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "8": { + "BLOCK_DIM": 128, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "8448": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=384,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=384,out_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 000000000..13ba4ba8e --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=384,out_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,74 @@ +{ + "1024": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "128": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "16384": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "2048": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "256": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 2, + "num_warps": 4 + }, + "32768": { + "BLOCK_M": 32, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "512": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 8 + }, + "64": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "67584": { + "BLOCK_M": 64, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "8": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 4 + }, + "800": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "8192": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/models/bloom/model.py b/lightllm/models/bloom/model.py index 2c341a790..4a07a7ff5 100644 --- a/lightllm/models/bloom/model.py +++ b/lightllm/models/bloom/model.py @@ -56,13 +56,4 @@ def _init_weights(self): ) for i in range(self.config["n_layer"]) ] - load_hf_weights( - self.data_type, - weight_dir=self.weight_dir_, - pre_post_layer=self.pre_post_weight, - transformer_layer_list=self.trans_layers_weight, - weight_dict=self.weight_dict, - ) - self.pre_post_weight.verify_load() - [weight.verify_load() for weight in self.trans_layers_weight] return diff --git a/lightllm/models/deepseek2/model.py b/lightllm/models/deepseek2/model.py index a08147769..0ac24cf8b 100644 --- a/lightllm/models/deepseek2/model.py +++ b/lightllm/models/deepseek2/model.py @@ -129,15 +129,6 @@ def _init_weights(self): ) for i in range(self.config["n_layer"]) ] - load_hf_weights( - self.data_type, - weight_dir=self.weight_dir_, - pre_post_layer=self.pre_post_weight, - transformer_layer_list=self.trans_layers_weight, - weight_dict=self.weight_dict, - ) - self.pre_post_weight.verify_load() - [weight.verify_load() for weight in self.trans_layers_weight] return def _init_infer_layer(self): diff --git a/lightllm/models/llama/model.py b/lightllm/models/llama/model.py index abc258e8b..4ff802d81 100644 --- a/lightllm/models/llama/model.py +++ b/lightllm/models/llama/model.py @@ -146,26 +146,6 @@ def _init_weights(self): ) for i in range(self.config["n_layer"]) ] - if self.load_way == "HF": - load_hf_weights( - self.data_type, - weight_dir=self.weight_dir_, - pre_post_layer=self.pre_post_weight, - transformer_layer_list=self.trans_layers_weight, - weight_dict=self.weight_dict, - ) - else: - load_ds_weights( - self.data_type, - weight_dir=self.weight_dir_, - pre_post_layer=self.pre_post_weight, - transformer_layer_list=self.trans_layers_weight, - weight_dict=self.weight_dict, - prefix="model.layers.", - num_layer=self.config["n_layer"], - ) - self.pre_post_weight.verify_load() - [weight.verify_load() for weight in self.trans_layers_weight] return def _init_to_get_rotary(self, default_base=10000): diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index fa01dd068..03f751d36 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -363,7 +363,7 @@ def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: "--visual_nccl_ports", nargs="+", type=int, - default=[29500], + default=[], help="List of NCCL ports to build a distributed environment for Vit, e.g., 29500 29501 29502", ) parser.add_argument( diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 6a02dda17..ffd794b2d 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -181,13 +181,13 @@ def normal_or_p_d_start(args: StartArgs): args.visual_gpu_ids = args.visual_gpu_ids[:total_required_gpus] # 检查visual_nccl_port数量是否足够 - if len(args.visual_nccl_ports) < args.visual_dp: - raise ValueError( - f"Not enough visual_nccl_ports specified. You need at least {args.visual_dp}, " - f"but got ({len(args.visual_nccl_ports)})." - ) - else: - args.visual_nccl_ports = args.visual_nccl_ports[: args.visual_dp] + # if len(args.visual_nccl_ports) < args.visual_dp: + # raise ValueError( + # f"Not enough visual_nccl_ports specified. You need at least {args.visual_dp}, " + # f"but got ({len(args.visual_nccl_ports)})." + # ) + # else: + # args.visual_nccl_ports = args.visual_nccl_ports[: args.visual_dp] if args.visual_dp <= 0: raise ValueError("visual_dp must be a positive integer.") From f73758582d30e67495c2eb718129b1cf7e6f79b5 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Tue, 9 Dec 2025 09:10:59 +0000 Subject: [PATCH 13/71] update requirement --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 20f27dc05..f062c9d2f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -89,3 +89,4 @@ orjson==3.11.2 setproctitle==1.3.6 xxhash==3.6.0 torch_memory_saver==0.0.9 +portpicker==1.6.0 From 8a67a4751063f0ec89566f18c1091440a7dab2aa Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Fri, 26 Dec 2025 10:21:19 +0000 Subject: [PATCH 14/71] add-neo-chat --- lightllm/models/__init__.py | 1 + lightllm/models/neo_chat/__init__.py | 0 .../models/neo_chat/layer_infer/__init__.py | 0 .../layer_infer/transformer_layer_infer.py | 7 + .../models/neo_chat/layer_weights/__init__.py | 0 .../pre_and_post_layer_weight.py | 23 ++ .../layer_weights/transformer_layer_weight.py | 7 + lightllm/models/neo_chat/model.py | 138 +++++++++ lightllm/models/neo_chat/neo_visual.py | 273 ++++++++++++++++++ lightllm/models/neo_chat/vision_process.py | 141 +++++++++ lightllm/server/tokenizer.py | 3 + .../visualserver/model_infer/model_rpc.py | 3 + 12 files changed, 596 insertions(+) create mode 100644 lightllm/models/neo_chat/__init__.py create mode 100644 lightllm/models/neo_chat/layer_infer/__init__.py create mode 100644 lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py create mode 100644 lightllm/models/neo_chat/layer_weights/__init__.py create mode 100644 lightllm/models/neo_chat/layer_weights/pre_and_post_layer_weight.py create mode 100644 lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py create mode 100644 lightllm/models/neo_chat/model.py create mode 100644 lightllm/models/neo_chat/neo_visual.py create mode 100644 lightllm/models/neo_chat/vision_process.py diff --git a/lightllm/models/__init__.py b/lightllm/models/__init__.py index 4ee02f003..5618dfd0c 100644 --- a/lightllm/models/__init__.py +++ b/lightllm/models/__init__.py @@ -38,4 +38,5 @@ Tarsier2LlamaTpPartModel, ) from lightllm.models.gpt_oss.model import GptOssTpPartModel +from lightllm.models.neo_chat.model import NeoTpPartModel from .registry import get_model, get_model_class diff --git a/lightllm/models/neo_chat/__init__.py b/lightllm/models/neo_chat/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/neo_chat/layer_infer/__init__.py b/lightllm/models/neo_chat/layer_infer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py new file mode 100644 index 000000000..c9297ee84 --- /dev/null +++ b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py @@ -0,0 +1,7 @@ +from lightllm.models.qwen3_moe.layer_infer.transformer_layer_infer import Qwen3MOETransformerLayerInfer + + +class NeoChatMOETransformerLayerInfer(Qwen3MOETransformerLayerInfer): + def __init__(self, data_type, network_config, mode): + super().__init__(data_type, network_config, mode) + return diff --git a/lightllm/models/neo_chat/layer_weights/__init__.py b/lightllm/models/neo_chat/layer_weights/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/neo_chat/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/neo_chat/layer_weights/pre_and_post_layer_weight.py new file mode 100644 index 000000000..7766a5d29 --- /dev/null +++ b/lightllm/models/neo_chat/layer_weights/pre_and_post_layer_weight.py @@ -0,0 +1,23 @@ +import torch +import numpy as np +from lightllm.models.qwen2.layer_weights.pre_and_post_layer_weight import Qwen2PreAndPostLayerWeight + +# add key: language_model.xxx -> xxx +# only change keys at PreAndPostLayerWeight load, TransformLayerWeight is correct now +def rename_weight_keys(weights): + prefix = "language_model." + keys = list(weights.keys()) + for k in keys: + if prefix in k: + weights[k.replace(prefix, "")] = weights.pop(k) + + +class NeoChatMOEPreAndPostLayerWeight(Qwen2PreAndPostLayerWeight): + def __init__(self, data_type, network_config, mode): + super().__init__(data_type, network_config, mode) + return + + def load_hf_weights(self, weights): + rename_weight_keys(weights) + super().load_hf_weights(weights) + return diff --git a/lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py b/lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py new file mode 100644 index 000000000..2dc87f3ca --- /dev/null +++ b/lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py @@ -0,0 +1,7 @@ +from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight + + +class NeoChatMOETransformerLayerWeight(Qwen3MOETransformerLayerWeight): + def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): + super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + return diff --git a/lightllm/models/neo_chat/model.py b/lightllm/models/neo_chat/model.py new file mode 100644 index 000000000..61fd98b98 --- /dev/null +++ b/lightllm/models/neo_chat/model.py @@ -0,0 +1,138 @@ +import os +import json +from lightllm.common.build_utils import repair_config +from lightllm.models.registry import ModelRegistry +from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo +from lightllm.models.qwen3_vl.layer_infer.pre_layer_infer import Qwen3VLMultimodalPreLayerInfer +from lightllm.models.qwen3_vl.layer_infer.transformer_layer_infer import Qwen3VLTransformerLayerInfer +from lightllm.models.qwen3_vl.layer_weights.pre_and_post_layer_weight import Qwen3VLPreAndPostLayerWeight +from lightllm.models.qwen2_vl.model import QWen2VLTokenizer +from lightllm.models.qwen3.model import Qwen3TpPartModel +from lightllm.server.core.objs import SamplingParams +from lightllm.models.qwen3_moe.model import Qwen3MOEModel +from lightllm.server.multimodal_params import AudioItem, MultimodalParams, ImageItem +from lightllm.models.neo_chat.vision_process import smart_resize +from lightllm.models.internvl.model import InternvlTokenizer +from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer +from lightllm.models.neo_chat.layer_infer.transformer_layer_infer import NeoChatMOETransformerLayerInfer +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.models.neo_chat.layer_weights.transformer_layer_weight import NeoChatMOETransformerLayerWeight +from lightllm.models.neo_chat.layer_weights.pre_and_post_layer_weight import NeoChatMOEPreAndPostLayerWeight +from lightllm.common.basemodel.multimodal_tokenizer import BaseMultiModalTokenizer + +IMG_START_TOKEN = "" +IMG_END_TOKEN = "" +IMG_TOKEN = "" +AUDIO_START_TOKEN = "" + + +class NeoChatTokenizer(BaseMultiModalTokenizer): + def __init__(self, tokenizer, model_cfg, **kwargs): + super().__init__(tokenizer) + self.tokenizer = tokenizer + self.min_pixel = model_cfg.get("vision_config").get("min_pixels") + self.max_pixel = model_cfg.get("vision_config").get("max_pixels") + self.patch_size = model_cfg.get("vision_config").get("patch_size") + self.downsample_ratio = model_cfg.get("vision_config").get("downsample_ratio") + + self.image_token_id = model_cfg.get("image_token_id") + self.image_start_tag = IMG_START_TOKEN + self.image_start_id = tokenizer.convert_tokens_to_ids(self.image_start_tag) + self.image_end_tag = IMG_END_TOKEN + self.image_end_id = tokenizer.convert_tokens_to_ids(self.image_end_tag) + + def init_imageitem_extral_params( + self, img: ImageItem, multi_params: MultimodalParams, sampling_params: SamplingParams + ): + return + + def init_audioitem_extral_params( + self, audio: AudioItem, multi_params: MultimodalParams, sampling_params: SamplingParams + ): + raise NotImplementedError + + def get_audio_token_length(self, audio: AudioItem): + raise NotImplementedError + + def get_image_token_length(self, img: ImageItem): + width, height = img.image_w, img.image_h + resized_height, resized_width = smart_resize( + height=height, + width=width, + factor=int(self.patch_size // self.downsample_ratio), + min_pixels=self.min_pixel, + max_pixels=self.max_pixel, + ) + grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size + token_num = (grid_h * grid_w) * (self.downsample_ratio ** 2) + # grid_thwd是为了mrope准备的,这里不需要 + img.grid_thwd = (1, grid_h, grid_w, 0) + return int(token_num) + + # only change the impl of the encode func: + def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs): + # TEXTTEXTTEXT --> TEXTTEXTTEXT + image_tokens = IMG_START_TOKEN + IMG_END_TOKEN + if multimodal_params is None: + add_special_tokens = kwargs.get("add_special_tokens", True) + return self.tokenizer.encode(prompt, add_special_tokens=add_special_tokens) + image_count = len(multimodal_params.images) + prompt = prompt.replace(IMG_TOKEN, image_tokens, image_count) + + origin_ids = self.tokenizer.encode(prompt, add_special_tokens=kwargs["add_special_tokens"]) + # --> id,id+1...id+num + input_ids = [] + image_id = 0 + start_idx = 0 + while True: + try: + start_idx = origin_ids.index(self.image_start_id, start_idx) + if start_idx + 1 >= len(origin_ids): + break + if origin_ids[start_idx + 1] == self.image_end_id: + input_ids.extend(origin_ids[: start_idx + 1]) + token_id = multimodal_params.images[image_id].token_id + token_num = multimodal_params.images[image_id].token_num + input_ids.extend(range(token_id, token_id + token_num)) + input_ids.append(self.image_end_id) + origin_ids = origin_ids[start_idx + 2 :] + start_idx = 0 + image_id += 1 + else: + raise ValueError("image token error") + except ValueError: + break + input_ids.extend(origin_ids[start_idx:]) + return input_ids + + +@ModelRegistry(["neo_chat"], is_multimodal=True) +class NeoTpPartModel(Qwen3MOEModel): + + pre_layer_infer_class = LlamaMultimodalPreLayerInfer + transformer_layer_infer_class = NeoChatMOETransformerLayerInfer + + pre_and_post_weight_class = NeoChatMOEPreAndPostLayerWeight + transformer_weight_class = NeoChatMOETransformerLayerWeight + + infer_state_class = LlamaInferStateInfo + + def __init__(self, kvargs): + super().__init__(kvargs) + return + + def _init_inferstate_cls(self): + pass + + def _init_config(self): + with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: + all_config = json.load(json_file) + self.config = all_config["llm_config"] + # rename keys + repair_config(self.config, same_names=["num_attention_heads", "n_head"]) + repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"]) + repair_config(self.config, same_names=["num_hidden_layers", "n_layer"]) + if self.finetune_config: + self.config["vocab_size"] = self.finetune_config.vocab_size + return diff --git a/lightllm/models/neo_chat/neo_visual.py b/lightllm/models/neo_chat/neo_visual.py new file mode 100644 index 000000000..c9d4b8161 --- /dev/null +++ b/lightllm/models/neo_chat/neo_visual.py @@ -0,0 +1,273 @@ +import os +import torch +import torch.nn.functional as F +from PIL import Image +from typing import List +from io import BytesIO +import torch.nn as nn +from transformers.activations import ACT2FN +from safetensors import safe_open +from lightllm.server.multimodal_params import ImageItem +from transformers.modeling_outputs import BaseModelOutputWithPooling +from transformers.modeling_utils import PreTrainedModel +from lightllm.models.neo_chat.vision_process import load_image_native +from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data + + +def apply_rotary_emb_1d( + x: torch.Tensor, + cos_cached: torch.Tensor, + sin_cached: torch.Tensor, + positions: torch.Tensor, +): + """对输入张量的一部分应用1D RoPE。""" + # x: (..., seq_len, dim_part) + # positions: (..., seq_len) + # cos_cached: (max_pos, dim_part / 2) + cos_cached = cos_cached.to(device=positions.device) + sin_cached = sin_cached.to(device=positions.device) + + cos = cos_cached[positions] # Shape: (positions.shape, dim_part / 2) + sin = sin_cached[positions] # Shape: (positions.shape, dim_part / 2) + + x1 = x[..., 0::2] + x2 = x[..., 1::2] + + rotated_x1 = x1 * cos - x2 * sin + rotated_x2 = x1 * sin + x2 * cos + + x_rotated = torch.empty_like(x) + x_rotated[..., 0::2] = rotated_x1 + x_rotated[..., 1::2] = rotated_x2 + return x_rotated + + +def apply_2d_rotary_pos_emb( + x: torch.Tensor, + cos_cached_x: torch.Tensor, + sin_cached_x: torch.Tensor, + cos_cached_y: torch.Tensor, + sin_cached_y: torch.Tensor, + abs_positions_x: torch.Tensor, + abs_positions_y: torch.Tensor, +): + """应用2D RoPE到输入张量x。""" + dim = x.shape[-1] + dim_half = dim // 2 + + # 假设我们将embedding的前半部分用于一个方向的RoPE,后半部分用于另一个方向 + # 例如,前一半给X坐标,后一半给Y坐标 (或者反过来,但要保持一致) + x_part_1 = x[..., :dim_half] + x_part_2 = x[..., dim_half:] + + # 将与 abs_positions_x 相关的旋转应用于 x_part_1 + rotated_part_1 = apply_rotary_emb_1d(x_part_1, cos_cached_x, sin_cached_x, abs_positions_x) + # 将与 abs_positions_y 相关的旋转应用于 x_part_2 + rotated_part_2 = apply_rotary_emb_1d(x_part_2, cos_cached_y, sin_cached_y, abs_positions_y) + + # 将它们重新拼接起来。确保顺序与你分割时一致。 + return torch.cat((rotated_part_1, rotated_part_2), dim=-1) + + +def build_abs_positions_from_grid_hw(grid_hw: torch.Tensor, device=None): + """ + Compute patch coordinates (x, y) + + Args: + grid_hw: (B, 2) tensor representing (H, W) per image + """ + device = grid_hw.device + B = grid_hw.shape[0] + + # Get the number of patches per image + H = grid_hw[:, 0] + W = grid_hw[:, 1] + N = H * W + N_total = N.sum() + + # Create the batch index for each patch (B x patch count) + patch_to_sample = torch.repeat_interleave(torch.arange(B, device=device), N) # (N_total,) + + # Generate intra-image patch index (row-major order) + patch_id_within_image = torch.arange(N_total, device=device) + patch_id_within_image = ( + patch_id_within_image + - torch.cumsum(torch.cat([torch.tensor([0], device=device), N[:-1]]), dim=0)[patch_to_sample] + ) + + # Get H/W for each patch according to its image + W_per_patch = W[patch_to_sample] + abs_x = patch_id_within_image % W_per_patch + abs_y = patch_id_within_image // W_per_patch + + return abs_x, abs_y + + +class NeoVisionTransformerPretrainedModel(nn.Module): + def __init__( + self, + kvargs, + hidden_size: int = 1024, + llm_hidden_size: int = 2048, + downsample_ratio: float = 0.5, + patch_size: int = 16, + num_channels: int = 3, + max_position_embeddings_vision: int = 10000, + rope_theta_vision: float = 10000.0, + min_pixels: int = 65536, + max_pixels: int = 2408448, + **kwargs, + ): + super().__init__() + self.weight_dir = kvargs["weight_dir"] + self.data_type = kvargs.get("data_type", "bfloat16") + self.embed_dim = hidden_size + self.llm_hidden_size = llm_hidden_size + self.patch_size = patch_size + self.num_channels = num_channels + self.downsample_ratio = downsample_ratio + self.downsample_factor = int(1 / downsample_ratio) + self.max_position_embeddings_vision = max_position_embeddings_vision + self.rope_theta_vision = rope_theta_vision + self.rope_dim_part = self.embed_dim // 2 + self.min_pixels = min_pixels + self.max_pixels = max_pixels + + self.patch_embedding = nn.Conv2d( + in_channels=num_channels, out_channels=self.embed_dim, kernel_size=patch_size, stride=patch_size + ) + + self.dense_embedding = nn.Conv2d( + in_channels=self.embed_dim, + out_channels=self.llm_hidden_size, + kernel_size=self.downsample_factor, + stride=self.downsample_factor, + ) + self.gelu = nn.GELU() + + self.repe_dim_part = self.embed_dim // 2 + self.cos_x, self.sin_x = self.precompute_rope_freqs_sincos() + self.cos_y, self.sin_y = self.precompute_rope_freqs_sincos() + self._init_datatype() + + def _init_datatype(self): + if isinstance(self.data_type, torch.dtype): + return + if self.data_type in ["fp16", "float16"]: + self.data_type = torch.float16 + elif self.data_type in ["bf16", "bfloat16"]: + self.data_type = torch.bfloat16 + elif self.data_type in ["fp32", "float32"]: + self.data_type = torch.float32 + else: + raise ValueError(f"Unsupport datatype {self.data_type}!") + return + + def load_model(self, weight_dir): + bin_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")] + if bin_weight_files: + weight_dict = {} + for file_ in bin_weight_files: + f = torch.load(os.path.join(weight_dir, file_), "cpu") + for k, v in f.items(): + if "vision_model" in k: + weight_dict[k[len("vision_model.embeddings.") :]] = v + else: + hf_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".safetensors")] + weight_dict = {} + for file_ in hf_weight_files: + f = safe_open(os.path.join(weight_dir, file_), "pt", "cpu") + for k in f.keys(): + if "vision_model" in k: + weight_dict[k[len("vision_model.embeddings.") :]] = f.get_tensor(k) + self.load_state_dict(weight_dict) + + def precompute_rope_freqs_sincos(self): + inv_freq = 1.0 / ( + self.rope_theta_vision ** (torch.arange(0, self.rope_dim_part, 2).float() / self.rope_dim_part) + ) + t = torch.arange(self.max_position_embeddings_vision).type_as(inv_freq) + freqs = torch.outer(t, inv_freq) + return torch.cos(freqs), torch.sin(freqs) + + def _apply_2d_rotary_pos_emb(self, patch_embeds, grid_hw): + """ + Apply 2D Rotary Position Embedding to the patch embeddings. + """ + abs_pos_x, abs_pos_y = build_abs_positions_from_grid_hw(grid_hw, device=patch_embeds.device) + embeddings = apply_2d_rotary_pos_emb( + patch_embeds.to(torch.float32), # RoPE calculations are often more stable in float32 + self.cos_x, + self.sin_x, + self.cos_y, + self.sin_y, + abs_pos_x, + abs_pos_y, + ).to(self.patch_embedding.weight.dtype) + return embeddings + + def forward(self, pixel_values: torch.Tensor, grid_hw: torch.Tensor) -> torch.Tensor: + pixel_values = pixel_values.view( + -1, + 3, + self.patch_size, + self.patch_size, + ) + patch_embeds = self.gelu(self.patch_embedding(pixel_values)).view(-1, self.embed_dim) + patch_embeds = self._apply_2d_rotary_pos_emb(patch_embeds, grid_hw) + assert (grid_hw[:, 0] * grid_hw[:, 1]).sum() == patch_embeds.shape[ + 0 + ], "Grid size and patch embeds size mismatch." + + patches_list = [] + cur_position = 0 + for i in range(grid_hw.shape[0]): + h, w = grid_hw[i] + patches_per_img = patch_embeds[cur_position : cur_position + h * w].view(h, w, -1).unsqueeze(0) + patches_per_img = self.dense_embedding(patches_per_img.permute(0, 3, 1, 2)) + patches_per_img = patches_per_img.permute(0, 2, 3, 1) + patches_list.append(patches_per_img.view(-1, patches_per_img.shape[-1])) + cur_position += h * w + + embeddings = torch.cat(patches_list, dim=0) # (N_total // downsample_factor**2, C) + assert cur_position == patch_embeds.shape[0] + assert embeddings.shape[0] == int(patch_embeds.shape[0] / self.downsample_factor ** 2) + + return embeddings + + def encode(self, images: List[ImageItem]): + img_tensors = [] + valid_ids = [] + valid_id = 0 + img_grids = [] + uuids = [] + + for i, img in enumerate(images): + if isinstance(img, ImageItem): + uuids.append(img.uuid) + image_data = read_shm(get_shm_name_data(img.uuid)) + image_data = Image.open(BytesIO(image_data)) + pixel_values, image_grid_hw = load_image_native(image_data) + img_tensors.append(pixel_values) + img_grids.append(image_grid_hw) + else: + raise Exception("Unsupport input types: {} for {}".format(type(img), img)) + + # must devide merge_length + cur_num = int(img_tensors[-1].shape[0] * (self.downsample_ratio ** 2)) + print(f"cur_num is {cur_num}") + valid_ids.append([valid_id, valid_id + cur_num]) + valid_id += cur_num + + if len(img_tensors) <= 0: + return None + + imgs = torch.cat(img_tensors, dim=0) + grid_hw = torch.cat(img_grids, dim=0) + + pixel_values = imgs.to("cuda", dtype=self.data_type, non_blocking=True) + image_grid_hw = grid_hw.to("cuda", non_blocking=True) + + all_img_embeds = self.forward(pixel_values, grid_hw=image_grid_hw) + + return all_img_embeds, uuids, valid_ids diff --git a/lightllm/models/neo_chat/vision_process.py b/lightllm/models/neo_chat/vision_process.py new file mode 100644 index 000000000..aa008e18f --- /dev/null +++ b/lightllm/models/neo_chat/vision_process.py @@ -0,0 +1,141 @@ +import re +import math +import torch +import string +import numpy as np +import pandas as pd +from PIL import Image +import torch.distributed as dist +import torchvision.transforms as T + +IMAGENET_MEAN = (0.485, 0.456, 0.406) +IMAGENET_STD = (0.229, 0.224, 0.225) + + +def round_by_factor(number: int, factor: int) -> int: + """Returns the closest integer to 'number' that is divisible by 'factor'.""" + return round(number / factor) * factor + + +def ceil_by_factor(number: int, factor: int) -> int: + """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" + return math.ceil(number / factor) * factor + + +def floor_by_factor(number: int, factor: int) -> int: + """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" + return math.floor(number / factor) * factor + + +# copy from https://github.com/QwenLM/Qwen2.5-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py#L60 +def smart_resize( + height: int, width: int, factor: int = 32, min_pixels: int = 65536, max_pixels: int = 4194304 +) -> tuple[int, int]: + """ + Rescales the image so that the following conditions are met: + + 1. Both dimensions (height and width) are divisible by 'factor'. + + 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. + + 3. The aspect ratio of the image is maintained as closely as possible. + """ + if max(height, width) / min(height, width) > 200: + raise ValueError( + f"absolute aspect ratio must be smaller than {200}, got {max(height, width) / min(height, width)}" + ) + h_bar = max(factor, round_by_factor(height, factor)) + w_bar = max(factor, round_by_factor(width, factor)) + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = max(factor, floor_by_factor(height / beta, factor)) + w_bar = max(factor, floor_by_factor(width / beta, factor)) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = ceil_by_factor(height * beta, factor) + w_bar = ceil_by_factor(width * beta, factor) + return h_bar, w_bar + + +def dynamic_preprocess_native_resolution(image, size_factor=32, min_pixels=65536, max_pixels=4194304, **kwargs): + width, height = image.size + resized_height, resized_width = smart_resize( + height, + width, + factor=size_factor, + min_pixels=min_pixels, + max_pixels=max_pixels, + ) + image = image.resize((resized_width, resized_height)) + + return image + + +def preprocess_pixel_values(pixel_values, patch_size=16): + c, h, w = pixel_values.shape + grid_h = h // patch_size + grid_w = w // patch_size + + flatten_pixel_values = ( + pixel_values.view(c, grid_h, patch_size, grid_w, patch_size) + .permute(1, 3, 0, 2, 4) # [grid_h, grid_w, c, patch_size, patch_size] + .reshape(grid_h * grid_w, c * patch_size ** 2) + ) + + grid_hw = torch.tensor([[grid_h, grid_w]]).to(device=pixel_values.device) + + return flatten_pixel_values, grid_hw + + +def get_contrasting_background(image): + """ + Calculate the color (white or black) that is different from the average foreground color + to use as the background color + """ + image_np = np.array(image) + if (image_np[:, :, 3] == 0).any(): + non_transparent_pixels = image_np[:, :, :3][image_np[:, :, 3] > 0] + if non_transparent_pixels.size == 0: + return None + pixel_mean = non_transparent_pixels.mean() + contrasting_color = (0, 0, 0) if pixel_mean > 382.5 else (255, 255, 255) + return contrasting_color + else: + return None + + +def load_image_native(image, patch_size=16, downsample_ratio=0.5, min_pixels=65536, max_pixels=4194304, upscale=False): + """ + Load and preprocess an image file, converting it to RGB mode, + resizing, normalizing, and optionally adding a thumbnail version. + """ + if image.mode == "RGBA": + bg_color = get_contrasting_background(image) + if bg_color: + background = Image.new("RGB", image.size, bg_color) + background.paste(image, mask=image.split()[3]) + image = background.convert("RGB") + else: + image = image.convert("RGB") + else: + image = image.convert("RGB") + + if upscale: + image = image.resize((image.width * 2, image.height * 2), Image.BILINEAR) + + transform = T.Compose( + [ + T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), + T.ToTensor(), + T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), + ] + ) + + new_image = dynamic_preprocess_native_resolution( + image, size_factor=int(patch_size // downsample_ratio), min_pixels=min_pixels, max_pixels=max_pixels + ) + pixel_values, grid_hw = preprocess_pixel_values(transform(new_image).to(torch.float32), patch_size=patch_size) + + print(f"Transfer image_size from ({image.height, image.width}) to ({new_image.height, new_image.width})") + + return pixel_values, grid_hw diff --git a/lightllm/server/tokenizer.py b/lightllm/server/tokenizer.py index e0b2bd425..17f5a741a 100644 --- a/lightllm/server/tokenizer.py +++ b/lightllm/server/tokenizer.py @@ -30,6 +30,7 @@ from ..models.qwen2_vl.model import QWen2VLTokenizer from ..models.qwen3_vl.model import QWen3VLTokenizer from ..models.internvl.model import InternvlTokenizer +from ..models.neo_chat.model import NeoChatTokenizer from ..models.gemma3.model import Gemma3Tokenizer # A fast LLaMA tokenizer with the pre-processed `tokenizer.json` file. @@ -104,5 +105,7 @@ def get_tokenizer( tokenizer = InternvlTokenizer(tokenizer, model_cfg, weight_dir=tokenizer_name) elif model_type == "gemma3": tokenizer = Gemma3Tokenizer(tokenizer, model_cfg) + elif model_type == "neo_chat": + tokenizer = NeoChatTokenizer(tokenizer, model_cfg, weight_dir=tokenizer_name) return tokenizer diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index d3d1610f3..d77271af8 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -19,6 +19,7 @@ from lightllm.models.qwen2_5_vl.qwen2_5_visual import Qwen2_5_VisionTransformerPretrainedModel from lightllm.models.qwen3_vl.qwen3_visual import Qwen3VisionTransformerPretrainedModel from lightllm.models.tarsier2.tarsier2_visual import TarsierVisionTransformerPretrainedModel +from lightllm.models.neo_chat.neo_visual import NeoVisionTransformerPretrainedModel from lightllm.utils.infer_utils import set_random_seed from lightllm.utils.dist_utils import init_vision_distributed_env from lightllm.utils.graceful_utils import graceful_registry @@ -78,6 +79,8 @@ def exposed_init_model(self, kvargs): # self.model = InternVLVisionModel() elif self.model_type == "gemma3": self.model = Gemma3VisionModel() + elif self.model_type == "neo_chat": + self.model = NeoVisionTransformerPretrainedModel(kvargs, **model_cfg["vision_config"]).eval().bfloat16() else: raise Exception(f"can not support {self.model_type} now") From fdc1369e315487a59f99c5c52d769bf3328fb43f Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Tue, 30 Dec 2025 14:50:41 +0000 Subject: [PATCH 15/71] add-neo-chat --- .../kv_cache_mem_manager/mem_manager.py | 4 +- lightllm/models/llama/model.py | 47 +- lightllm/models/neo_chat/infer_state.py | 95 ++++ .../layer_infer/transformer_layer_infer.py | 166 +++++++ .../layer_weights/transformer_layer_weight.py | 44 ++ lightllm/models/neo_chat/model.py | 17 +- .../models/neo_chat/triton_kernel/__init__.py | 0 .../context_attention_fwd_neo.py | 467 ++++++++++++++++++ .../triton_kernel/get_neo_position.py | 174 +++++++ 9 files changed, 1003 insertions(+), 11 deletions(-) create mode 100644 lightllm/models/neo_chat/infer_state.py create mode 100644 lightllm/models/neo_chat/triton_kernel/__init__.py create mode 100644 lightllm/models/neo_chat/triton_kernel/context_attention_fwd_neo.py create mode 100644 lightllm/models/neo_chat/triton_kernel/get_neo_position.py diff --git a/lightllm/common/kv_cache_mem_manager/mem_manager.py b/lightllm/common/kv_cache_mem_manager/mem_manager.py index d8fd93009..b599bedfc 100755 --- a/lightllm/common/kv_cache_mem_manager/mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/mem_manager.py @@ -28,7 +28,7 @@ class MemoryManager: def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): self.size = size self.head_num = head_num - self.head_dim = head_dim + self.head_dim = head_dim * 2 # neo kv 是[k, k_h, k_w]拼在一起的 self.layer_num = layer_num self.always_copy = always_copy self.dtype = dtype @@ -60,7 +60,7 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False self.size, dtype, head_num, - head_dim, + self.head_dim, layer_num, ) self.HOLD_TOKEN_MEMINDEX = self.size diff --git a/lightllm/models/llama/model.py b/lightllm/models/llama/model.py index a228e0025..36b5d79b5 100644 --- a/lightllm/models/llama/model.py +++ b/lightllm/models/llama/model.py @@ -110,6 +110,8 @@ def _init_custom(self): rope_scaling = self.config.get("rope_scaling", None) if rope_scaling is None: self._init_to_get_rotary() + if "rope_theta_hw" in self.config: + self._init_to_get_hw_rotary() return if "rope_type" in rope_scaling: @@ -132,6 +134,8 @@ def _init_custom(self): self._init_to_get_rotary() else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + if "rope_theta_hw" in self.config: + self._init_to_get_hw_rotary() return def _init_weights(self): @@ -178,7 +182,7 @@ def _init_to_get_rotary(self, default_base=10000): rope_scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0) base = self.config.get("rope_theta", float(default_base)) - + print(f"base is {base}") if "max_sequence_length" in self.config: max_seq_len = self.config["max_sequence_length"] else: @@ -211,6 +215,47 @@ def _init_to_get_rotary(self, default_base=10000): self._sin_cached = torch.sin(freqs).to(self.data_type).cuda() return + def _init_to_get_hw_rotary(self, default_base=10000): + partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_ // 2) + if self.config.get("rope_scaling", {}) is None: + rope_scaling_factor = 1.0 + else: + rope_scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0) + + base = self.config.get("rope_theta_hw", float(default_base)) + print(f"hw_base is {base}") + if "max_sequence_length" in self.config: + max_seq_len = self.config["max_sequence_length"] + else: + max_position_embeddings = self.config.get( + "max_position_embeddings_hw", 2048 if base <= 10000.0 + 1e-5 else 16384 + ) + max_seq_len = max_position_embeddings * rope_scaling_factor + + # NTK + try: + ntk_alpha = float(os.environ.get("LIGHTLLM_NTK_ALPHA", 1)) + assert ntk_alpha >= 1 + if ntk_alpha > 1: + logger.info(f"Note: NTK enabled, alpha set to {ntk_alpha}") + max_seq_len *= ntk_alpha + base = base * (ntk_alpha ** (partial_head_dim / (partial_head_dim - 2))) # Base change formula + except: + pass + + inv_freq = 1.0 / ( + base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim) + ) + t = ( + torch.arange(max(max_seq_len + 1024 * 128, self.max_seq_length), device="cpu", dtype=torch.float32) + / rope_scaling_factor + ) + freqs = torch.outer(t, inv_freq) + + self._hw_cos_cached = torch.cos(freqs).to(self.data_type).cuda() + self._hw_sin_cached = torch.sin(freqs).to(self.data_type).cuda() + return + def _init_to_get_dynamic_ntk_rotary(self): partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_) max_position_embeddings = self.config.get("max_position_embeddings", 2048) diff --git a/lightllm/models/neo_chat/infer_state.py b/lightllm/models/neo_chat/infer_state.py new file mode 100644 index 000000000..9a71c3ddb --- /dev/null +++ b/lightllm/models/neo_chat/infer_state.py @@ -0,0 +1,95 @@ +from typing import Optional, List +import torch +import numpy as np +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.common.req_manager import ReqManager +from lightllm.models.neo_chat.triton_kernel.get_neo_position import get_neo_position_triton +from lightllm.models.llama.model import LlamaTpPartModel + + +class NeoChatInferStateInfo(LlamaInferStateInfo): + def __init__(self): + super().__init__() + self.position_cos = None + self.position_sin = None + self.position_cos_h = None + self.position_sin_h = None + self.position_cos_w = None + self.position_sin_w = None + + def init_some_extra_state(self, model: LlamaTpPartModel, input_ids: torch.Tensor): + LlamaInferStateInfo.init_some_extra_state(self, model, input_ids) + if self.is_prefill: + self.position_ids = self.get_neo_position(self.multimodal_params) + else: + b_position_delta = [0 for _ in range(self.b_seq_len.shape[0])] + for batch_idx, p in enumerate(self.multimodal_params): + position_delta = 0 + for image in p["images"]: + position_delta += image["grid_thwd"][3] + b_position_delta[batch_idx] = position_delta + position_ids = self.position_ids + torch.tensor(b_position_delta, device=self.position_ids.device) + self.position_ids = position_ids.unsqueeze(0).expand(3, -1) + self.position_ids[1:].zero_() + + self.position_ids = self.position_ids.contiguous() + self.position_cos = model._cos_cached[self.position_ids[0]] + self.position_sin = model._sin_cached[self.position_ids[0]] + self.position_cos_h = model._hw_cos_cached[self.position_ids[1]] + self.position_sin_h = model._hw_sin_cached[self.position_ids[1]] + self.position_cos_w = model._hw_cos_cached[self.position_ids[2]] + self.position_sin_w = model._hw_sin_cached[self.position_ids[2]] + return + + def get_neo_position(self, multimodal_params: List[dict]) -> torch.Tensor: + if len(multimodal_params) == 0: + return self.position_ids.unsqueeze(0).expand(3, -1) + b_image_start_idx = [] + b_image_nums = [] + b_image_start_num = [] + b_image_len = [] + image_start_num = 0 + b_image_thwd = [] + + # pad multimodal_params to batch size. + batch_size = self.b_q_seq_len.shape[0] + multimodal_params = multimodal_params + [ + {"images": [], "audios": []} for _ in range(batch_size - len(multimodal_params)) + ] + + for _, p in enumerate(multimodal_params): + images = p.get("images", []) + for img in images: + b_image_start_idx.append(img["start_idx"]) + a = img["start_idx"] + print(f"img start_idx: {a}") + b_image_len.append(img["token_num"]) + b_image_thwd.append(img["grid_thwd"]) + b_image_nums.append(len(images)) + b_image_start_num.append(image_start_num) + image_start_num += len(images) + + # 没有任何图片 + if image_start_num == 0: + return self.position_ids.unsqueeze(0).expand(3, -1).contiguous() + b_image_start_idx = torch.tensor(b_image_start_idx, device="cpu").cuda(non_blocking=True) + b_image_thwd = torch.tensor(b_image_thwd, device="cpu").cuda(non_blocking=True) # image_num x 4 + b_image_nums = torch.tensor(b_image_nums, device="cpu").cuda(non_blocking=True) + b_image_start_num = torch.tensor(b_image_start_num, device="cpu").cuda(non_blocking=True) + b_image_len = torch.tensor(b_image_len, device="cpu").cuda(non_blocking=True) + + position_ids = self.position_ids.new_zeros((3, self.position_ids.size(0))) + position_ids[0].copy_(self.position_ids) + + get_neo_position_triton( + b_image_start_idx=b_image_start_idx, + b_image_thwd=b_image_thwd, + b_image_nums=b_image_nums, + b_image_start_num=b_image_start_num, + b_image_len=b_image_len, + position_ids=position_ids, + b_ready_cache_len=self.b_ready_cache_len, + b_q_seq_len=self.b_q_seq_len, + b_start_loc=self.b_start_loc, + ) + return position_ids diff --git a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py index c9297ee84..b0ee42856 100644 --- a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py @@ -1,7 +1,173 @@ +import torch +from functools import partial +from typing import Tuple +from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward +from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd +from lightllm.models.neo_chat.infer_state import NeoChatInferStateInfo +from lightllm.models.neo_chat.triton_kernel.context_attention_fwd_neo import context_attention_fwd_neo +from lightllm.models.llama.triton_kernel.context_flashattention_nopad import context_attention_fwd +from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd from lightllm.models.qwen3_moe.layer_infer.transformer_layer_infer import Qwen3MOETransformerLayerInfer +from lightllm.models.neo_chat.layer_weights.transformer_layer_weight import NeoChatMOETransformerLayerWeight +from lightllm.distributed import all_reduce +import torch.distributed as dist +from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer +from lightllm.models.qwen3.triton_kernel.qk_norm import qk_rmsnorm_forward class NeoChatMOETransformerLayerInfer(Qwen3MOETransformerLayerInfer): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) return + + def _bind_attention(self): + self._context_attention_kernel = self._context_attention_kernel + self._token_attention_kernel = self._token_decode_attention_normal + self._copy_kv_to_mem_cache = self._copy_kv_to_mem_cache_normal + return + + def _get_qkv( + self, + input: torch.Tensor, + infer_state: NeoChatInferStateInfo, + layer_weight: NeoChatMOETransformerLayerWeight, + ) -> Tuple[torch.Tensor, torch.Tensor]: + input = input.view(-1, self.embed_dim_) + q = layer_weight.q_proj.mm(input) + + q_h, q_w = layer_weight.q_hw_proj.mm(input).chunk(2, dim=-1) + k_h, k_w = layer_weight.k_hw_proj.mm(input).chunk(2, dim=-1) + + cache_kv = layer_weight.kv_proj.mm(input) + qk_rmsnorm_forward( + q, + weight=layer_weight.q_norm_weight_.weight, + eps=self.eps_, + ) + + qk_rmsnorm_forward( + q_h, + weight=layer_weight.q_norm_h_weight_.weight, + eps=self.eps_, + ) + + qk_rmsnorm_forward( + q_w, + weight=layer_weight.q_norm_w_weight_.weight, + eps=self.eps_, + ) + + qk_rmsnorm_forward( + cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], + weight=layer_weight.k_norm_weight_.weight, + eps=self.eps_, + ) + + qk_rmsnorm_forward( + k_h, + weight=layer_weight.k_norm_h_weight_.weight, + eps=self.eps_, + ) + qk_rmsnorm_forward( + k_w, + weight=layer_weight.k_norm_w_weight_.weight, + eps=self.eps_, + ) + cache_kv = cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) + rotary_emb_fwd( + q.view(-1, self.tp_q_head_num_, self.head_dim_), + cache_kv[:, : self.tp_k_head_num_, :], + infer_state.position_cos, + infer_state.position_sin, + ) + rotary_emb_fwd( + q_h.view(-1, self.tp_q_head_num_, self.head_dim_ // 2), + k_h.view(-1, self.tp_k_head_num_, self.head_dim_ // 2), + infer_state.position_cos_h, + infer_state.position_sin_h, + ) + rotary_emb_fwd( + q_w.view(-1, self.tp_q_head_num_, self.head_dim_ // 2), + k_w.view(-1, self.tp_k_head_num_, self.head_dim_ // 2), + infer_state.position_cos_w, + infer_state.position_sin_w, + ) + # 拼接q q_h q_w + q = torch.cat([q, q_h, q_w], dim=-1) + # 拼接k k_h k_w + seq_len = cache_kv.shape[0] + k_h = k_h.reshape(seq_len, self.tp_k_head_num_, self.head_dim_ // 2) + k_w = k_w.reshape(seq_len, self.tp_k_head_num_, self.head_dim_ // 2) + k = cache_kv[:, : self.tp_k_head_num_, :] + k = torch.cat([k, k_h, k_w], dim=-1) + # 对齐V的shape + v = cache_kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :] + v_pad = torch.zeros( + (seq_len, self.tp_v_head_num_, self.head_dim_), + device=v.device, + dtype=v.dtype, + ) + v = torch.cat([v, v_pad], dim=-1) + cache_kv = torch.cat([k, v], dim=1) + return q, cache_kv + + def _context_attention_kernel( + self, q, kv, infer_state: NeoChatInferStateInfo, layer_weight, out=None + ) -> torch.Tensor: + o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out + kv = infer_state.mem_manager.kv_buffer[self.layer_num_] + context_attention_fwd_neo( + q.view(-1, self.tp_q_head_num_, self.head_dim_ * 2), + kv[:, 0 : self.tp_k_head_num_, :], + kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :], + o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2), + infer_state.position_ids[0], + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.b_ready_cache_len, + infer_state.max_len_in_batch, + infer_state.req_manager.req_to_token_indexs, + ) + o3 = o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2) + o3 = o3[:, :, : self.head_dim_].contiguous() + return o3.view(o3.shape[0], -1) + + def _token_decode_attention_normal(self, q, infer_state: NeoChatInferStateInfo, layer_weight, out=None): + total_token_num = infer_state.total_token_num + batch_size = infer_state.batch_size + calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_ * 2) + + att_m_tensor = self.alloc_tensor((self.tp_q_head_num_, total_token_num), torch.float32) + + token_att_fwd( + q.view(calcu_shape1), + infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :], + att_m_tensor, + infer_state.req_manager.req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.max_len_in_batch, + ) + o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out + o_tensor = o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2) + o_tensor = o_tensor[:, :, : self.head_dim_].contiguous() + + from lightllm.models.llama.triton_kernel.token_attention_softmax_and_reducev import ( + token_softmax_reducev_fwd, + ) + + calcu_shape2 = (batch_size, self.tp_q_head_num_, self.head_dim_) + token_softmax_reducev_fwd( + att_m_tensor, + infer_state.mem_manager.kv_buffer[self.layer_num_][ + :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : + ], + o_tensor.view(calcu_shape2), + infer_state.req_manager.req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + ) + return o_tensor diff --git a/lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py b/lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py index 2dc87f3ca..bc38f1adc 100644 --- a/lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py @@ -1,7 +1,51 @@ from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ( + NormWeight, + ROWMMWeight, +) class NeoChatMOETransformerLayerWeight(Qwen3MOETransformerLayerWeight): def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): super().__init__(layer_num, data_type, network_config, mode, quant_cfg) return + + def _init_weight_names(self): + super()._init_weight_names() + self._q_weight_hw_name = f"model.layers.{self.layer_num_}.self_attn.q_proj_hw.weight" + self._q_bias_hw_name = None + self._k_weight_hw_name = f"model.layers.{self.layer_num_}.self_attn.k_proj_hw.weight" + self._k_bias_hw_name = None + + self._q_norm_h_name = f"model.layers.{self.layer_num_}.self_attn.q_norm_h.weight" + self._q_norm_w_name = f"model.layers.{self.layer_num_}.self_attn.q_norm_w.weight" + + self._k_norm_h_name = f"model.layers.{self.layer_num_}.self_attn.k_norm_h.weight" + self._k_norm_w_name = f"model.layers.{self.layer_num_}.self_attn.k_norm_w.weight" + + def _init_qkv(self): + super()._init_qkv() + self.q_hw_proj = ROWMMWeight( + weight_names=self._q_weight_hw_name, + data_type=self.data_type_, + bias_names=self._q_bias_hw_name, + quant_cfg=self.quant_cfg, + layer_num=self.layer_num_, + name="q_hw_proj", + ) + self.k_hw_proj = ROWMMWeight( + weight_names=self._k_weight_hw_name, + data_type=self.data_type_, + bias_names=self._k_bias_hw_name, + quant_cfg=self.quant_cfg, + layer_num=self.layer_num_, + name="k_hw_proj", + ) + + def _init_norm(self): + super()._init_norm() + + self.q_norm_h_weight_ = NormWeight(weight_name=self._q_norm_h_name, data_type=self.data_type_) + self.q_norm_w_weight_ = NormWeight(weight_name=self._q_norm_w_name, data_type=self.data_type_) + self.k_norm_h_weight_ = NormWeight(weight_name=self._k_norm_h_name, data_type=self.data_type_) + self.k_norm_w_weight_ = NormWeight(weight_name=self._k_norm_w_name, data_type=self.data_type_) diff --git a/lightllm/models/neo_chat/model.py b/lightllm/models/neo_chat/model.py index 61fd98b98..edc734986 100644 --- a/lightllm/models/neo_chat/model.py +++ b/lightllm/models/neo_chat/model.py @@ -19,6 +19,7 @@ from lightllm.models.neo_chat.layer_weights.transformer_layer_weight import NeoChatMOETransformerLayerWeight from lightllm.models.neo_chat.layer_weights.pre_and_post_layer_weight import NeoChatMOEPreAndPostLayerWeight from lightllm.common.basemodel.multimodal_tokenizer import BaseMultiModalTokenizer +from lightllm.models.neo_chat.infer_state import NeoChatInferStateInfo IMG_START_TOKEN = "" IMG_END_TOKEN = "" @@ -65,10 +66,10 @@ def get_image_token_length(self, img: ImageItem): max_pixels=self.max_pixel, ) grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size - token_num = (grid_h * grid_w) * (self.downsample_ratio ** 2) - # grid_thwd是为了mrope准备的,这里不需要 - img.grid_thwd = (1, grid_h, grid_w, 0) - return int(token_num) + token_num = int((grid_h * grid_w) * (self.downsample_ratio ** 2)) + # 这里的grid_h和grid_w需要* self.downsample_ratio么?再仔细看下代码 + img.grid_thwd = (1, int(grid_h * self.downsample_ratio), int(grid_w * self.downsample_ratio), 1 - token_num) + return token_num # only change the impl of the encode func: def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs): @@ -87,23 +88,23 @@ def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs): start_idx = 0 while True: try: - start_idx = origin_ids.index(self.image_start_id, start_idx) + start_idx = origin_ids.index(self.image_start_id) if start_idx + 1 >= len(origin_ids): break if origin_ids[start_idx + 1] == self.image_end_id: input_ids.extend(origin_ids[: start_idx + 1]) token_id = multimodal_params.images[image_id].token_id token_num = multimodal_params.images[image_id].token_num + multimodal_params.images[image_id].start_idx = len(input_ids) input_ids.extend(range(token_id, token_id + token_num)) input_ids.append(self.image_end_id) origin_ids = origin_ids[start_idx + 2 :] - start_idx = 0 image_id += 1 else: raise ValueError("image token error") except ValueError: break - input_ids.extend(origin_ids[start_idx:]) + input_ids.extend(origin_ids) return input_ids @@ -116,7 +117,7 @@ class NeoTpPartModel(Qwen3MOEModel): pre_and_post_weight_class = NeoChatMOEPreAndPostLayerWeight transformer_weight_class = NeoChatMOETransformerLayerWeight - infer_state_class = LlamaInferStateInfo + infer_state_class = NeoChatInferStateInfo def __init__(self, kvargs): super().__init__(kvargs) diff --git a/lightllm/models/neo_chat/triton_kernel/__init__.py b/lightllm/models/neo_chat/triton_kernel/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/neo_chat/triton_kernel/context_attention_fwd_neo.py b/lightllm/models/neo_chat/triton_kernel/context_attention_fwd_neo.py new file mode 100644 index 000000000..46376502f --- /dev/null +++ b/lightllm/models/neo_chat/triton_kernel/context_attention_fwd_neo.py @@ -0,0 +1,467 @@ +# context_attention_fwd_neo_pos1d.py +# From : https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html + +import math +import torch +import triton +import triton.language as tl + +from lightllm.utils.device_utils import is_tesla + + +@triton.jit +def _fwd_kernel( + Q, + K, + V, + sm_scale, + Out, + position_ids, # 1D, concatenated by batch order, length = sum(B_Seqlen) + B_Pos_Start, # [batch], prefix sum of B_Seqlen (int32) + B_Start_Loc, + B_Seqlen, + Req_to_tokens, + B_req_idx, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_req_to_tokens_b, + stride_req_to_tokens_s, + kv_group_num, + b_prompt_cache_len, + H: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + start_m = tl.program_id(0) + cur_bh = tl.program_id(1) + cur_batch = cur_bh // H + cur_head = cur_bh % H + + cur_kv_head = cur_head // kv_group_num + + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + prompt_cache_len = tl.load(b_prompt_cache_len + cur_batch) + total_len = tl.load(B_Seqlen + cur_batch) + cur_batch_seq_len = total_len - prompt_cache_len + cur_batch_req_idx = tl.load(B_req_idx + cur_batch) + + # where this request starts inside 1D position_ids + pos_base = tl.load(B_Pos_Start + cur_batch) + + block_start_loc = BLOCK_M * start_m + + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = block_start_loc + tl.arange(0, BLOCK_M) + + # load Q for current block + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + + offs_d[None, :] * stride_qd + ) + q_valid = offs_m < cur_batch_seq_len + q = tl.load(Q + off_q, mask=q_valid[:, None], other=0.0) + + # init online softmax + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) + block_end_loc = total_len + + # query absolute pos inside request: [prompt_cache_len .. total_len-1] + q_pos = prompt_cache_len + offs_m + + # gid by pos (NOT by mem_index) + q_gid = tl.load( + position_ids + pos_base + q_pos, + mask=q_valid, + other=-2147483648, + ).to(tl.int32) + + # main loop over keys by logical pos + for start_n in range(0, block_mask * block_end_loc, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + k_pos = start_n + offs_n + k_valid = k_pos < block_end_loc + + # gid by pos (NOT by mem_index) + k_gid = tl.load( + position_ids + pos_base + k_pos, + mask=k_valid, + other=-2147483647, + ).to(tl.int32) + + # map logical k_pos -> kv cache mem_index + kv_loc = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * k_pos, + mask=k_valid, + other=0, + ).to(tl.int64) + + # load K using mem_index + off_k = kv_loc[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd + k = tl.load(K + off_k, mask=k_valid[None, :], other=0.0) + + # qk + qk = tl.dot(q, k) + + # mask: causal OR same gid (image block full-attn) + mask = (q_pos[:, None] >= k_pos[None, :]) | (q_gid[:, None] == k_gid[None, :]) + mask = mask & q_valid[:, None] & k_valid[None, :] + + qk = tl.where(mask, qk * sm_scale, -1.0e8) + + # online softmax + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk = qk - m_ij[:, None] + p = tl.math.exp2(qk) + l_ij = tl.sum(p, 1) + + alpha = tl.math.exp2(m_i - m_ij) + l_i = l_i * alpha + l_ij + acc = acc * alpha[:, None] + + # load V using mem_index + off_v = kv_loc[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd + v = tl.load(V + off_v, mask=k_valid[:, None], other=0.0) + + p = p.to(v.dtype) + acc = tl.dot(p, v, acc) + + m_i = m_ij + + acc = acc / l_i[:, None] + + # store + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + + offs_d[None, :] * stride_od + ) + out_ptrs = Out + off_o + tl.store(out_ptrs, acc, mask=q_valid[:, None]) + + +@torch.no_grad() +def context_attention_fwd_neo( + q, + k, + v, + o, + position_ids, # 1D concatenated for this batch + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + max_input_len, + req_to_token_indexs, +): + # position_ids must cover sum of b_seq_len + # b_pos_start: prefix sum over b_seq_len, defines each request's start inside position_ids + # NOTE: assumes position_ids is concatenated in the SAME order as cur_batch = 0..batch-1 + batch = b_seq_len.shape[0] + device = b_seq_len.device + b_pos_start = torch.zeros((batch,), device=device, dtype=torch.int32) + if batch > 1: + b_pos_start[1:] = torch.cumsum(b_seq_len[:-1].to(torch.int32), dim=0) + + needed = int((b_pos_start[-1] + b_seq_len[-1]).item()) + assert position_ids.numel() >= needed, (position_ids.numel(), needed) + + BLOCK_M = 128 if not is_tesla() else 64 + + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128, 256} + + # same trick as your original: exp2 + 1/log(2) + sm_scale = 1.0 / (Lq ** 0.5) * 1.4426950408889634 + + head = q.shape[1] + kv_group_num = q.shape[1] // k.shape[1] + + grid = lambda meta: (triton.cdiv(max_input_len, meta["BLOCK_M"]), batch * head, 1) + + BLOCK_N = BLOCK_M + num_warps = 4 if Lk <= 64 else 8 + num_stages = 1 + + _fwd_kernel[grid]( + q, + k, + v, + sm_scale, + o, + position_ids, + b_pos_start, + b_start_loc, + b_seq_len, + req_to_token_indexs, + b_req_idx, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + req_to_token_indexs.stride(0), + req_to_token_indexs.stride(1), + kv_group_num=kv_group_num, + b_prompt_cache_len=b_prompt_cache_len, + H=head, + BLOCK_DMODEL=Lk, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + num_warps=num_warps, + num_stages=num_stages, + ) + + +def reference_attention( + q, + k, + v, + position_ids, + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + req_to_token_indexs, +): + """ + q: [sum_q, Hq, D] packed by b_start_loc + k/v: [KV_SIZE, Hk, D] by mem_index + position_ids: 1D concatenated by batch order, length = sum(b_seq_len) + """ + device = q.device + dtype = q.dtype + sum_q, Hq, D = q.shape + Hk = k.shape[1] + kv_group_num = Hq // Hk + + batch = b_seq_len.shape[0] + b_pos_start = torch.zeros((batch,), device=device, dtype=torch.int64) + if batch > 1: + b_pos_start[1:] = torch.cumsum(b_seq_len[:-1].to(torch.int64), dim=0) + + out = torch.empty_like(q) + + scale = 1.0 / math.sqrt(D) + + for b in range(batch): + req = int(b_req_idx[b].item()) + total_len = int(b_seq_len[b].item()) + prompt_len = int(b_prompt_cache_len[b].item()) + q_len = total_len - prompt_len + + q_start = int(b_start_loc[b].item()) + q_blk = q[q_start : q_start + q_len] # [M, Hq, D] + + pos_base = int(b_pos_start[b].item()) + gid = position_ids[pos_base : pos_base + total_len].to(torch.int64) # [L] + + # gather K/V for this request by logical pos -> mem_index + token_locs = req_to_token_indexs[req, :total_len].to(torch.int64) # [L] + k_blk = k[token_locs] # [L, Hk, D] + v_blk = v[token_locs] # [L, Hk, D] + + # expand kv heads to q heads (GQA) + k_hq = k_blk.repeat_interleave(kv_group_num, dim=1) # [L, Hq, D] + v_hq = v_blk.repeat_interleave(kv_group_num, dim=1) # [L, Hq, D] + + # build mask by pos + q_pos = torch.arange(prompt_len, total_len, device=device, dtype=torch.int64) # [M] + k_pos = torch.arange(0, total_len, device=device, dtype=torch.int64) # [L] + allow = (k_pos[None, :] <= q_pos[:, None]) | (gid[q_pos][:, None] == gid[k_pos][None, :]) # [M, L] + + # scores: [Hq, M, L] + q_t = q_blk.permute(1, 0, 2).to(torch.float32) # [Hq, M, D] + k_t = k_hq.permute(1, 2, 0).to(torch.float32) # [Hq, D, L] + scores = torch.matmul(q_t, k_t) * scale # [Hq, M, L] + + # mask + neg = torch.tensor(-1.0e9, device=device, dtype=torch.float32) + scores = torch.where(allow[None, :, :], scores, neg) + + # softmax + reduce + p = torch.softmax(scores, dim=-1).to(torch.float32) # [Hq, M, L] + v_t = v_hq.permute(1, 0, 2).to(torch.float32) # [Hq, L, D] + out_hq = torch.matmul(p, v_t) # [Hq, M, D] + out_blk = out_hq.permute(1, 0, 2).to(dtype) # [M, Hq, D] + + out[q_start : q_start + q_len] = out_blk + + return out + + +def make_test_case( + device="cuda", + dtype=torch.float16, + batch=3, + Hq=8, + Hk=4, + D=64, + seed=0, + base_index=5000, +): + torch.manual_seed(seed) + + prompt_lens = torch.randint(low=1, high=5, size=(batch,), device=device) + q_lens = torch.randint(low=2, high=8, size=(batch,), device=device) + total_lens = (prompt_lens + q_lens).to(torch.int32) + + max_total_len = int(total_lens.max().item()) + + # b_start_loc for packed q (q_len per batch) + b_start_loc = torch.zeros((batch,), device=device, dtype=torch.int32) + cur = 0 + for b in range(batch): + b_start_loc[b] = cur + cur += int(q_lens[b].item()) + sum_q = cur + + b_seq_len = total_lens + b_prompt_cache_len = prompt_lens.to(torch.int32) + + # one req per batch for test + num_req = batch + b_req_idx = torch.arange(batch, device=device, dtype=torch.int32) + + # build a global KV "mem_index" space with offset, to simulate large indices + sum_kv = int(total_lens.sum().item()) + kv_size = base_index + sum_kv + 16 + + # allocate unique mem indices + pool = torch.randperm(kv_size - base_index, device=device, dtype=torch.int64)[:sum_kv] + base_index + + # Req_to_tokens: [num_req, max_total_len] + req_to_token_indexs = torch.zeros((num_req, max_total_len), device=device, dtype=torch.int32) + p = 0 + for r in range(num_req): + L = int(total_lens[r].item()) + req_to_token_indexs[r, :L] = pool[p : p + L].to(torch.int32) + p += L + + # position_ids: 1D concatenated by batch order (length = sum_kv) + position_ids = torch.empty((sum_kv,), device=device, dtype=torch.int32) + off = 0 + for r in range(num_req): + L = int(total_lens[r].item()) + gid = torch.arange(L, device=device, dtype=torch.int32) + + # make 0-2 repeated blocks (simulate image tokens) + if L >= 4: + # repeat a short block + s = int(torch.randint(0, max(1, L - 2), (1,), device=device).item()) + e = min(L, s + int(torch.randint(2, min(4, L - s) + 1, (1,), device=device).item())) + gid[s:e] = gid[s] + if L >= 8 and torch.rand((), device=device).item() > 0.5: + s = 4 + e = min(L, 7) + gid[s:e] = gid[s] + + position_ids[off : off + L] = gid + off += L + + q = torch.randn((sum_q, Hq, D), device=device, dtype=dtype) + k = torch.randn((kv_size, Hk, D), device=device, dtype=dtype) + v = torch.randn((kv_size, Hk, D), device=device, dtype=dtype) + o = torch.empty((sum_q, Hq, D), device=device, dtype=dtype) + + return ( + q, + k, + v, + o, + position_ids, + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + max_total_len, + req_to_token_indexs, + ) + + +def check_once(device="cuda", dtype=torch.float16, seed=0): + ( + q, + k, + v, + o, + position_ids, + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + max_total_len, + req_to_token_indexs, + ) = make_test_case(device=device, dtype=dtype, seed=seed) + + # triton + context_attention_fwd_neo( + q, + k, + v, + o, + position_ids, + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + max_total_len, + req_to_token_indexs, + ) + + # reference + ref = reference_attention( + q, + k, + v, + position_ids, + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + req_to_token_indexs, + ) + + diff = (o - ref).abs() + max_abs = diff.max().item() + denom = ref.abs().max().item() + 1e-6 + max_rel = max_abs / denom + + print(f"seed={seed}, dtype={dtype}") + print(f"max_abs_error = {max_abs:.6e}") + print(f"max_rel_error = {max_rel:.6e}") + print("allclose(fp16 tol)?", torch.allclose(o, ref, atol=5e-2, rtol=5e-2)) + + +# if __name__ == "__main__": +# if not torch.cuda.is_available(): +# print("No CUDA, skip Triton check.") +# else: +# torch.cuda.synchronize() +# check_once(dtype=torch.float16, seed=0) +# check_once(dtype=torch.float16, seed=1) +# check_once(dtype=torch.float16, seed=2) diff --git a/lightllm/models/neo_chat/triton_kernel/get_neo_position.py b/lightllm/models/neo_chat/triton_kernel/get_neo_position.py new file mode 100644 index 000000000..5cf270a12 --- /dev/null +++ b/lightllm/models/neo_chat/triton_kernel/get_neo_position.py @@ -0,0 +1,174 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _get_neo_position_triton( + b_image_start_idx: torch.Tensor, + b_image_thwd: torch.Tensor, + b_image_thwd_stride0: torch.Tensor, + b_image_nums: torch.Tensor, + b_image_start_num: torch.Tensor, + b_image_len: torch.Tensor, + position_ids: torch.Tensor, + position_ids_stride0: torch.Tensor, + b_ready_cache_len: torch.Tensor, + b_q_seq_len: torch.Tensor, + b_start_loc: torch.Tensor, + BLOCK_SIZE: tl.constexpr, +) -> torch.Tensor: + cur_batch = tl.program_id(0) + cache_len = tl.load(b_ready_cache_len + cur_batch) + q_seq_len = tl.load(b_q_seq_len + cur_batch) + image_num = tl.load(b_image_nums + cur_batch) + image_start_num = tl.load(b_image_start_num + cur_batch) + start_loc = tl.load(b_start_loc + cur_batch) + for i in range(image_num): + local_image_start_idx = tl.load(b_image_start_idx + image_start_num + i) + image_start_idx = start_loc + local_image_start_idx - cache_len + image_len = tl.load(b_image_len + image_start_num + i) + image_h = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 1) + image_w = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 2) + for j in range(0, image_len, BLOCK_SIZE): + off = j + tl.arange(0, BLOCK_SIZE) + # 目前没考虑视频,所以t 恒为 0 + t_pos = local_image_start_idx + off * 0 + h_pos = off // image_h + w_pos = off % image_w + tl.store( + position_ids + off + image_start_idx, + t_pos, + mask=(off < image_len) + & (off + local_image_start_idx - cache_len < q_seq_len) + & (local_image_start_idx - cache_len + off >= 0), + ) + tl.store( + position_ids + position_ids_stride0 + off + image_start_idx, + h_pos, + mask=(off < image_len) + & (off + local_image_start_idx - cache_len < q_seq_len) + & (local_image_start_idx - cache_len + off >= 0), + ) + tl.store( + position_ids + position_ids_stride0 * 2 + off + image_start_idx, + w_pos, + mask=(off < image_len) + & (off + local_image_start_idx - cache_len < q_seq_len) + & (local_image_start_idx - cache_len + off >= 0), + ) + + for i in range(image_num): + local_image_start_idx = tl.load(b_image_start_idx + image_start_num + i) + image_len = tl.load(b_image_len + image_start_num + i) + image_delta = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 3) + image_end = local_image_start_idx + image_len - cache_len + text_start = tl.maximum(0, image_end) + for j in range(text_start, q_seq_len, BLOCK_SIZE): + off = j + tl.arange(0, BLOCK_SIZE) + t_pos = tl.load(position_ids + off + start_loc, mask=(off < q_seq_len), other=0.0) + image_delta + h_pos = tl.load(position_ids + position_ids_stride0 + off + start_loc, mask=(off < q_seq_len), other=0.0) + w_pos = tl.load( + position_ids + position_ids_stride0 * 2 + off + start_loc, mask=(off < q_seq_len), other=0.0 + ) + tl.store(position_ids + off + start_loc, t_pos, mask=(off < q_seq_len)) + tl.store(position_ids + position_ids_stride0 + off + start_loc, h_pos, mask=(off < q_seq_len)) + tl.store(position_ids + position_ids_stride0 * 2 + off + start_loc, w_pos, mask=(off < q_seq_len)) + return + + +def get_neo_position_triton( + b_image_start_idx: torch.Tensor, + b_image_thwd: torch.Tensor, + b_image_nums: torch.Tensor, + b_image_start_num: torch.Tensor, + b_image_len: torch.Tensor, + position_ids: torch.Tensor, + b_ready_cache_len: torch.Tensor, + b_q_seq_len: torch.Tensor, + b_start_loc: torch.Tensor, +) -> torch.Tensor: + + batch_size = b_q_seq_len.shape[0] + assert batch_size == b_image_nums.shape[0] + grid = (batch_size,) + BLOCK_SIZE = 64 + _get_neo_position_triton[grid]( + b_image_start_idx=b_image_start_idx, + b_image_thwd=b_image_thwd, + b_image_thwd_stride0=b_image_thwd.stride(0), + b_image_nums=b_image_nums, + b_image_start_num=b_image_start_num, + b_image_len=b_image_len, + position_ids=position_ids, + position_ids_stride0=position_ids.stride(0), + b_ready_cache_len=b_ready_cache_len, + b_q_seq_len=b_q_seq_len, + b_start_loc=b_start_loc, + BLOCK_SIZE=BLOCK_SIZE, + ) + + +def test(): + b_image_start_idx = torch.tensor([0, 0, 4], dtype=torch.int32, device="cuda") + b_image_thwd = torch.tensor([[1, 2, 2, -3], [1, 2, 2, -3], [1, 2, 4, -7]], dtype=torch.int32, device="cuda") + b_image_nums = torch.tensor([1, 2], dtype=torch.int32, device="cuda") + b_image_start_num = torch.tensor([0, 1], dtype=torch.int32, device="cuda") + b_image_len = torch.tensor([4, 4, 8], dtype=torch.int32, device="cuda") + position_ids = ( + torch.tensor([0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=torch.int32, device="cuda") + .unsqueeze(0) + .expand(3, -1) + .contiguous() + ) + position_ids[1:].zero_() + b_ready_cache_len = torch.tensor([0, 0], dtype=torch.int32, device="cuda") + b_q_seq_len = torch.tensor([7, 13], dtype=torch.int32, device="cuda") + b_start_loc = torch.tensor([0, 7], dtype=torch.int32, device="cuda") + get_neo_position_triton( + b_image_start_idx, + b_image_thwd, + b_image_nums, + b_image_start_num, + b_image_len, + position_ids, + b_ready_cache_len, + b_q_seq_len, + b_start_loc, + ) + + print(position_ids) + # old_value = torch.cat([position_ids[:, 2:7], position_ids[:, 7 + 2 :]], dim=1) + + # position_ids = ( + # torch.tensor([2, 3, 4, 5, 6, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=torch.int32, device="cuda") + # .unsqueeze(0) + # .expand(3, -1) + # .contiguous() + # ) + # b_ready_cache_len = torch.tensor([2, 2], dtype=torch.int32, device="cuda") + # b_q_seq_len = torch.tensor([5, 11], dtype=torch.int32, device="cuda") + # b_start_loc = torch.tensor([0, 5], dtype=torch.int32, device="cuda") + + # get_neo_position_triton( + # b_image_start_idx, + # b_image_thwd, + # b_image_nums, + # b_image_start_num, + # b_image_len, + # position_ids, + # b_ready_cache_len, + # b_q_seq_len, + # b_start_loc, + # ) + + # print(f"old_value:\n{old_value}") + # print(f"position_ids:\n{position_ids}") + # assert torch.equal(old_value, position_ids) + + """ + tensor([[0, 0, 0, 0, 2, 3, 4, 0, 0, 0, 0, 2, 2, 2, 2, 4, 5, 6, 7, 8], + [0, 0, 1, 1, 2, 3, 4, 0, 0, 1, 1, 2, 2, 3, 3, 4, 5, 6, 7, 8], + [0, 1, 0, 1, 2, 3, 4, 0, 1, 0, 1, 2, 3, 2, 3, 4, 5, 6, 7, 8]], + device='cuda:0', dtype=torch.int32) + """ From e8e74168c24b2c34a117561c5bec245e930aaa52 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Wed, 31 Dec 2025 05:12:28 +0000 Subject: [PATCH 16/71] add-neo-chat --- .../neo_chat/{infer_state.py => infer_struct.py} | 10 +++++++--- .../neo_chat/layer_infer/transformer_layer_infer.py | 2 +- lightllm/models/neo_chat/model.py | 2 +- 3 files changed, 9 insertions(+), 5 deletions(-) rename lightllm/models/neo_chat/{infer_state.py => infer_struct.py} (91%) diff --git a/lightllm/models/neo_chat/infer_state.py b/lightllm/models/neo_chat/infer_struct.py similarity index 91% rename from lightllm/models/neo_chat/infer_state.py rename to lightllm/models/neo_chat/infer_struct.py index 9a71c3ddb..8e5347e8b 100644 --- a/lightllm/models/neo_chat/infer_state.py +++ b/lightllm/models/neo_chat/infer_struct.py @@ -29,7 +29,7 @@ def init_some_extra_state(self, model: LlamaTpPartModel, input_ids: torch.Tensor position_delta += image["grid_thwd"][3] b_position_delta[batch_idx] = position_delta position_ids = self.position_ids + torch.tensor(b_position_delta, device=self.position_ids.device) - self.position_ids = position_ids.unsqueeze(0).expand(3, -1) + self.position_ids = position_ids.unsqueeze(0).expand(3, -1).clone() self.position_ids[1:].zero_() self.position_ids = self.position_ids.contiguous() @@ -43,7 +43,9 @@ def init_some_extra_state(self, model: LlamaTpPartModel, input_ids: torch.Tensor def get_neo_position(self, multimodal_params: List[dict]) -> torch.Tensor: if len(multimodal_params) == 0: - return self.position_ids.unsqueeze(0).expand(3, -1) + position_ids = self.position_ids.new_zeros((3, self.position_ids.size(0))) + position_ids[0].copy_(self.position_ids) + return position_ids b_image_start_idx = [] b_image_nums = [] b_image_start_num = [] @@ -71,7 +73,9 @@ def get_neo_position(self, multimodal_params: List[dict]) -> torch.Tensor: # 没有任何图片 if image_start_num == 0: - return self.position_ids.unsqueeze(0).expand(3, -1).contiguous() + position_ids = self.position_ids.new_zeros((3, self.position_ids.size(0))) + position_ids[0].copy_(self.position_ids) + return position_ids.contiguous() b_image_start_idx = torch.tensor(b_image_start_idx, device="cpu").cuda(non_blocking=True) b_image_thwd = torch.tensor(b_image_thwd, device="cpu").cuda(non_blocking=True) # image_num x 4 b_image_nums = torch.tensor(b_image_nums, device="cpu").cuda(non_blocking=True) diff --git a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py index b0ee42856..e6b0402bb 100644 --- a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py @@ -3,7 +3,7 @@ from typing import Tuple from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd -from lightllm.models.neo_chat.infer_state import NeoChatInferStateInfo +from lightllm.models.neo_chat.infer_struct import NeoChatInferStateInfo from lightllm.models.neo_chat.triton_kernel.context_attention_fwd_neo import context_attention_fwd_neo from lightllm.models.llama.triton_kernel.context_flashattention_nopad import context_attention_fwd from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd diff --git a/lightllm/models/neo_chat/model.py b/lightllm/models/neo_chat/model.py index edc734986..0cc469cea 100644 --- a/lightllm/models/neo_chat/model.py +++ b/lightllm/models/neo_chat/model.py @@ -19,7 +19,7 @@ from lightllm.models.neo_chat.layer_weights.transformer_layer_weight import NeoChatMOETransformerLayerWeight from lightllm.models.neo_chat.layer_weights.pre_and_post_layer_weight import NeoChatMOEPreAndPostLayerWeight from lightllm.common.basemodel.multimodal_tokenizer import BaseMultiModalTokenizer -from lightllm.models.neo_chat.infer_state import NeoChatInferStateInfo +from lightllm.models.neo_chat.infer_struct import NeoChatInferStateInfo IMG_START_TOKEN = "" IMG_END_TOKEN = "" From ba4498317f427f316c15da8a6301e72422d94136 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Wed, 31 Dec 2025 05:28:39 +0000 Subject: [PATCH 17/71] add-neo-chat --- .../context_attention_fwd_neo.py | 217 ++++++++---------- 1 file changed, 101 insertions(+), 116 deletions(-) diff --git a/lightllm/models/neo_chat/triton_kernel/context_attention_fwd_neo.py b/lightllm/models/neo_chat/triton_kernel/context_attention_fwd_neo.py index 46376502f..80fc2ea44 100644 --- a/lightllm/models/neo_chat/triton_kernel/context_attention_fwd_neo.py +++ b/lightllm/models/neo_chat/triton_kernel/context_attention_fwd_neo.py @@ -1,6 +1,3 @@ -# context_attention_fwd_neo_pos1d.py -# From : https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html - import math import torch import triton @@ -16,8 +13,7 @@ def _fwd_kernel( V, sm_scale, Out, - position_ids, # 1D, concatenated by batch order, length = sum(B_Seqlen) - B_Pos_Start, # [batch], prefix sum of B_Seqlen (int32) + position_ids, # 1D: packed like Q (only NEW tokens), length == Q.shape[0] B_Start_Loc, B_Seqlen, Req_to_tokens, @@ -53,28 +49,26 @@ def _fwd_kernel( cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) prompt_cache_len = tl.load(b_prompt_cache_len + cur_batch) total_len = tl.load(B_Seqlen + cur_batch) - cur_batch_seq_len = total_len - prompt_cache_len + cur_batch_seq_len = total_len - prompt_cache_len # NEW len cur_batch_req_idx = tl.load(B_req_idx + cur_batch) - # where this request starts inside 1D position_ids - pos_base = tl.load(B_Pos_Start + cur_batch) - block_start_loc = BLOCK_M * start_m offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_DMODEL) offs_m = block_start_loc + tl.arange(0, BLOCK_M) - # load Q for current block + # Q pointers off_q = ( (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd ) + q_valid = offs_m < cur_batch_seq_len q = tl.load(Q + off_q, mask=q_valid[:, None], other=0.0) - # init online softmax + # online softmax state m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) @@ -82,44 +76,55 @@ def _fwd_kernel( block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) block_end_loc = total_len - # query absolute pos inside request: [prompt_cache_len .. total_len-1] - q_pos = prompt_cache_len + offs_m + # absolute q positions in the request + q_pos = prompt_cache_len + offs_m # [M] - # gid by pos (NOT by mem_index) + # q_gid from packed position_ids (aligned with Q rows) q_gid = tl.load( - position_ids + pos_base + q_pos, + position_ids + cur_batch_in_all_start_index + offs_m, mask=q_valid, other=-2147483648, ).to(tl.int32) - # main loop over keys by logical pos + BIG = tl.full([BLOCK_N], 1000000000, tl.int32) # ensure != any normal gid + for start_n in range(0, block_mask * block_end_loc, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) - k_pos = start_n + offs_n - k_valid = k_pos < block_end_loc - # gid by pos (NOT by mem_index) - k_gid = tl.load( - position_ids + pos_base + k_pos, - mask=k_valid, - other=-2147483647, - ).to(tl.int32) + k_pos = start_n + offs_n # [N] + k_valid = k_pos < block_end_loc - # map logical k_pos -> kv cache mem_index + # map logical pos -> mem_index (for K/V) kv_loc = tl.load( Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * k_pos, mask=k_valid, other=0, ).to(tl.int64) - # load K using mem_index + # load K off_k = kv_loc[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd k = tl.load(K + off_k, mask=k_valid[None, :], other=0.0) - # qk qk = tl.dot(q, k) - # mask: causal OR same gid (image block full-attn) + # k_gid: + # - for cached keys (k_pos < prompt_cache_len): set to BIG + k_pos so equality is always false + # - for new keys (k_pos >= prompt_cache_len): read from packed position_ids by (k_pos - prompt_cache_len) + k_in_new = k_pos >= prompt_cache_len + k_new_idx = (k_pos - prompt_cache_len).to(tl.int32) # [N] valid only when k_in_new + k_gid_new = tl.load( + position_ids + cur_batch_in_all_start_index + k_new_idx, + mask=k_valid & k_in_new, + other=-2147483647, + ).to(tl.int32) + + k_gid = tl.where( + k_in_new, + k_gid_new, + (k_pos.to(tl.int32) + BIG), + ) + + # mask: causal OR same gid (only possible inside NEW part) mask = (q_pos[:, None] >= k_pos[None, :]) | (q_gid[:, None] == k_gid[None, :]) mask = mask & q_valid[:, None] & k_valid[None, :] @@ -127,7 +132,7 @@ def _fwd_kernel( # online softmax m_ij = tl.maximum(m_i, tl.max(qk, 1)) - qk = qk - m_ij[:, None] + qk -= m_ij[:, None] p = tl.math.exp2(qk) l_ij = tl.sum(p, 1) @@ -135,7 +140,7 @@ def _fwd_kernel( l_i = l_i * alpha + l_ij acc = acc * alpha[:, None] - # load V using mem_index + # load V off_v = kv_loc[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd v = tl.load(V + off_v, mask=k_valid[:, None], other=0.0) @@ -146,14 +151,12 @@ def _fwd_kernel( acc = acc / l_i[:, None] - # store off_o = ( (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od ) - out_ptrs = Out + off_o - tl.store(out_ptrs, acc, mask=q_valid[:, None]) + tl.store(Out + off_o, acc, mask=q_valid[:, None]) @torch.no_grad() @@ -162,7 +165,7 @@ def context_attention_fwd_neo( k, v, o, - position_ids, # 1D concatenated for this batch + position_ids, # 1D packed like q (only NEW tokens) b_req_idx, b_start_loc, b_seq_len, @@ -170,17 +173,8 @@ def context_attention_fwd_neo( max_input_len, req_to_token_indexs, ): - # position_ids must cover sum of b_seq_len - # b_pos_start: prefix sum over b_seq_len, defines each request's start inside position_ids - # NOTE: assumes position_ids is concatenated in the SAME order as cur_batch = 0..batch-1 - batch = b_seq_len.shape[0] - device = b_seq_len.device - b_pos_start = torch.zeros((batch,), device=device, dtype=torch.int32) - if batch > 1: - b_pos_start[1:] = torch.cumsum(b_seq_len[:-1].to(torch.int32), dim=0) - - needed = int((b_pos_start[-1] + b_seq_len[-1]).item()) - assert position_ids.numel() >= needed, (position_ids.numel(), needed) + # minimal safety: position_ids must cover packed q rows + assert position_ids.numel() >= q.shape[0], (position_ids.numel(), q.shape[0]) BLOCK_M = 128 if not is_tesla() else 64 @@ -188,10 +182,9 @@ def context_attention_fwd_neo( assert Lq == Lk and Lk == Lv assert Lk in {16, 32, 64, 128, 256} - # same trick as your original: exp2 + 1/log(2) sm_scale = 1.0 / (Lq ** 0.5) * 1.4426950408889634 - head = q.shape[1] + batch, head = b_seq_len.shape[0], q.shape[1] kv_group_num = q.shape[1] // k.shape[1] grid = lambda meta: (triton.cdiv(max_input_len, meta["BLOCK_M"]), batch * head, 1) @@ -207,7 +200,6 @@ def context_attention_fwd_neo( sm_scale, o, position_ids, - b_pos_start, b_start_loc, b_seq_len, req_to_token_indexs, @@ -241,18 +233,13 @@ def reference_attention( q, k, v, - position_ids, + position_ids_q, # 1D packed like q (only NEW tokens) b_req_idx, b_start_loc, b_seq_len, b_prompt_cache_len, req_to_token_indexs, ): - """ - q: [sum_q, Hq, D] packed by b_start_loc - k/v: [KV_SIZE, Hk, D] by mem_index - position_ids: 1D concatenated by batch order, length = sum(b_seq_len) - """ device = q.device dtype = q.dtype sum_q, Hq, D = q.shape @@ -260,27 +247,20 @@ def reference_attention( kv_group_num = Hq // Hk batch = b_seq_len.shape[0] - b_pos_start = torch.zeros((batch,), device=device, dtype=torch.int64) - if batch > 1: - b_pos_start[1:] = torch.cumsum(b_seq_len[:-1].to(torch.int64), dim=0) - out = torch.empty_like(q) - scale = 1.0 / math.sqrt(D) for b in range(batch): req = int(b_req_idx[b].item()) total_len = int(b_seq_len[b].item()) prompt_len = int(b_prompt_cache_len[b].item()) - q_len = total_len - prompt_len + new_len = total_len - prompt_len q_start = int(b_start_loc[b].item()) - q_blk = q[q_start : q_start + q_len] # [M, Hq, D] - - pos_base = int(b_pos_start[b].item()) - gid = position_ids[pos_base : pos_base + total_len].to(torch.int64) # [L] + q_blk = q[q_start : q_start + new_len] # [M, Hq, D] + gid_new = position_ids_q[q_start : q_start + new_len].to(torch.int64) # [M] - # gather K/V for this request by logical pos -> mem_index + # gather K/V for full request by logical pos -> mem_index token_locs = req_to_token_indexs[req, :total_len].to(torch.int64) # [L] k_blk = k[token_locs] # [L, Hk, D] v_blk = v[token_locs] # [L, Hk, D] @@ -289,27 +269,39 @@ def reference_attention( k_hq = k_blk.repeat_interleave(kv_group_num, dim=1) # [L, Hq, D] v_hq = v_blk.repeat_interleave(kv_group_num, dim=1) # [L, Hq, D] - # build mask by pos + # positions q_pos = torch.arange(prompt_len, total_len, device=device, dtype=torch.int64) # [M] k_pos = torch.arange(0, total_len, device=device, dtype=torch.int64) # [L] - allow = (k_pos[None, :] <= q_pos[:, None]) | (gid[q_pos][:, None] == gid[k_pos][None, :]) # [M, L] + + # build allow mask: + # causal always + allow = k_pos[None, :] <= q_pos[:, None] + + # full-attn only inside NEW part by gid + # compare only when k_pos in NEW + k_in_new = k_pos >= prompt_len + k_rel = (k_pos - prompt_len).clamp_min(0) # [L] + # map k_rel to gid_new, but only valid where k_in_new + k_gid = torch.empty((total_len,), device=device, dtype=torch.int64) + k_gid[:] = 10 ** 12 + k_pos # never equal to gid_new + k_gid[k_in_new] = gid_new[k_rel[k_in_new]] + + allow = allow | (gid_new[q_pos - prompt_len][:, None] == k_gid[None, :]) # scores: [Hq, M, L] q_t = q_blk.permute(1, 0, 2).to(torch.float32) # [Hq, M, D] k_t = k_hq.permute(1, 2, 0).to(torch.float32) # [Hq, D, L] scores = torch.matmul(q_t, k_t) * scale # [Hq, M, L] - # mask neg = torch.tensor(-1.0e9, device=device, dtype=torch.float32) scores = torch.where(allow[None, :, :], scores, neg) - # softmax + reduce p = torch.softmax(scores, dim=-1).to(torch.float32) # [Hq, M, L] v_t = v_hq.permute(1, 0, 2).to(torch.float32) # [Hq, L, D] out_hq = torch.matmul(p, v_t) # [Hq, M, D] out_blk = out_hq.permute(1, 0, 2).to(dtype) # [M, Hq, D] - out[q_start : q_start + q_len] = out_blk + out[q_start : q_start + new_len] = out_blk return out @@ -322,39 +314,39 @@ def make_test_case( Hk=4, D=64, seed=0, - base_index=5000, + base_index=50000, ): torch.manual_seed(seed) - prompt_lens = torch.randint(low=1, high=5, size=(batch,), device=device) - q_lens = torch.randint(low=2, high=8, size=(batch,), device=device) - total_lens = (prompt_lens + q_lens).to(torch.int32) + # prompt (cached) len and new len + prompt_lens = torch.randint(low=2, high=8, size=(batch,), device=device) + new_lens = torch.randint(low=1, high=8, size=(batch,), device=device) + total_lens = (prompt_lens + new_lens).to(torch.int32) max_total_len = int(total_lens.max().item()) + max_new_len = int(new_lens.max().item()) - # b_start_loc for packed q (q_len per batch) + # packed q start b_start_loc = torch.zeros((batch,), device=device, dtype=torch.int32) cur = 0 for b in range(batch): b_start_loc[b] = cur - cur += int(q_lens[b].item()) + cur += int(new_lens[b].item()) sum_q = cur b_seq_len = total_lens b_prompt_cache_len = prompt_lens.to(torch.int32) - # one req per batch for test + # one req per batch num_req = batch b_req_idx = torch.arange(batch, device=device, dtype=torch.int32) - # build a global KV "mem_index" space with offset, to simulate large indices + # global KV space large, indices not small sum_kv = int(total_lens.sum().item()) - kv_size = base_index + sum_kv + 16 - - # allocate unique mem indices + kv_size = base_index + sum_kv + 1024 pool = torch.randperm(kv_size - base_index, device=device, dtype=torch.int64)[:sum_kv] + base_index - # Req_to_tokens: [num_req, max_total_len] + # Req_to_tokens [num_req, max_total_len] req_to_token_indexs = torch.zeros((num_req, max_total_len), device=device, dtype=torch.int32) p = 0 for r in range(num_req): @@ -362,26 +354,21 @@ def make_test_case( req_to_token_indexs[r, :L] = pool[p : p + L].to(torch.int32) p += L - # position_ids: 1D concatenated by batch order (length = sum_kv) - position_ids = torch.empty((sum_kv,), device=device, dtype=torch.int32) - off = 0 - for r in range(num_req): - L = int(total_lens[r].item()) - gid = torch.arange(L, device=device, dtype=torch.int32) + # position_ids_q: only NEW tokens, packed like q + position_ids_q = torch.empty((sum_q,), device=device, dtype=torch.int32) + for b in range(batch): + M = int(new_lens[b].item()) + start = int(b_start_loc[b].item()) - # make 0-2 repeated blocks (simulate image tokens) - if L >= 4: - # repeat a short block - s = int(torch.randint(0, max(1, L - 2), (1,), device=device).item()) - e = min(L, s + int(torch.randint(2, min(4, L - s) + 1, (1,), device=device).item())) - gid[s:e] = gid[s] - if L >= 8 and torch.rand((), device=device).item() > 0.5: - s = 4 - e = min(L, 7) + gid = torch.arange(M, device=device, dtype=torch.int32) + + # make one repeated block inside NEW part to simulate image tokens + if M >= 4 and torch.rand((), device=device).item() > 0.3: + s = int(torch.randint(0, M - 2, (1,), device=device).item()) + e = min(M, s + 3) gid[s:e] = gid[s] - position_ids[off : off + L] = gid - off += L + position_ids_q[start : start + M] = gid q = torch.randn((sum_q, Hq, D), device=device, dtype=dtype) k = torch.randn((kv_size, Hk, D), device=device, dtype=dtype) @@ -393,12 +380,12 @@ def make_test_case( k, v, o, - position_ids, + position_ids_q, b_req_idx, b_start_loc, b_seq_len, b_prompt_cache_len, - max_total_len, + max_new_len, req_to_token_indexs, ) @@ -409,36 +396,34 @@ def check_once(device="cuda", dtype=torch.float16, seed=0): k, v, o, - position_ids, + position_ids_q, b_req_idx, b_start_loc, b_seq_len, b_prompt_cache_len, - max_total_len, + max_new_len, req_to_token_indexs, ) = make_test_case(device=device, dtype=dtype, seed=seed) - # triton context_attention_fwd_neo( q, k, v, o, - position_ids, + position_ids_q, b_req_idx, b_start_loc, b_seq_len, b_prompt_cache_len, - max_total_len, + max_new_len, req_to_token_indexs, ) - # reference ref = reference_attention( q, k, v, - position_ids, + position_ids_q, b_req_idx, b_start_loc, b_seq_len, @@ -457,11 +442,11 @@ def check_once(device="cuda", dtype=torch.float16, seed=0): print("allclose(fp16 tol)?", torch.allclose(o, ref, atol=5e-2, rtol=5e-2)) -# if __name__ == "__main__": -# if not torch.cuda.is_available(): -# print("No CUDA, skip Triton check.") -# else: -# torch.cuda.synchronize() -# check_once(dtype=torch.float16, seed=0) -# check_once(dtype=torch.float16, seed=1) -# check_once(dtype=torch.float16, seed=2) +if __name__ == "__main__": + if not torch.cuda.is_available(): + print("No CUDA, skip.") + else: + torch.cuda.synchronize() + check_once(dtype=torch.bfloat16, seed=0) + check_once(dtype=torch.bfloat16, seed=1) + check_once(dtype=torch.bfloat16, seed=2) From 4d41a33fcd6d3447661486b6c79c12f268410a5f Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Wed, 31 Dec 2025 10:27:23 +0000 Subject: [PATCH 18/71] add-neo-chat --- .../token_attention_nopad_att1.py | 3 +- .../layer_infer/transformer_layer_infer.py | 120 ++++++++---------- .../context_attention_fwd_neo.py | 4 +- 3 files changed, 57 insertions(+), 70 deletions(-) diff --git a/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py b/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py index eb5af6fec..02bd277ad 100644 --- a/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py +++ b/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py @@ -74,7 +74,8 @@ def token_att_fwd(q, k, att_out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen Lq, Lk = q.shape[-1], k.shape[-1] assert Lq == Lk assert Lk in {16, 32, 64, 128, 256} - sm_scale = 1.0 / (Lk ** 0.5) + Lk_scale = Lk // 2 + sm_scale = 1.0 / (Lk_scale ** 0.5) batch, head_num = B_req_idx.shape[0], q.shape[1] diff --git a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py index e6b0402bb..b0105131f 100644 --- a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py @@ -26,36 +26,28 @@ def _bind_attention(self): self._copy_kv_to_mem_cache = self._copy_kv_to_mem_cache_normal return - def _get_qkv( - self, - input: torch.Tensor, - infer_state: NeoChatInferStateInfo, - layer_weight: NeoChatMOETransformerLayerWeight, - ) -> Tuple[torch.Tensor, torch.Tensor]: + def _get_qkv(self, input, infer_state: NeoChatInferStateInfo, layer_weight: NeoChatMOETransformerLayerWeight): input = input.view(-1, self.embed_dim_) - q = layer_weight.q_proj.mm(input) + q = layer_weight.q_proj.mm(input) # [T, Hq*D] - q_h, q_w = layer_weight.q_hw_proj.mm(input).chunk(2, dim=-1) - k_h, k_w = layer_weight.k_hw_proj.mm(input).chunk(2, dim=-1) + q_hw = layer_weight.q_hw_proj.mm(input) + q_hw = q_hw.view(-1, self.tp_q_head_num_, self.head_dim_) + q_h, q_w = q_hw.chunk(2, dim=-1) - cache_kv = layer_weight.kv_proj.mm(input) - qk_rmsnorm_forward( - q, - weight=layer_weight.q_norm_weight_.weight, - eps=self.eps_, - ) + k_hw = layer_weight.k_hw_proj.mm(input) + k_hw = k_hw.view(-1, self.tp_k_head_num_, self.head_dim_) + k_h, k_w = k_hw.chunk(2, dim=-1) - qk_rmsnorm_forward( - q_h, - weight=layer_weight.q_norm_h_weight_.weight, - eps=self.eps_, - ) + cache_kv = layer_weight.kv_proj.mm(input) # [T, (Hk+Hv)*D] - qk_rmsnorm_forward( - q_w, - weight=layer_weight.q_norm_w_weight_.weight, - eps=self.eps_, - ) + qk_rmsnorm_forward(q, weight=layer_weight.q_norm_weight_.weight, eps=self.eps_) + + q_h_2d = q_h.reshape(q.shape[0], -1) + q_w_2d = q_w.reshape(q.shape[0], -1) + qk_rmsnorm_forward(q_h_2d, weight=layer_weight.q_norm_h_weight_.weight, eps=self.eps_) + qk_rmsnorm_forward(q_w_2d, weight=layer_weight.q_norm_w_weight_.weight, eps=self.eps_) + q_h = q_h_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2) + q_w = q_w_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2) qk_rmsnorm_forward( cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], @@ -63,17 +55,15 @@ def _get_qkv( eps=self.eps_, ) - qk_rmsnorm_forward( - k_h, - weight=layer_weight.k_norm_h_weight_.weight, - eps=self.eps_, - ) - qk_rmsnorm_forward( - k_w, - weight=layer_weight.k_norm_w_weight_.weight, - eps=self.eps_, - ) + k_h_2d = k_h.reshape(q.shape[0], -1) # [T, Hk*(D/2)] + k_w_2d = k_w.reshape(q.shape[0], -1) + qk_rmsnorm_forward(k_h_2d, weight=layer_weight.k_norm_h_weight_.weight, eps=self.eps_) + qk_rmsnorm_forward(k_w_2d, weight=layer_weight.k_norm_w_weight_.weight, eps=self.eps_) + k_h = k_h_2d.view(q.shape[0], self.tp_k_head_num_, self.head_dim_ // 2) + k_w = k_w_2d.view(q.shape[0], self.tp_k_head_num_, self.head_dim_ // 2) + cache_kv = cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) + rotary_emb_fwd( q.view(-1, self.tp_q_head_num_, self.head_dim_), cache_kv[:, : self.tp_k_head_num_, :], @@ -81,33 +71,29 @@ def _get_qkv( infer_state.position_sin, ) rotary_emb_fwd( - q_h.view(-1, self.tp_q_head_num_, self.head_dim_ // 2), - k_h.view(-1, self.tp_k_head_num_, self.head_dim_ // 2), + q_h, + k_h, infer_state.position_cos_h, infer_state.position_sin_h, ) rotary_emb_fwd( - q_w.view(-1, self.tp_q_head_num_, self.head_dim_ // 2), - k_w.view(-1, self.tp_k_head_num_, self.head_dim_ // 2), + q_w, + k_w, infer_state.position_cos_w, infer_state.position_sin_w, ) - # 拼接q q_h q_w - q = torch.cat([q, q_h, q_w], dim=-1) - # 拼接k k_h k_w - seq_len = cache_kv.shape[0] - k_h = k_h.reshape(seq_len, self.tp_k_head_num_, self.head_dim_ // 2) - k_w = k_w.reshape(seq_len, self.tp_k_head_num_, self.head_dim_ // 2) + + q3 = q.view(-1, self.tp_q_head_num_, self.head_dim_) + q3 = torch.cat([q3, q_h, q_w], dim=-1) + q = q3.reshape(q3.shape[0], -1) + k = cache_kv[:, : self.tp_k_head_num_, :] k = torch.cat([k, k_h, k_w], dim=-1) - # 对齐V的shape + v = cache_kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :] - v_pad = torch.zeros( - (seq_len, self.tp_v_head_num_, self.head_dim_), - device=v.device, - dtype=v.dtype, - ) + v_pad = torch.zeros((v.shape[0], v.shape[1], self.head_dim_), device=v.device, dtype=v.dtype) v = torch.cat([v, v_pad], dim=-1) + cache_kv = torch.cat([k, v], dim=1) return q, cache_kv @@ -121,7 +107,7 @@ def _context_attention_kernel( kv[:, 0 : self.tp_k_head_num_, :], kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :], o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2), - infer_state.position_ids[0], + infer_state.position_ids[0], # [0,0,1,2,3,3,3,4] infer_state.b_req_idx, infer_state.b_start_loc, infer_state.b_seq_len, @@ -136,13 +122,15 @@ def _context_attention_kernel( def _token_decode_attention_normal(self, q, infer_state: NeoChatInferStateInfo, layer_weight, out=None): total_token_num = infer_state.total_token_num batch_size = infer_state.batch_size - calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_ * 2) + + q_3d = q.view(batch_size, self.tp_q_head_num_, self.head_dim_ * 2) att_m_tensor = self.alloc_tensor((self.tp_q_head_num_, total_token_num), torch.float32) + k_3d = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] token_att_fwd( - q.view(calcu_shape1), - infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :], + q_3d, + k_3d, att_m_tensor, infer_state.req_manager.req_to_token_indexs, infer_state.b_req_idx, @@ -150,24 +138,22 @@ def _token_decode_attention_normal(self, q, infer_state: NeoChatInferStateInfo, infer_state.b_seq_len, infer_state.max_len_in_batch, ) - o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - o_tensor = o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2) - o_tensor = o_tensor[:, :, : self.head_dim_].contiguous() - from lightllm.models.llama.triton_kernel.token_attention_softmax_and_reducev import ( - token_softmax_reducev_fwd, - ) + from lightllm.models.llama.triton_kernel.token_attention_softmax_and_reducev import token_softmax_reducev_fwd + + v_3d = infer_state.mem_manager.kv_buffer[self.layer_num_][ + :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : self.head_dim_ + ] + + o_3d = self.alloc_tensor((batch_size, self.tp_q_head_num_, self.head_dim_), q.dtype) if out is None else out - calcu_shape2 = (batch_size, self.tp_q_head_num_, self.head_dim_) token_softmax_reducev_fwd( att_m_tensor, - infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ], - o_tensor.view(calcu_shape2), + v_3d, + o_3d, infer_state.req_manager.req_to_token_indexs, infer_state.b_req_idx, infer_state.b_start_loc, infer_state.b_seq_len, ) - return o_tensor + return o_3d.view(batch_size, -1) diff --git a/lightllm/models/neo_chat/triton_kernel/context_attention_fwd_neo.py b/lightllm/models/neo_chat/triton_kernel/context_attention_fwd_neo.py index 80fc2ea44..f5dae493c 100644 --- a/lightllm/models/neo_chat/triton_kernel/context_attention_fwd_neo.py +++ b/lightllm/models/neo_chat/triton_kernel/context_attention_fwd_neo.py @@ -181,8 +181,8 @@ def context_attention_fwd_neo( Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk and Lk == Lv assert Lk in {16, 32, 64, 128, 256} - - sm_scale = 1.0 / (Lq ** 0.5) * 1.4426950408889634 + base_head_dim = Lq // 2 + sm_scale = 1.0 / (base_head_dim ** 0.5) * 1.4426950408889634 batch, head = b_seq_len.shape[0], q.shape[1] kv_group_num = q.shape[1] // k.shape[1] From 0e8845c160ce2563815eb0dfc5d851d209b5dbd7 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Thu, 1 Jan 2026 16:38:33 +0000 Subject: [PATCH 19/71] fix-neo-chat --- lightllm/models/neo_chat/neo_visual.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/lightllm/models/neo_chat/neo_visual.py b/lightllm/models/neo_chat/neo_visual.py index c9d4b8161..16b30511e 100644 --- a/lightllm/models/neo_chat/neo_visual.py +++ b/lightllm/models/neo_chat/neo_visual.py @@ -247,7 +247,13 @@ def encode(self, images: List[ImageItem]): uuids.append(img.uuid) image_data = read_shm(get_shm_name_data(img.uuid)) image_data = Image.open(BytesIO(image_data)) - pixel_values, image_grid_hw = load_image_native(image_data) + pixel_values, image_grid_hw = load_image_native( + image_data, + patch_size=self.patch_size, + downsample_ratio=self.downsample_ratio, + min_pixels=self.min_pixels, + max_pixels=self.max_pixels, + ) img_tensors.append(pixel_values) img_grids.append(image_grid_hw) else: From b48cd499e1b661e55ca8f32c5f6f0164e1da7045 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Mon, 5 Jan 2026 05:00:10 +0000 Subject: [PATCH 20/71] fix-neo-chat-position-ids-h --- lightllm/models/neo_chat/triton_kernel/get_neo_position.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lightllm/models/neo_chat/triton_kernel/get_neo_position.py b/lightllm/models/neo_chat/triton_kernel/get_neo_position.py index 5cf270a12..955f48bd8 100644 --- a/lightllm/models/neo_chat/triton_kernel/get_neo_position.py +++ b/lightllm/models/neo_chat/triton_kernel/get_neo_position.py @@ -28,13 +28,13 @@ def _get_neo_position_triton( local_image_start_idx = tl.load(b_image_start_idx + image_start_num + i) image_start_idx = start_loc + local_image_start_idx - cache_len image_len = tl.load(b_image_len + image_start_num + i) - image_h = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 1) + # image_h = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 1) image_w = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 2) for j in range(0, image_len, BLOCK_SIZE): off = j + tl.arange(0, BLOCK_SIZE) # 目前没考虑视频,所以t 恒为 0 t_pos = local_image_start_idx + off * 0 - h_pos = off // image_h + h_pos = off // image_w w_pos = off % image_w tl.store( position_ids + off + image_start_idx, From 7a904f39d054ad3271683886f8fd0bece0cce665 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Tue, 6 Jan 2026 08:14:56 +0000 Subject: [PATCH 21/71] add-neo-chat-dense --- lightllm/models/__init__.py | 1 + lightllm/models/neo_chat/infer_struct.py | 99 ---- .../layer_infer/transformer_layer_infer.py | 12 +- .../pre_and_post_layer_weight.py | 2 +- .../layer_weights/transformer_layer_weight.py | 4 +- lightllm/models/neo_chat/model.py | 108 +---- lightllm/models/neo_chat/neo_visual.py | 279 ----------- .../models/neo_chat/triton_kernel/__init__.py | 0 .../context_attention_fwd_neo.py | 452 ------------------ .../triton_kernel/get_neo_position.py | 174 ------- lightllm/models/neo_chat/vision_process.py | 141 ------ lightllm/server/tokenizer.py | 2 +- .../visualserver/model_infer/model_rpc.py | 2 +- 13 files changed, 23 insertions(+), 1253 deletions(-) delete mode 100644 lightllm/models/neo_chat/infer_struct.py delete mode 100644 lightllm/models/neo_chat/neo_visual.py delete mode 100644 lightllm/models/neo_chat/triton_kernel/__init__.py delete mode 100644 lightllm/models/neo_chat/triton_kernel/context_attention_fwd_neo.py delete mode 100644 lightllm/models/neo_chat/triton_kernel/get_neo_position.py delete mode 100644 lightllm/models/neo_chat/vision_process.py diff --git a/lightllm/models/__init__.py b/lightllm/models/__init__.py index 5618dfd0c..9a51d9512 100644 --- a/lightllm/models/__init__.py +++ b/lightllm/models/__init__.py @@ -38,5 +38,6 @@ Tarsier2LlamaTpPartModel, ) from lightllm.models.gpt_oss.model import GptOssTpPartModel +from lightllm.models.neo_chat_moe.model import NeoTpMOEPartModel from lightllm.models.neo_chat.model import NeoTpPartModel from .registry import get_model, get_model_class diff --git a/lightllm/models/neo_chat/infer_struct.py b/lightllm/models/neo_chat/infer_struct.py deleted file mode 100644 index 8e5347e8b..000000000 --- a/lightllm/models/neo_chat/infer_struct.py +++ /dev/null @@ -1,99 +0,0 @@ -from typing import Optional, List -import torch -import numpy as np -from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.common.req_manager import ReqManager -from lightllm.models.neo_chat.triton_kernel.get_neo_position import get_neo_position_triton -from lightllm.models.llama.model import LlamaTpPartModel - - -class NeoChatInferStateInfo(LlamaInferStateInfo): - def __init__(self): - super().__init__() - self.position_cos = None - self.position_sin = None - self.position_cos_h = None - self.position_sin_h = None - self.position_cos_w = None - self.position_sin_w = None - - def init_some_extra_state(self, model: LlamaTpPartModel, input_ids: torch.Tensor): - LlamaInferStateInfo.init_some_extra_state(self, model, input_ids) - if self.is_prefill: - self.position_ids = self.get_neo_position(self.multimodal_params) - else: - b_position_delta = [0 for _ in range(self.b_seq_len.shape[0])] - for batch_idx, p in enumerate(self.multimodal_params): - position_delta = 0 - for image in p["images"]: - position_delta += image["grid_thwd"][3] - b_position_delta[batch_idx] = position_delta - position_ids = self.position_ids + torch.tensor(b_position_delta, device=self.position_ids.device) - self.position_ids = position_ids.unsqueeze(0).expand(3, -1).clone() - self.position_ids[1:].zero_() - - self.position_ids = self.position_ids.contiguous() - self.position_cos = model._cos_cached[self.position_ids[0]] - self.position_sin = model._sin_cached[self.position_ids[0]] - self.position_cos_h = model._hw_cos_cached[self.position_ids[1]] - self.position_sin_h = model._hw_sin_cached[self.position_ids[1]] - self.position_cos_w = model._hw_cos_cached[self.position_ids[2]] - self.position_sin_w = model._hw_sin_cached[self.position_ids[2]] - return - - def get_neo_position(self, multimodal_params: List[dict]) -> torch.Tensor: - if len(multimodal_params) == 0: - position_ids = self.position_ids.new_zeros((3, self.position_ids.size(0))) - position_ids[0].copy_(self.position_ids) - return position_ids - b_image_start_idx = [] - b_image_nums = [] - b_image_start_num = [] - b_image_len = [] - image_start_num = 0 - b_image_thwd = [] - - # pad multimodal_params to batch size. - batch_size = self.b_q_seq_len.shape[0] - multimodal_params = multimodal_params + [ - {"images": [], "audios": []} for _ in range(batch_size - len(multimodal_params)) - ] - - for _, p in enumerate(multimodal_params): - images = p.get("images", []) - for img in images: - b_image_start_idx.append(img["start_idx"]) - a = img["start_idx"] - print(f"img start_idx: {a}") - b_image_len.append(img["token_num"]) - b_image_thwd.append(img["grid_thwd"]) - b_image_nums.append(len(images)) - b_image_start_num.append(image_start_num) - image_start_num += len(images) - - # 没有任何图片 - if image_start_num == 0: - position_ids = self.position_ids.new_zeros((3, self.position_ids.size(0))) - position_ids[0].copy_(self.position_ids) - return position_ids.contiguous() - b_image_start_idx = torch.tensor(b_image_start_idx, device="cpu").cuda(non_blocking=True) - b_image_thwd = torch.tensor(b_image_thwd, device="cpu").cuda(non_blocking=True) # image_num x 4 - b_image_nums = torch.tensor(b_image_nums, device="cpu").cuda(non_blocking=True) - b_image_start_num = torch.tensor(b_image_start_num, device="cpu").cuda(non_blocking=True) - b_image_len = torch.tensor(b_image_len, device="cpu").cuda(non_blocking=True) - - position_ids = self.position_ids.new_zeros((3, self.position_ids.size(0))) - position_ids[0].copy_(self.position_ids) - - get_neo_position_triton( - b_image_start_idx=b_image_start_idx, - b_image_thwd=b_image_thwd, - b_image_nums=b_image_nums, - b_image_start_num=b_image_start_num, - b_image_len=b_image_len, - position_ids=position_ids, - b_ready_cache_len=self.b_ready_cache_len, - b_q_seq_len=self.b_q_seq_len, - b_start_loc=self.b_start_loc, - ) - return position_ids diff --git a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py index b0105131f..1cf13c413 100644 --- a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py @@ -3,19 +3,19 @@ from typing import Tuple from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd -from lightllm.models.neo_chat.infer_struct import NeoChatInferStateInfo -from lightllm.models.neo_chat.triton_kernel.context_attention_fwd_neo import context_attention_fwd_neo +from lightllm.models.neo_chat_moe.infer_struct import NeoChatInferStateInfo +from lightllm.models.neo_chat_moe.triton_kernel.context_attention_fwd_neo import context_attention_fwd_neo from lightllm.models.llama.triton_kernel.context_flashattention_nopad import context_attention_fwd from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd -from lightllm.models.qwen3_moe.layer_infer.transformer_layer_infer import Qwen3MOETransformerLayerInfer -from lightllm.models.neo_chat.layer_weights.transformer_layer_weight import NeoChatMOETransformerLayerWeight +from lightllm.models.qwen3.layer_infer.transformer_layer_infer import Qwen3TransformerLayerInfer +from lightllm.models.neo_chat.layer_weights.transformer_layer_weight import NeoChatTransformerLayerWeight from lightllm.distributed import all_reduce import torch.distributed as dist from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.qwen3.triton_kernel.qk_norm import qk_rmsnorm_forward -class NeoChatMOETransformerLayerInfer(Qwen3MOETransformerLayerInfer): +class NeoChatTransformerLayerInfer(Qwen3TransformerLayerInfer): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) return @@ -26,7 +26,7 @@ def _bind_attention(self): self._copy_kv_to_mem_cache = self._copy_kv_to_mem_cache_normal return - def _get_qkv(self, input, infer_state: NeoChatInferStateInfo, layer_weight: NeoChatMOETransformerLayerWeight): + def _get_qkv(self, input, infer_state: NeoChatInferStateInfo, layer_weight: NeoChatTransformerLayerWeight): input = input.view(-1, self.embed_dim_) q = layer_weight.q_proj.mm(input) # [T, Hq*D] diff --git a/lightllm/models/neo_chat/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/neo_chat/layer_weights/pre_and_post_layer_weight.py index 7766a5d29..c1f0638ac 100644 --- a/lightllm/models/neo_chat/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/neo_chat/layer_weights/pre_and_post_layer_weight.py @@ -12,7 +12,7 @@ def rename_weight_keys(weights): weights[k.replace(prefix, "")] = weights.pop(k) -class NeoChatMOEPreAndPostLayerWeight(Qwen2PreAndPostLayerWeight): +class NeoChatPreAndPostLayerWeight(Qwen2PreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) return diff --git a/lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py b/lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py index bc38f1adc..e5e769a76 100644 --- a/lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py @@ -1,11 +1,11 @@ -from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight +from lightllm.models.qwen3.layer_weights.transformer_layer_weight import Qwen3TransformerLayerWeight from lightllm.common.basemodel.layer_weights.meta_weights import ( NormWeight, ROWMMWeight, ) -class NeoChatMOETransformerLayerWeight(Qwen3MOETransformerLayerWeight): +class NeoChatTransformerLayerWeight(Qwen3TransformerLayerWeight): def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): super().__init__(layer_num, data_type, network_config, mode, quant_cfg) return diff --git a/lightllm/models/neo_chat/model.py b/lightllm/models/neo_chat/model.py index 0cc469cea..14d9f96dc 100644 --- a/lightllm/models/neo_chat/model.py +++ b/lightllm/models/neo_chat/model.py @@ -1,7 +1,7 @@ import os import json from lightllm.common.build_utils import repair_config -from lightllm.models.registry import ModelRegistry +from lightllm.models.registry import ModelRegistry, llm_model_type_is from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo from lightllm.models.qwen3_vl.layer_infer.pre_layer_infer import Qwen3VLMultimodalPreLayerInfer from lightllm.models.qwen3_vl.layer_infer.transformer_layer_infer import Qwen3VLTransformerLayerInfer @@ -11,111 +11,25 @@ from lightllm.server.core.objs import SamplingParams from lightllm.models.qwen3_moe.model import Qwen3MOEModel from lightllm.server.multimodal_params import AudioItem, MultimodalParams, ImageItem -from lightllm.models.neo_chat.vision_process import smart_resize +from lightllm.models.neo_chat_moe.vision_process import smart_resize from lightllm.models.internvl.model import InternvlTokenizer from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer -from lightllm.models.neo_chat.layer_infer.transformer_layer_infer import NeoChatMOETransformerLayerInfer +from lightllm.models.neo_chat.layer_infer.transformer_layer_infer import NeoChatTransformerLayerInfer from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.models.neo_chat.layer_weights.transformer_layer_weight import NeoChatMOETransformerLayerWeight -from lightllm.models.neo_chat.layer_weights.pre_and_post_layer_weight import NeoChatMOEPreAndPostLayerWeight +from lightllm.models.neo_chat.layer_weights.transformer_layer_weight import NeoChatTransformerLayerWeight +from lightllm.models.neo_chat.layer_weights.pre_and_post_layer_weight import NeoChatPreAndPostLayerWeight from lightllm.common.basemodel.multimodal_tokenizer import BaseMultiModalTokenizer -from lightllm.models.neo_chat.infer_struct import NeoChatInferStateInfo +from lightllm.models.neo_chat_moe.infer_struct import NeoChatInferStateInfo -IMG_START_TOKEN = "" -IMG_END_TOKEN = "" -IMG_TOKEN = "" -AUDIO_START_TOKEN = "" - -class NeoChatTokenizer(BaseMultiModalTokenizer): - def __init__(self, tokenizer, model_cfg, **kwargs): - super().__init__(tokenizer) - self.tokenizer = tokenizer - self.min_pixel = model_cfg.get("vision_config").get("min_pixels") - self.max_pixel = model_cfg.get("vision_config").get("max_pixels") - self.patch_size = model_cfg.get("vision_config").get("patch_size") - self.downsample_ratio = model_cfg.get("vision_config").get("downsample_ratio") - - self.image_token_id = model_cfg.get("image_token_id") - self.image_start_tag = IMG_START_TOKEN - self.image_start_id = tokenizer.convert_tokens_to_ids(self.image_start_tag) - self.image_end_tag = IMG_END_TOKEN - self.image_end_id = tokenizer.convert_tokens_to_ids(self.image_end_tag) - - def init_imageitem_extral_params( - self, img: ImageItem, multi_params: MultimodalParams, sampling_params: SamplingParams - ): - return - - def init_audioitem_extral_params( - self, audio: AudioItem, multi_params: MultimodalParams, sampling_params: SamplingParams - ): - raise NotImplementedError - - def get_audio_token_length(self, audio: AudioItem): - raise NotImplementedError - - def get_image_token_length(self, img: ImageItem): - width, height = img.image_w, img.image_h - resized_height, resized_width = smart_resize( - height=height, - width=width, - factor=int(self.patch_size // self.downsample_ratio), - min_pixels=self.min_pixel, - max_pixels=self.max_pixel, - ) - grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size - token_num = int((grid_h * grid_w) * (self.downsample_ratio ** 2)) - # 这里的grid_h和grid_w需要* self.downsample_ratio么?再仔细看下代码 - img.grid_thwd = (1, int(grid_h * self.downsample_ratio), int(grid_w * self.downsample_ratio), 1 - token_num) - return token_num - - # only change the impl of the encode func: - def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs): - # TEXTTEXTTEXT --> TEXTTEXTTEXT - image_tokens = IMG_START_TOKEN + IMG_END_TOKEN - if multimodal_params is None: - add_special_tokens = kwargs.get("add_special_tokens", True) - return self.tokenizer.encode(prompt, add_special_tokens=add_special_tokens) - image_count = len(multimodal_params.images) - prompt = prompt.replace(IMG_TOKEN, image_tokens, image_count) - - origin_ids = self.tokenizer.encode(prompt, add_special_tokens=kwargs["add_special_tokens"]) - # --> id,id+1...id+num - input_ids = [] - image_id = 0 - start_idx = 0 - while True: - try: - start_idx = origin_ids.index(self.image_start_id) - if start_idx + 1 >= len(origin_ids): - break - if origin_ids[start_idx + 1] == self.image_end_id: - input_ids.extend(origin_ids[: start_idx + 1]) - token_id = multimodal_params.images[image_id].token_id - token_num = multimodal_params.images[image_id].token_num - multimodal_params.images[image_id].start_idx = len(input_ids) - input_ids.extend(range(token_id, token_id + token_num)) - input_ids.append(self.image_end_id) - origin_ids = origin_ids[start_idx + 2 :] - image_id += 1 - else: - raise ValueError("image token error") - except ValueError: - break - input_ids.extend(origin_ids) - return input_ids - - -@ModelRegistry(["neo_chat"], is_multimodal=True) -class NeoTpPartModel(Qwen3MOEModel): +@ModelRegistry(["neo_chat"], is_multimodal=True, condition=llm_model_type_is("qwen3")) +class NeoTpPartModel(Qwen3TpPartModel): pre_layer_infer_class = LlamaMultimodalPreLayerInfer - transformer_layer_infer_class = NeoChatMOETransformerLayerInfer + transformer_layer_infer_class = NeoChatTransformerLayerInfer - pre_and_post_weight_class = NeoChatMOEPreAndPostLayerWeight - transformer_weight_class = NeoChatMOETransformerLayerWeight + pre_and_post_weight_class = NeoChatPreAndPostLayerWeight + transformer_weight_class = NeoChatTransformerLayerWeight infer_state_class = NeoChatInferStateInfo diff --git a/lightllm/models/neo_chat/neo_visual.py b/lightllm/models/neo_chat/neo_visual.py deleted file mode 100644 index 16b30511e..000000000 --- a/lightllm/models/neo_chat/neo_visual.py +++ /dev/null @@ -1,279 +0,0 @@ -import os -import torch -import torch.nn.functional as F -from PIL import Image -from typing import List -from io import BytesIO -import torch.nn as nn -from transformers.activations import ACT2FN -from safetensors import safe_open -from lightllm.server.multimodal_params import ImageItem -from transformers.modeling_outputs import BaseModelOutputWithPooling -from transformers.modeling_utils import PreTrainedModel -from lightllm.models.neo_chat.vision_process import load_image_native -from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data - - -def apply_rotary_emb_1d( - x: torch.Tensor, - cos_cached: torch.Tensor, - sin_cached: torch.Tensor, - positions: torch.Tensor, -): - """对输入张量的一部分应用1D RoPE。""" - # x: (..., seq_len, dim_part) - # positions: (..., seq_len) - # cos_cached: (max_pos, dim_part / 2) - cos_cached = cos_cached.to(device=positions.device) - sin_cached = sin_cached.to(device=positions.device) - - cos = cos_cached[positions] # Shape: (positions.shape, dim_part / 2) - sin = sin_cached[positions] # Shape: (positions.shape, dim_part / 2) - - x1 = x[..., 0::2] - x2 = x[..., 1::2] - - rotated_x1 = x1 * cos - x2 * sin - rotated_x2 = x1 * sin + x2 * cos - - x_rotated = torch.empty_like(x) - x_rotated[..., 0::2] = rotated_x1 - x_rotated[..., 1::2] = rotated_x2 - return x_rotated - - -def apply_2d_rotary_pos_emb( - x: torch.Tensor, - cos_cached_x: torch.Tensor, - sin_cached_x: torch.Tensor, - cos_cached_y: torch.Tensor, - sin_cached_y: torch.Tensor, - abs_positions_x: torch.Tensor, - abs_positions_y: torch.Tensor, -): - """应用2D RoPE到输入张量x。""" - dim = x.shape[-1] - dim_half = dim // 2 - - # 假设我们将embedding的前半部分用于一个方向的RoPE,后半部分用于另一个方向 - # 例如,前一半给X坐标,后一半给Y坐标 (或者反过来,但要保持一致) - x_part_1 = x[..., :dim_half] - x_part_2 = x[..., dim_half:] - - # 将与 abs_positions_x 相关的旋转应用于 x_part_1 - rotated_part_1 = apply_rotary_emb_1d(x_part_1, cos_cached_x, sin_cached_x, abs_positions_x) - # 将与 abs_positions_y 相关的旋转应用于 x_part_2 - rotated_part_2 = apply_rotary_emb_1d(x_part_2, cos_cached_y, sin_cached_y, abs_positions_y) - - # 将它们重新拼接起来。确保顺序与你分割时一致。 - return torch.cat((rotated_part_1, rotated_part_2), dim=-1) - - -def build_abs_positions_from_grid_hw(grid_hw: torch.Tensor, device=None): - """ - Compute patch coordinates (x, y) - - Args: - grid_hw: (B, 2) tensor representing (H, W) per image - """ - device = grid_hw.device - B = grid_hw.shape[0] - - # Get the number of patches per image - H = grid_hw[:, 0] - W = grid_hw[:, 1] - N = H * W - N_total = N.sum() - - # Create the batch index for each patch (B x patch count) - patch_to_sample = torch.repeat_interleave(torch.arange(B, device=device), N) # (N_total,) - - # Generate intra-image patch index (row-major order) - patch_id_within_image = torch.arange(N_total, device=device) - patch_id_within_image = ( - patch_id_within_image - - torch.cumsum(torch.cat([torch.tensor([0], device=device), N[:-1]]), dim=0)[patch_to_sample] - ) - - # Get H/W for each patch according to its image - W_per_patch = W[patch_to_sample] - abs_x = patch_id_within_image % W_per_patch - abs_y = patch_id_within_image // W_per_patch - - return abs_x, abs_y - - -class NeoVisionTransformerPretrainedModel(nn.Module): - def __init__( - self, - kvargs, - hidden_size: int = 1024, - llm_hidden_size: int = 2048, - downsample_ratio: float = 0.5, - patch_size: int = 16, - num_channels: int = 3, - max_position_embeddings_vision: int = 10000, - rope_theta_vision: float = 10000.0, - min_pixels: int = 65536, - max_pixels: int = 2408448, - **kwargs, - ): - super().__init__() - self.weight_dir = kvargs["weight_dir"] - self.data_type = kvargs.get("data_type", "bfloat16") - self.embed_dim = hidden_size - self.llm_hidden_size = llm_hidden_size - self.patch_size = patch_size - self.num_channels = num_channels - self.downsample_ratio = downsample_ratio - self.downsample_factor = int(1 / downsample_ratio) - self.max_position_embeddings_vision = max_position_embeddings_vision - self.rope_theta_vision = rope_theta_vision - self.rope_dim_part = self.embed_dim // 2 - self.min_pixels = min_pixels - self.max_pixels = max_pixels - - self.patch_embedding = nn.Conv2d( - in_channels=num_channels, out_channels=self.embed_dim, kernel_size=patch_size, stride=patch_size - ) - - self.dense_embedding = nn.Conv2d( - in_channels=self.embed_dim, - out_channels=self.llm_hidden_size, - kernel_size=self.downsample_factor, - stride=self.downsample_factor, - ) - self.gelu = nn.GELU() - - self.repe_dim_part = self.embed_dim // 2 - self.cos_x, self.sin_x = self.precompute_rope_freqs_sincos() - self.cos_y, self.sin_y = self.precompute_rope_freqs_sincos() - self._init_datatype() - - def _init_datatype(self): - if isinstance(self.data_type, torch.dtype): - return - if self.data_type in ["fp16", "float16"]: - self.data_type = torch.float16 - elif self.data_type in ["bf16", "bfloat16"]: - self.data_type = torch.bfloat16 - elif self.data_type in ["fp32", "float32"]: - self.data_type = torch.float32 - else: - raise ValueError(f"Unsupport datatype {self.data_type}!") - return - - def load_model(self, weight_dir): - bin_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")] - if bin_weight_files: - weight_dict = {} - for file_ in bin_weight_files: - f = torch.load(os.path.join(weight_dir, file_), "cpu") - for k, v in f.items(): - if "vision_model" in k: - weight_dict[k[len("vision_model.embeddings.") :]] = v - else: - hf_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".safetensors")] - weight_dict = {} - for file_ in hf_weight_files: - f = safe_open(os.path.join(weight_dir, file_), "pt", "cpu") - for k in f.keys(): - if "vision_model" in k: - weight_dict[k[len("vision_model.embeddings.") :]] = f.get_tensor(k) - self.load_state_dict(weight_dict) - - def precompute_rope_freqs_sincos(self): - inv_freq = 1.0 / ( - self.rope_theta_vision ** (torch.arange(0, self.rope_dim_part, 2).float() / self.rope_dim_part) - ) - t = torch.arange(self.max_position_embeddings_vision).type_as(inv_freq) - freqs = torch.outer(t, inv_freq) - return torch.cos(freqs), torch.sin(freqs) - - def _apply_2d_rotary_pos_emb(self, patch_embeds, grid_hw): - """ - Apply 2D Rotary Position Embedding to the patch embeddings. - """ - abs_pos_x, abs_pos_y = build_abs_positions_from_grid_hw(grid_hw, device=patch_embeds.device) - embeddings = apply_2d_rotary_pos_emb( - patch_embeds.to(torch.float32), # RoPE calculations are often more stable in float32 - self.cos_x, - self.sin_x, - self.cos_y, - self.sin_y, - abs_pos_x, - abs_pos_y, - ).to(self.patch_embedding.weight.dtype) - return embeddings - - def forward(self, pixel_values: torch.Tensor, grid_hw: torch.Tensor) -> torch.Tensor: - pixel_values = pixel_values.view( - -1, - 3, - self.patch_size, - self.patch_size, - ) - patch_embeds = self.gelu(self.patch_embedding(pixel_values)).view(-1, self.embed_dim) - patch_embeds = self._apply_2d_rotary_pos_emb(patch_embeds, grid_hw) - assert (grid_hw[:, 0] * grid_hw[:, 1]).sum() == patch_embeds.shape[ - 0 - ], "Grid size and patch embeds size mismatch." - - patches_list = [] - cur_position = 0 - for i in range(grid_hw.shape[0]): - h, w = grid_hw[i] - patches_per_img = patch_embeds[cur_position : cur_position + h * w].view(h, w, -1).unsqueeze(0) - patches_per_img = self.dense_embedding(patches_per_img.permute(0, 3, 1, 2)) - patches_per_img = patches_per_img.permute(0, 2, 3, 1) - patches_list.append(patches_per_img.view(-1, patches_per_img.shape[-1])) - cur_position += h * w - - embeddings = torch.cat(patches_list, dim=0) # (N_total // downsample_factor**2, C) - assert cur_position == patch_embeds.shape[0] - assert embeddings.shape[0] == int(patch_embeds.shape[0] / self.downsample_factor ** 2) - - return embeddings - - def encode(self, images: List[ImageItem]): - img_tensors = [] - valid_ids = [] - valid_id = 0 - img_grids = [] - uuids = [] - - for i, img in enumerate(images): - if isinstance(img, ImageItem): - uuids.append(img.uuid) - image_data = read_shm(get_shm_name_data(img.uuid)) - image_data = Image.open(BytesIO(image_data)) - pixel_values, image_grid_hw = load_image_native( - image_data, - patch_size=self.patch_size, - downsample_ratio=self.downsample_ratio, - min_pixels=self.min_pixels, - max_pixels=self.max_pixels, - ) - img_tensors.append(pixel_values) - img_grids.append(image_grid_hw) - else: - raise Exception("Unsupport input types: {} for {}".format(type(img), img)) - - # must devide merge_length - cur_num = int(img_tensors[-1].shape[0] * (self.downsample_ratio ** 2)) - print(f"cur_num is {cur_num}") - valid_ids.append([valid_id, valid_id + cur_num]) - valid_id += cur_num - - if len(img_tensors) <= 0: - return None - - imgs = torch.cat(img_tensors, dim=0) - grid_hw = torch.cat(img_grids, dim=0) - - pixel_values = imgs.to("cuda", dtype=self.data_type, non_blocking=True) - image_grid_hw = grid_hw.to("cuda", non_blocking=True) - - all_img_embeds = self.forward(pixel_values, grid_hw=image_grid_hw) - - return all_img_embeds, uuids, valid_ids diff --git a/lightllm/models/neo_chat/triton_kernel/__init__.py b/lightllm/models/neo_chat/triton_kernel/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/lightllm/models/neo_chat/triton_kernel/context_attention_fwd_neo.py b/lightllm/models/neo_chat/triton_kernel/context_attention_fwd_neo.py deleted file mode 100644 index f5dae493c..000000000 --- a/lightllm/models/neo_chat/triton_kernel/context_attention_fwd_neo.py +++ /dev/null @@ -1,452 +0,0 @@ -import math -import torch -import triton -import triton.language as tl - -from lightllm.utils.device_utils import is_tesla - - -@triton.jit -def _fwd_kernel( - Q, - K, - V, - sm_scale, - Out, - position_ids, # 1D: packed like Q (only NEW tokens), length == Q.shape[0] - B_Start_Loc, - B_Seqlen, - Req_to_tokens, - B_req_idx, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - stride_req_to_tokens_b, - stride_req_to_tokens_s, - kv_group_num, - b_prompt_cache_len, - H: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, -): - start_m = tl.program_id(0) - cur_bh = tl.program_id(1) - cur_batch = cur_bh // H - cur_head = cur_bh % H - - cur_kv_head = cur_head // kv_group_num - - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - prompt_cache_len = tl.load(b_prompt_cache_len + cur_batch) - total_len = tl.load(B_Seqlen + cur_batch) - cur_batch_seq_len = total_len - prompt_cache_len # NEW len - cur_batch_req_idx = tl.load(B_req_idx + cur_batch) - - block_start_loc = BLOCK_M * start_m - - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - offs_m = block_start_loc + tl.arange(0, BLOCK_M) - - # Q pointers - off_q = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs - + cur_head * stride_qh - + offs_d[None, :] * stride_qd - ) - - q_valid = offs_m < cur_batch_seq_len - q = tl.load(Q + off_q, mask=q_valid[:, None], other=0.0) - - # online softmax state - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) - block_end_loc = total_len - - # absolute q positions in the request - q_pos = prompt_cache_len + offs_m # [M] - - # q_gid from packed position_ids (aligned with Q rows) - q_gid = tl.load( - position_ids + cur_batch_in_all_start_index + offs_m, - mask=q_valid, - other=-2147483648, - ).to(tl.int32) - - BIG = tl.full([BLOCK_N], 1000000000, tl.int32) # ensure != any normal gid - - for start_n in range(0, block_mask * block_end_loc, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - - k_pos = start_n + offs_n # [N] - k_valid = k_pos < block_end_loc - - # map logical pos -> mem_index (for K/V) - kv_loc = tl.load( - Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * k_pos, - mask=k_valid, - other=0, - ).to(tl.int64) - - # load K - off_k = kv_loc[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd - k = tl.load(K + off_k, mask=k_valid[None, :], other=0.0) - - qk = tl.dot(q, k) - - # k_gid: - # - for cached keys (k_pos < prompt_cache_len): set to BIG + k_pos so equality is always false - # - for new keys (k_pos >= prompt_cache_len): read from packed position_ids by (k_pos - prompt_cache_len) - k_in_new = k_pos >= prompt_cache_len - k_new_idx = (k_pos - prompt_cache_len).to(tl.int32) # [N] valid only when k_in_new - k_gid_new = tl.load( - position_ids + cur_batch_in_all_start_index + k_new_idx, - mask=k_valid & k_in_new, - other=-2147483647, - ).to(tl.int32) - - k_gid = tl.where( - k_in_new, - k_gid_new, - (k_pos.to(tl.int32) + BIG), - ) - - # mask: causal OR same gid (only possible inside NEW part) - mask = (q_pos[:, None] >= k_pos[None, :]) | (q_gid[:, None] == k_gid[None, :]) - mask = mask & q_valid[:, None] & k_valid[None, :] - - qk = tl.where(mask, qk * sm_scale, -1.0e8) - - # online softmax - m_ij = tl.maximum(m_i, tl.max(qk, 1)) - qk -= m_ij[:, None] - p = tl.math.exp2(qk) - l_ij = tl.sum(p, 1) - - alpha = tl.math.exp2(m_i - m_ij) - l_i = l_i * alpha + l_ij - acc = acc * alpha[:, None] - - # load V - off_v = kv_loc[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd - v = tl.load(V + off_v, mask=k_valid[:, None], other=0.0) - - p = p.to(v.dtype) - acc = tl.dot(p, v, acc) - - m_i = m_ij - - acc = acc / l_i[:, None] - - off_o = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs - + cur_head * stride_oh - + offs_d[None, :] * stride_od - ) - tl.store(Out + off_o, acc, mask=q_valid[:, None]) - - -@torch.no_grad() -def context_attention_fwd_neo( - q, - k, - v, - o, - position_ids, # 1D packed like q (only NEW tokens) - b_req_idx, - b_start_loc, - b_seq_len, - b_prompt_cache_len, - max_input_len, - req_to_token_indexs, -): - # minimal safety: position_ids must cover packed q rows - assert position_ids.numel() >= q.shape[0], (position_ids.numel(), q.shape[0]) - - BLOCK_M = 128 if not is_tesla() else 64 - - Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] - assert Lq == Lk and Lk == Lv - assert Lk in {16, 32, 64, 128, 256} - base_head_dim = Lq // 2 - sm_scale = 1.0 / (base_head_dim ** 0.5) * 1.4426950408889634 - - batch, head = b_seq_len.shape[0], q.shape[1] - kv_group_num = q.shape[1] // k.shape[1] - - grid = lambda meta: (triton.cdiv(max_input_len, meta["BLOCK_M"]), batch * head, 1) - - BLOCK_N = BLOCK_M - num_warps = 4 if Lk <= 64 else 8 - num_stages = 1 - - _fwd_kernel[grid]( - q, - k, - v, - sm_scale, - o, - position_ids, - b_start_loc, - b_seq_len, - req_to_token_indexs, - b_req_idx, - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - v.stride(0), - v.stride(1), - v.stride(2), - o.stride(0), - o.stride(1), - o.stride(2), - req_to_token_indexs.stride(0), - req_to_token_indexs.stride(1), - kv_group_num=kv_group_num, - b_prompt_cache_len=b_prompt_cache_len, - H=head, - BLOCK_DMODEL=Lk, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - num_warps=num_warps, - num_stages=num_stages, - ) - - -def reference_attention( - q, - k, - v, - position_ids_q, # 1D packed like q (only NEW tokens) - b_req_idx, - b_start_loc, - b_seq_len, - b_prompt_cache_len, - req_to_token_indexs, -): - device = q.device - dtype = q.dtype - sum_q, Hq, D = q.shape - Hk = k.shape[1] - kv_group_num = Hq // Hk - - batch = b_seq_len.shape[0] - out = torch.empty_like(q) - scale = 1.0 / math.sqrt(D) - - for b in range(batch): - req = int(b_req_idx[b].item()) - total_len = int(b_seq_len[b].item()) - prompt_len = int(b_prompt_cache_len[b].item()) - new_len = total_len - prompt_len - - q_start = int(b_start_loc[b].item()) - q_blk = q[q_start : q_start + new_len] # [M, Hq, D] - gid_new = position_ids_q[q_start : q_start + new_len].to(torch.int64) # [M] - - # gather K/V for full request by logical pos -> mem_index - token_locs = req_to_token_indexs[req, :total_len].to(torch.int64) # [L] - k_blk = k[token_locs] # [L, Hk, D] - v_blk = v[token_locs] # [L, Hk, D] - - # expand kv heads to q heads (GQA) - k_hq = k_blk.repeat_interleave(kv_group_num, dim=1) # [L, Hq, D] - v_hq = v_blk.repeat_interleave(kv_group_num, dim=1) # [L, Hq, D] - - # positions - q_pos = torch.arange(prompt_len, total_len, device=device, dtype=torch.int64) # [M] - k_pos = torch.arange(0, total_len, device=device, dtype=torch.int64) # [L] - - # build allow mask: - # causal always - allow = k_pos[None, :] <= q_pos[:, None] - - # full-attn only inside NEW part by gid - # compare only when k_pos in NEW - k_in_new = k_pos >= prompt_len - k_rel = (k_pos - prompt_len).clamp_min(0) # [L] - # map k_rel to gid_new, but only valid where k_in_new - k_gid = torch.empty((total_len,), device=device, dtype=torch.int64) - k_gid[:] = 10 ** 12 + k_pos # never equal to gid_new - k_gid[k_in_new] = gid_new[k_rel[k_in_new]] - - allow = allow | (gid_new[q_pos - prompt_len][:, None] == k_gid[None, :]) - - # scores: [Hq, M, L] - q_t = q_blk.permute(1, 0, 2).to(torch.float32) # [Hq, M, D] - k_t = k_hq.permute(1, 2, 0).to(torch.float32) # [Hq, D, L] - scores = torch.matmul(q_t, k_t) * scale # [Hq, M, L] - - neg = torch.tensor(-1.0e9, device=device, dtype=torch.float32) - scores = torch.where(allow[None, :, :], scores, neg) - - p = torch.softmax(scores, dim=-1).to(torch.float32) # [Hq, M, L] - v_t = v_hq.permute(1, 0, 2).to(torch.float32) # [Hq, L, D] - out_hq = torch.matmul(p, v_t) # [Hq, M, D] - out_blk = out_hq.permute(1, 0, 2).to(dtype) # [M, Hq, D] - - out[q_start : q_start + new_len] = out_blk - - return out - - -def make_test_case( - device="cuda", - dtype=torch.float16, - batch=3, - Hq=8, - Hk=4, - D=64, - seed=0, - base_index=50000, -): - torch.manual_seed(seed) - - # prompt (cached) len and new len - prompt_lens = torch.randint(low=2, high=8, size=(batch,), device=device) - new_lens = torch.randint(low=1, high=8, size=(batch,), device=device) - total_lens = (prompt_lens + new_lens).to(torch.int32) - - max_total_len = int(total_lens.max().item()) - max_new_len = int(new_lens.max().item()) - - # packed q start - b_start_loc = torch.zeros((batch,), device=device, dtype=torch.int32) - cur = 0 - for b in range(batch): - b_start_loc[b] = cur - cur += int(new_lens[b].item()) - sum_q = cur - - b_seq_len = total_lens - b_prompt_cache_len = prompt_lens.to(torch.int32) - - # one req per batch - num_req = batch - b_req_idx = torch.arange(batch, device=device, dtype=torch.int32) - - # global KV space large, indices not small - sum_kv = int(total_lens.sum().item()) - kv_size = base_index + sum_kv + 1024 - pool = torch.randperm(kv_size - base_index, device=device, dtype=torch.int64)[:sum_kv] + base_index - - # Req_to_tokens [num_req, max_total_len] - req_to_token_indexs = torch.zeros((num_req, max_total_len), device=device, dtype=torch.int32) - p = 0 - for r in range(num_req): - L = int(total_lens[r].item()) - req_to_token_indexs[r, :L] = pool[p : p + L].to(torch.int32) - p += L - - # position_ids_q: only NEW tokens, packed like q - position_ids_q = torch.empty((sum_q,), device=device, dtype=torch.int32) - for b in range(batch): - M = int(new_lens[b].item()) - start = int(b_start_loc[b].item()) - - gid = torch.arange(M, device=device, dtype=torch.int32) - - # make one repeated block inside NEW part to simulate image tokens - if M >= 4 and torch.rand((), device=device).item() > 0.3: - s = int(torch.randint(0, M - 2, (1,), device=device).item()) - e = min(M, s + 3) - gid[s:e] = gid[s] - - position_ids_q[start : start + M] = gid - - q = torch.randn((sum_q, Hq, D), device=device, dtype=dtype) - k = torch.randn((kv_size, Hk, D), device=device, dtype=dtype) - v = torch.randn((kv_size, Hk, D), device=device, dtype=dtype) - o = torch.empty((sum_q, Hq, D), device=device, dtype=dtype) - - return ( - q, - k, - v, - o, - position_ids_q, - b_req_idx, - b_start_loc, - b_seq_len, - b_prompt_cache_len, - max_new_len, - req_to_token_indexs, - ) - - -def check_once(device="cuda", dtype=torch.float16, seed=0): - ( - q, - k, - v, - o, - position_ids_q, - b_req_idx, - b_start_loc, - b_seq_len, - b_prompt_cache_len, - max_new_len, - req_to_token_indexs, - ) = make_test_case(device=device, dtype=dtype, seed=seed) - - context_attention_fwd_neo( - q, - k, - v, - o, - position_ids_q, - b_req_idx, - b_start_loc, - b_seq_len, - b_prompt_cache_len, - max_new_len, - req_to_token_indexs, - ) - - ref = reference_attention( - q, - k, - v, - position_ids_q, - b_req_idx, - b_start_loc, - b_seq_len, - b_prompt_cache_len, - req_to_token_indexs, - ) - - diff = (o - ref).abs() - max_abs = diff.max().item() - denom = ref.abs().max().item() + 1e-6 - max_rel = max_abs / denom - - print(f"seed={seed}, dtype={dtype}") - print(f"max_abs_error = {max_abs:.6e}") - print(f"max_rel_error = {max_rel:.6e}") - print("allclose(fp16 tol)?", torch.allclose(o, ref, atol=5e-2, rtol=5e-2)) - - -if __name__ == "__main__": - if not torch.cuda.is_available(): - print("No CUDA, skip.") - else: - torch.cuda.synchronize() - check_once(dtype=torch.bfloat16, seed=0) - check_once(dtype=torch.bfloat16, seed=1) - check_once(dtype=torch.bfloat16, seed=2) diff --git a/lightllm/models/neo_chat/triton_kernel/get_neo_position.py b/lightllm/models/neo_chat/triton_kernel/get_neo_position.py deleted file mode 100644 index 955f48bd8..000000000 --- a/lightllm/models/neo_chat/triton_kernel/get_neo_position.py +++ /dev/null @@ -1,174 +0,0 @@ -import torch -import triton -import triton.language as tl - - -@triton.jit -def _get_neo_position_triton( - b_image_start_idx: torch.Tensor, - b_image_thwd: torch.Tensor, - b_image_thwd_stride0: torch.Tensor, - b_image_nums: torch.Tensor, - b_image_start_num: torch.Tensor, - b_image_len: torch.Tensor, - position_ids: torch.Tensor, - position_ids_stride0: torch.Tensor, - b_ready_cache_len: torch.Tensor, - b_q_seq_len: torch.Tensor, - b_start_loc: torch.Tensor, - BLOCK_SIZE: tl.constexpr, -) -> torch.Tensor: - cur_batch = tl.program_id(0) - cache_len = tl.load(b_ready_cache_len + cur_batch) - q_seq_len = tl.load(b_q_seq_len + cur_batch) - image_num = tl.load(b_image_nums + cur_batch) - image_start_num = tl.load(b_image_start_num + cur_batch) - start_loc = tl.load(b_start_loc + cur_batch) - for i in range(image_num): - local_image_start_idx = tl.load(b_image_start_idx + image_start_num + i) - image_start_idx = start_loc + local_image_start_idx - cache_len - image_len = tl.load(b_image_len + image_start_num + i) - # image_h = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 1) - image_w = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 2) - for j in range(0, image_len, BLOCK_SIZE): - off = j + tl.arange(0, BLOCK_SIZE) - # 目前没考虑视频,所以t 恒为 0 - t_pos = local_image_start_idx + off * 0 - h_pos = off // image_w - w_pos = off % image_w - tl.store( - position_ids + off + image_start_idx, - t_pos, - mask=(off < image_len) - & (off + local_image_start_idx - cache_len < q_seq_len) - & (local_image_start_idx - cache_len + off >= 0), - ) - tl.store( - position_ids + position_ids_stride0 + off + image_start_idx, - h_pos, - mask=(off < image_len) - & (off + local_image_start_idx - cache_len < q_seq_len) - & (local_image_start_idx - cache_len + off >= 0), - ) - tl.store( - position_ids + position_ids_stride0 * 2 + off + image_start_idx, - w_pos, - mask=(off < image_len) - & (off + local_image_start_idx - cache_len < q_seq_len) - & (local_image_start_idx - cache_len + off >= 0), - ) - - for i in range(image_num): - local_image_start_idx = tl.load(b_image_start_idx + image_start_num + i) - image_len = tl.load(b_image_len + image_start_num + i) - image_delta = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 3) - image_end = local_image_start_idx + image_len - cache_len - text_start = tl.maximum(0, image_end) - for j in range(text_start, q_seq_len, BLOCK_SIZE): - off = j + tl.arange(0, BLOCK_SIZE) - t_pos = tl.load(position_ids + off + start_loc, mask=(off < q_seq_len), other=0.0) + image_delta - h_pos = tl.load(position_ids + position_ids_stride0 + off + start_loc, mask=(off < q_seq_len), other=0.0) - w_pos = tl.load( - position_ids + position_ids_stride0 * 2 + off + start_loc, mask=(off < q_seq_len), other=0.0 - ) - tl.store(position_ids + off + start_loc, t_pos, mask=(off < q_seq_len)) - tl.store(position_ids + position_ids_stride0 + off + start_loc, h_pos, mask=(off < q_seq_len)) - tl.store(position_ids + position_ids_stride0 * 2 + off + start_loc, w_pos, mask=(off < q_seq_len)) - return - - -def get_neo_position_triton( - b_image_start_idx: torch.Tensor, - b_image_thwd: torch.Tensor, - b_image_nums: torch.Tensor, - b_image_start_num: torch.Tensor, - b_image_len: torch.Tensor, - position_ids: torch.Tensor, - b_ready_cache_len: torch.Tensor, - b_q_seq_len: torch.Tensor, - b_start_loc: torch.Tensor, -) -> torch.Tensor: - - batch_size = b_q_seq_len.shape[0] - assert batch_size == b_image_nums.shape[0] - grid = (batch_size,) - BLOCK_SIZE = 64 - _get_neo_position_triton[grid]( - b_image_start_idx=b_image_start_idx, - b_image_thwd=b_image_thwd, - b_image_thwd_stride0=b_image_thwd.stride(0), - b_image_nums=b_image_nums, - b_image_start_num=b_image_start_num, - b_image_len=b_image_len, - position_ids=position_ids, - position_ids_stride0=position_ids.stride(0), - b_ready_cache_len=b_ready_cache_len, - b_q_seq_len=b_q_seq_len, - b_start_loc=b_start_loc, - BLOCK_SIZE=BLOCK_SIZE, - ) - - -def test(): - b_image_start_idx = torch.tensor([0, 0, 4], dtype=torch.int32, device="cuda") - b_image_thwd = torch.tensor([[1, 2, 2, -3], [1, 2, 2, -3], [1, 2, 4, -7]], dtype=torch.int32, device="cuda") - b_image_nums = torch.tensor([1, 2], dtype=torch.int32, device="cuda") - b_image_start_num = torch.tensor([0, 1], dtype=torch.int32, device="cuda") - b_image_len = torch.tensor([4, 4, 8], dtype=torch.int32, device="cuda") - position_ids = ( - torch.tensor([0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=torch.int32, device="cuda") - .unsqueeze(0) - .expand(3, -1) - .contiguous() - ) - position_ids[1:].zero_() - b_ready_cache_len = torch.tensor([0, 0], dtype=torch.int32, device="cuda") - b_q_seq_len = torch.tensor([7, 13], dtype=torch.int32, device="cuda") - b_start_loc = torch.tensor([0, 7], dtype=torch.int32, device="cuda") - get_neo_position_triton( - b_image_start_idx, - b_image_thwd, - b_image_nums, - b_image_start_num, - b_image_len, - position_ids, - b_ready_cache_len, - b_q_seq_len, - b_start_loc, - ) - - print(position_ids) - # old_value = torch.cat([position_ids[:, 2:7], position_ids[:, 7 + 2 :]], dim=1) - - # position_ids = ( - # torch.tensor([2, 3, 4, 5, 6, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=torch.int32, device="cuda") - # .unsqueeze(0) - # .expand(3, -1) - # .contiguous() - # ) - # b_ready_cache_len = torch.tensor([2, 2], dtype=torch.int32, device="cuda") - # b_q_seq_len = torch.tensor([5, 11], dtype=torch.int32, device="cuda") - # b_start_loc = torch.tensor([0, 5], dtype=torch.int32, device="cuda") - - # get_neo_position_triton( - # b_image_start_idx, - # b_image_thwd, - # b_image_nums, - # b_image_start_num, - # b_image_len, - # position_ids, - # b_ready_cache_len, - # b_q_seq_len, - # b_start_loc, - # ) - - # print(f"old_value:\n{old_value}") - # print(f"position_ids:\n{position_ids}") - # assert torch.equal(old_value, position_ids) - - """ - tensor([[0, 0, 0, 0, 2, 3, 4, 0, 0, 0, 0, 2, 2, 2, 2, 4, 5, 6, 7, 8], - [0, 0, 1, 1, 2, 3, 4, 0, 0, 1, 1, 2, 2, 3, 3, 4, 5, 6, 7, 8], - [0, 1, 0, 1, 2, 3, 4, 0, 1, 0, 1, 2, 3, 2, 3, 4, 5, 6, 7, 8]], - device='cuda:0', dtype=torch.int32) - """ diff --git a/lightllm/models/neo_chat/vision_process.py b/lightllm/models/neo_chat/vision_process.py deleted file mode 100644 index aa008e18f..000000000 --- a/lightllm/models/neo_chat/vision_process.py +++ /dev/null @@ -1,141 +0,0 @@ -import re -import math -import torch -import string -import numpy as np -import pandas as pd -from PIL import Image -import torch.distributed as dist -import torchvision.transforms as T - -IMAGENET_MEAN = (0.485, 0.456, 0.406) -IMAGENET_STD = (0.229, 0.224, 0.225) - - -def round_by_factor(number: int, factor: int) -> int: - """Returns the closest integer to 'number' that is divisible by 'factor'.""" - return round(number / factor) * factor - - -def ceil_by_factor(number: int, factor: int) -> int: - """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" - return math.ceil(number / factor) * factor - - -def floor_by_factor(number: int, factor: int) -> int: - """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" - return math.floor(number / factor) * factor - - -# copy from https://github.com/QwenLM/Qwen2.5-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py#L60 -def smart_resize( - height: int, width: int, factor: int = 32, min_pixels: int = 65536, max_pixels: int = 4194304 -) -> tuple[int, int]: - """ - Rescales the image so that the following conditions are met: - - 1. Both dimensions (height and width) are divisible by 'factor'. - - 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. - - 3. The aspect ratio of the image is maintained as closely as possible. - """ - if max(height, width) / min(height, width) > 200: - raise ValueError( - f"absolute aspect ratio must be smaller than {200}, got {max(height, width) / min(height, width)}" - ) - h_bar = max(factor, round_by_factor(height, factor)) - w_bar = max(factor, round_by_factor(width, factor)) - if h_bar * w_bar > max_pixels: - beta = math.sqrt((height * width) / max_pixels) - h_bar = max(factor, floor_by_factor(height / beta, factor)) - w_bar = max(factor, floor_by_factor(width / beta, factor)) - elif h_bar * w_bar < min_pixels: - beta = math.sqrt(min_pixels / (height * width)) - h_bar = ceil_by_factor(height * beta, factor) - w_bar = ceil_by_factor(width * beta, factor) - return h_bar, w_bar - - -def dynamic_preprocess_native_resolution(image, size_factor=32, min_pixels=65536, max_pixels=4194304, **kwargs): - width, height = image.size - resized_height, resized_width = smart_resize( - height, - width, - factor=size_factor, - min_pixels=min_pixels, - max_pixels=max_pixels, - ) - image = image.resize((resized_width, resized_height)) - - return image - - -def preprocess_pixel_values(pixel_values, patch_size=16): - c, h, w = pixel_values.shape - grid_h = h // patch_size - grid_w = w // patch_size - - flatten_pixel_values = ( - pixel_values.view(c, grid_h, patch_size, grid_w, patch_size) - .permute(1, 3, 0, 2, 4) # [grid_h, grid_w, c, patch_size, patch_size] - .reshape(grid_h * grid_w, c * patch_size ** 2) - ) - - grid_hw = torch.tensor([[grid_h, grid_w]]).to(device=pixel_values.device) - - return flatten_pixel_values, grid_hw - - -def get_contrasting_background(image): - """ - Calculate the color (white or black) that is different from the average foreground color - to use as the background color - """ - image_np = np.array(image) - if (image_np[:, :, 3] == 0).any(): - non_transparent_pixels = image_np[:, :, :3][image_np[:, :, 3] > 0] - if non_transparent_pixels.size == 0: - return None - pixel_mean = non_transparent_pixels.mean() - contrasting_color = (0, 0, 0) if pixel_mean > 382.5 else (255, 255, 255) - return contrasting_color - else: - return None - - -def load_image_native(image, patch_size=16, downsample_ratio=0.5, min_pixels=65536, max_pixels=4194304, upscale=False): - """ - Load and preprocess an image file, converting it to RGB mode, - resizing, normalizing, and optionally adding a thumbnail version. - """ - if image.mode == "RGBA": - bg_color = get_contrasting_background(image) - if bg_color: - background = Image.new("RGB", image.size, bg_color) - background.paste(image, mask=image.split()[3]) - image = background.convert("RGB") - else: - image = image.convert("RGB") - else: - image = image.convert("RGB") - - if upscale: - image = image.resize((image.width * 2, image.height * 2), Image.BILINEAR) - - transform = T.Compose( - [ - T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), - T.ToTensor(), - T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), - ] - ) - - new_image = dynamic_preprocess_native_resolution( - image, size_factor=int(patch_size // downsample_ratio), min_pixels=min_pixels, max_pixels=max_pixels - ) - pixel_values, grid_hw = preprocess_pixel_values(transform(new_image).to(torch.float32), patch_size=patch_size) - - print(f"Transfer image_size from ({image.height, image.width}) to ({new_image.height, new_image.width})") - - return pixel_values, grid_hw diff --git a/lightllm/server/tokenizer.py b/lightllm/server/tokenizer.py index 17f5a741a..3563739f7 100644 --- a/lightllm/server/tokenizer.py +++ b/lightllm/server/tokenizer.py @@ -30,7 +30,7 @@ from ..models.qwen2_vl.model import QWen2VLTokenizer from ..models.qwen3_vl.model import QWen3VLTokenizer from ..models.internvl.model import InternvlTokenizer -from ..models.neo_chat.model import NeoChatTokenizer +from ..models.neo_chat_moe.model import NeoChatTokenizer from ..models.gemma3.model import Gemma3Tokenizer # A fast LLaMA tokenizer with the pre-processed `tokenizer.json` file. diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index d77271af8..df5d66bcb 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -19,7 +19,7 @@ from lightllm.models.qwen2_5_vl.qwen2_5_visual import Qwen2_5_VisionTransformerPretrainedModel from lightllm.models.qwen3_vl.qwen3_visual import Qwen3VisionTransformerPretrainedModel from lightllm.models.tarsier2.tarsier2_visual import TarsierVisionTransformerPretrainedModel -from lightllm.models.neo_chat.neo_visual import NeoVisionTransformerPretrainedModel +from lightllm.models.neo_chat_moe.neo_visual import NeoVisionTransformerPretrainedModel from lightllm.utils.infer_utils import set_random_seed from lightllm.utils.dist_utils import init_vision_distributed_env from lightllm.utils.graceful_utils import graceful_registry From 4b757ddb4150b6f0498d7b97b9a7dcd9de71d4a5 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Tue, 6 Jan 2026 08:36:24 +0000 Subject: [PATCH 22/71] add-neo-chat-dense --- lightllm/models/neo_chat_moe/__init__.py | 0 lightllm/models/neo_chat_moe/infer_struct.py | 99 ++++ .../neo_chat_moe/layer_infer/__init__.py | 0 .../layer_infer/transformer_layer_infer.py | 159 ++++++ .../neo_chat_moe/layer_weights/__init__.py | 0 .../pre_and_post_layer_weight.py | 23 + .../layer_weights/transformer_layer_weight.py | 51 ++ lightllm/models/neo_chat_moe/model.py | 139 ++++++ lightllm/models/neo_chat_moe/neo_visual.py | 279 +++++++++++ .../neo_chat_moe/triton_kernel/__init__.py | 0 .../context_attention_fwd_neo.py | 452 ++++++++++++++++++ .../triton_kernel/get_neo_position.py | 174 +++++++ .../models/neo_chat_moe/vision_process.py | 141 ++++++ 13 files changed, 1517 insertions(+) create mode 100644 lightllm/models/neo_chat_moe/__init__.py create mode 100644 lightllm/models/neo_chat_moe/infer_struct.py create mode 100644 lightllm/models/neo_chat_moe/layer_infer/__init__.py create mode 100644 lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py create mode 100644 lightllm/models/neo_chat_moe/layer_weights/__init__.py create mode 100644 lightllm/models/neo_chat_moe/layer_weights/pre_and_post_layer_weight.py create mode 100644 lightllm/models/neo_chat_moe/layer_weights/transformer_layer_weight.py create mode 100644 lightllm/models/neo_chat_moe/model.py create mode 100644 lightllm/models/neo_chat_moe/neo_visual.py create mode 100644 lightllm/models/neo_chat_moe/triton_kernel/__init__.py create mode 100644 lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py create mode 100644 lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py create mode 100644 lightllm/models/neo_chat_moe/vision_process.py diff --git a/lightllm/models/neo_chat_moe/__init__.py b/lightllm/models/neo_chat_moe/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/neo_chat_moe/infer_struct.py b/lightllm/models/neo_chat_moe/infer_struct.py new file mode 100644 index 000000000..0c7d9372e --- /dev/null +++ b/lightllm/models/neo_chat_moe/infer_struct.py @@ -0,0 +1,99 @@ +from typing import Optional, List +import torch +import numpy as np +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.common.req_manager import ReqManager +from lightllm.models.neo_chat_moe.triton_kernel.get_neo_position import get_neo_position_triton +from lightllm.models.llama.model import LlamaTpPartModel + + +class NeoChatInferStateInfo(LlamaInferStateInfo): + def __init__(self): + super().__init__() + self.position_cos = None + self.position_sin = None + self.position_cos_h = None + self.position_sin_h = None + self.position_cos_w = None + self.position_sin_w = None + + def init_some_extra_state(self, model: LlamaTpPartModel, input_ids: torch.Tensor): + LlamaInferStateInfo.init_some_extra_state(self, model, input_ids) + if self.is_prefill: + self.position_ids = self.get_neo_position(self.multimodal_params) + else: + b_position_delta = [0 for _ in range(self.b_seq_len.shape[0])] + for batch_idx, p in enumerate(self.multimodal_params): + position_delta = 0 + for image in p["images"]: + position_delta += image["grid_thwd"][3] + b_position_delta[batch_idx] = position_delta + position_ids = self.position_ids + torch.tensor(b_position_delta, device=self.position_ids.device) + self.position_ids = position_ids.unsqueeze(0).expand(3, -1).clone() + self.position_ids[1:].zero_() + + self.position_ids = self.position_ids.contiguous() + self.position_cos = model._cos_cached[self.position_ids[0]] + self.position_sin = model._sin_cached[self.position_ids[0]] + self.position_cos_h = model._hw_cos_cached[self.position_ids[1]] + self.position_sin_h = model._hw_sin_cached[self.position_ids[1]] + self.position_cos_w = model._hw_cos_cached[self.position_ids[2]] + self.position_sin_w = model._hw_sin_cached[self.position_ids[2]] + return + + def get_neo_position(self, multimodal_params: List[dict]) -> torch.Tensor: + if len(multimodal_params) == 0: + position_ids = self.position_ids.new_zeros((3, self.position_ids.size(0))) + position_ids[0].copy_(self.position_ids) + return position_ids + b_image_start_idx = [] + b_image_nums = [] + b_image_start_num = [] + b_image_len = [] + image_start_num = 0 + b_image_thwd = [] + + # pad multimodal_params to batch size. + batch_size = self.b_q_seq_len.shape[0] + multimodal_params = multimodal_params + [ + {"images": [], "audios": []} for _ in range(batch_size - len(multimodal_params)) + ] + + for _, p in enumerate(multimodal_params): + images = p.get("images", []) + for img in images: + b_image_start_idx.append(img["start_idx"]) + a = img["start_idx"] + print(f"img start_idx: {a}") + b_image_len.append(img["token_num"]) + b_image_thwd.append(img["grid_thwd"]) + b_image_nums.append(len(images)) + b_image_start_num.append(image_start_num) + image_start_num += len(images) + + # 没有任何图片 + if image_start_num == 0: + position_ids = self.position_ids.new_zeros((3, self.position_ids.size(0))) + position_ids[0].copy_(self.position_ids) + return position_ids.contiguous() + b_image_start_idx = torch.tensor(b_image_start_idx, device="cpu").cuda(non_blocking=True) + b_image_thwd = torch.tensor(b_image_thwd, device="cpu").cuda(non_blocking=True) # image_num x 4 + b_image_nums = torch.tensor(b_image_nums, device="cpu").cuda(non_blocking=True) + b_image_start_num = torch.tensor(b_image_start_num, device="cpu").cuda(non_blocking=True) + b_image_len = torch.tensor(b_image_len, device="cpu").cuda(non_blocking=True) + + position_ids = self.position_ids.new_zeros((3, self.position_ids.size(0))) + position_ids[0].copy_(self.position_ids) + + get_neo_position_triton( + b_image_start_idx=b_image_start_idx, + b_image_thwd=b_image_thwd, + b_image_nums=b_image_nums, + b_image_start_num=b_image_start_num, + b_image_len=b_image_len, + position_ids=position_ids, + b_ready_cache_len=self.b_ready_cache_len, + b_q_seq_len=self.b_q_seq_len, + b_start_loc=self.b_start_loc, + ) + return position_ids diff --git a/lightllm/models/neo_chat_moe/layer_infer/__init__.py b/lightllm/models/neo_chat_moe/layer_infer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py new file mode 100644 index 000000000..ed48a9c6f --- /dev/null +++ b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py @@ -0,0 +1,159 @@ +import torch +from functools import partial +from typing import Tuple +from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward +from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd +from lightllm.models.neo_chat_moe.infer_struct import NeoChatInferStateInfo +from lightllm.models.neo_chat_moe.triton_kernel.context_attention_fwd_neo import context_attention_fwd_neo +from lightllm.models.llama.triton_kernel.context_flashattention_nopad import context_attention_fwd +from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd +from lightllm.models.qwen3_moe.layer_infer.transformer_layer_infer import Qwen3MOETransformerLayerInfer +from lightllm.models.neo_chat_moe.layer_weights.transformer_layer_weight import NeoChatMOETransformerLayerWeight +from lightllm.distributed import all_reduce +import torch.distributed as dist +from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer +from lightllm.models.qwen3.triton_kernel.qk_norm import qk_rmsnorm_forward + + +class NeoChatMOETransformerLayerInfer(Qwen3MOETransformerLayerInfer): + def __init__(self, data_type, network_config, mode): + super().__init__(data_type, network_config, mode) + return + + def _bind_attention(self): + self._context_attention_kernel = self._context_attention_kernel + self._token_attention_kernel = self._token_decode_attention_normal + self._copy_kv_to_mem_cache = self._copy_kv_to_mem_cache_normal + return + + def _get_qkv(self, input, infer_state: NeoChatInferStateInfo, layer_weight: NeoChatMOETransformerLayerWeight): + input = input.view(-1, self.embed_dim_) + q = layer_weight.q_proj.mm(input) # [T, Hq*D] + + q_hw = layer_weight.q_hw_proj.mm(input) + q_hw = q_hw.view(-1, self.tp_q_head_num_, self.head_dim_) + q_h, q_w = q_hw.chunk(2, dim=-1) + + k_hw = layer_weight.k_hw_proj.mm(input) + k_hw = k_hw.view(-1, self.tp_k_head_num_, self.head_dim_) + k_h, k_w = k_hw.chunk(2, dim=-1) + + cache_kv = layer_weight.kv_proj.mm(input) # [T, (Hk+Hv)*D] + + qk_rmsnorm_forward(q, weight=layer_weight.q_norm_weight_.weight, eps=self.eps_) + + q_h_2d = q_h.reshape(q.shape[0], -1) + q_w_2d = q_w.reshape(q.shape[0], -1) + qk_rmsnorm_forward(q_h_2d, weight=layer_weight.q_norm_h_weight_.weight, eps=self.eps_) + qk_rmsnorm_forward(q_w_2d, weight=layer_weight.q_norm_w_weight_.weight, eps=self.eps_) + q_h = q_h_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2) + q_w = q_w_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2) + + qk_rmsnorm_forward( + cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], + weight=layer_weight.k_norm_weight_.weight, + eps=self.eps_, + ) + + k_h_2d = k_h.reshape(q.shape[0], -1) # [T, Hk*(D/2)] + k_w_2d = k_w.reshape(q.shape[0], -1) + qk_rmsnorm_forward(k_h_2d, weight=layer_weight.k_norm_h_weight_.weight, eps=self.eps_) + qk_rmsnorm_forward(k_w_2d, weight=layer_weight.k_norm_w_weight_.weight, eps=self.eps_) + k_h = k_h_2d.view(q.shape[0], self.tp_k_head_num_, self.head_dim_ // 2) + k_w = k_w_2d.view(q.shape[0], self.tp_k_head_num_, self.head_dim_ // 2) + + cache_kv = cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) + + rotary_emb_fwd( + q.view(-1, self.tp_q_head_num_, self.head_dim_), + cache_kv[:, : self.tp_k_head_num_, :], + infer_state.position_cos, + infer_state.position_sin, + ) + rotary_emb_fwd( + q_h, + k_h, + infer_state.position_cos_h, + infer_state.position_sin_h, + ) + rotary_emb_fwd( + q_w, + k_w, + infer_state.position_cos_w, + infer_state.position_sin_w, + ) + + q3 = q.view(-1, self.tp_q_head_num_, self.head_dim_) + q3 = torch.cat([q3, q_h, q_w], dim=-1) + q = q3.reshape(q3.shape[0], -1) + + k = cache_kv[:, : self.tp_k_head_num_, :] + k = torch.cat([k, k_h, k_w], dim=-1) + + v = cache_kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :] + v_pad = torch.zeros((v.shape[0], v.shape[1], self.head_dim_), device=v.device, dtype=v.dtype) + v = torch.cat([v, v_pad], dim=-1) + + cache_kv = torch.cat([k, v], dim=1) + return q, cache_kv + + def _context_attention_kernel( + self, q, kv, infer_state: NeoChatInferStateInfo, layer_weight, out=None + ) -> torch.Tensor: + o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out + kv = infer_state.mem_manager.kv_buffer[self.layer_num_] + context_attention_fwd_neo( + q.view(-1, self.tp_q_head_num_, self.head_dim_ * 2), + kv[:, 0 : self.tp_k_head_num_, :], + kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :], + o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2), + infer_state.position_ids[0], # [0,0,1,2,3,3,3,4] + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.b_ready_cache_len, + infer_state.max_len_in_batch, + infer_state.req_manager.req_to_token_indexs, + ) + o3 = o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2) + o3 = o3[:, :, : self.head_dim_].contiguous() + return o3.view(o3.shape[0], -1) + + def _token_decode_attention_normal(self, q, infer_state: NeoChatInferStateInfo, layer_weight, out=None): + total_token_num = infer_state.total_token_num + batch_size = infer_state.batch_size + + q_3d = q.view(batch_size, self.tp_q_head_num_, self.head_dim_ * 2) + + att_m_tensor = self.alloc_tensor((self.tp_q_head_num_, total_token_num), torch.float32) + + k_3d = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] + token_att_fwd( + q_3d, + k_3d, + att_m_tensor, + infer_state.req_manager.req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.max_len_in_batch, + ) + + from lightllm.models.llama.triton_kernel.token_attention_softmax_and_reducev import token_softmax_reducev_fwd + + v_3d = infer_state.mem_manager.kv_buffer[self.layer_num_][ + :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : self.head_dim_ + ] + + o_3d = self.alloc_tensor((batch_size, self.tp_q_head_num_, self.head_dim_), q.dtype) if out is None else out + + token_softmax_reducev_fwd( + att_m_tensor, + v_3d, + o_3d, + infer_state.req_manager.req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + ) + return o_3d.view(batch_size, -1) diff --git a/lightllm/models/neo_chat_moe/layer_weights/__init__.py b/lightllm/models/neo_chat_moe/layer_weights/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/neo_chat_moe/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/neo_chat_moe/layer_weights/pre_and_post_layer_weight.py new file mode 100644 index 000000000..7766a5d29 --- /dev/null +++ b/lightllm/models/neo_chat_moe/layer_weights/pre_and_post_layer_weight.py @@ -0,0 +1,23 @@ +import torch +import numpy as np +from lightllm.models.qwen2.layer_weights.pre_and_post_layer_weight import Qwen2PreAndPostLayerWeight + +# add key: language_model.xxx -> xxx +# only change keys at PreAndPostLayerWeight load, TransformLayerWeight is correct now +def rename_weight_keys(weights): + prefix = "language_model." + keys = list(weights.keys()) + for k in keys: + if prefix in k: + weights[k.replace(prefix, "")] = weights.pop(k) + + +class NeoChatMOEPreAndPostLayerWeight(Qwen2PreAndPostLayerWeight): + def __init__(self, data_type, network_config, mode): + super().__init__(data_type, network_config, mode) + return + + def load_hf_weights(self, weights): + rename_weight_keys(weights) + super().load_hf_weights(weights) + return diff --git a/lightllm/models/neo_chat_moe/layer_weights/transformer_layer_weight.py b/lightllm/models/neo_chat_moe/layer_weights/transformer_layer_weight.py new file mode 100644 index 000000000..bc38f1adc --- /dev/null +++ b/lightllm/models/neo_chat_moe/layer_weights/transformer_layer_weight.py @@ -0,0 +1,51 @@ +from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ( + NormWeight, + ROWMMWeight, +) + + +class NeoChatMOETransformerLayerWeight(Qwen3MOETransformerLayerWeight): + def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): + super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + return + + def _init_weight_names(self): + super()._init_weight_names() + self._q_weight_hw_name = f"model.layers.{self.layer_num_}.self_attn.q_proj_hw.weight" + self._q_bias_hw_name = None + self._k_weight_hw_name = f"model.layers.{self.layer_num_}.self_attn.k_proj_hw.weight" + self._k_bias_hw_name = None + + self._q_norm_h_name = f"model.layers.{self.layer_num_}.self_attn.q_norm_h.weight" + self._q_norm_w_name = f"model.layers.{self.layer_num_}.self_attn.q_norm_w.weight" + + self._k_norm_h_name = f"model.layers.{self.layer_num_}.self_attn.k_norm_h.weight" + self._k_norm_w_name = f"model.layers.{self.layer_num_}.self_attn.k_norm_w.weight" + + def _init_qkv(self): + super()._init_qkv() + self.q_hw_proj = ROWMMWeight( + weight_names=self._q_weight_hw_name, + data_type=self.data_type_, + bias_names=self._q_bias_hw_name, + quant_cfg=self.quant_cfg, + layer_num=self.layer_num_, + name="q_hw_proj", + ) + self.k_hw_proj = ROWMMWeight( + weight_names=self._k_weight_hw_name, + data_type=self.data_type_, + bias_names=self._k_bias_hw_name, + quant_cfg=self.quant_cfg, + layer_num=self.layer_num_, + name="k_hw_proj", + ) + + def _init_norm(self): + super()._init_norm() + + self.q_norm_h_weight_ = NormWeight(weight_name=self._q_norm_h_name, data_type=self.data_type_) + self.q_norm_w_weight_ = NormWeight(weight_name=self._q_norm_w_name, data_type=self.data_type_) + self.k_norm_h_weight_ = NormWeight(weight_name=self._k_norm_h_name, data_type=self.data_type_) + self.k_norm_w_weight_ = NormWeight(weight_name=self._k_norm_w_name, data_type=self.data_type_) diff --git a/lightllm/models/neo_chat_moe/model.py b/lightllm/models/neo_chat_moe/model.py new file mode 100644 index 000000000..e4123d109 --- /dev/null +++ b/lightllm/models/neo_chat_moe/model.py @@ -0,0 +1,139 @@ +import os +import json +from lightllm.common.build_utils import repair_config +from lightllm.models.registry import ModelRegistry, llm_model_type_is +from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo +from lightllm.models.qwen3_vl.layer_infer.pre_layer_infer import Qwen3VLMultimodalPreLayerInfer +from lightllm.models.qwen3_vl.layer_infer.transformer_layer_infer import Qwen3VLTransformerLayerInfer +from lightllm.models.qwen3_vl.layer_weights.pre_and_post_layer_weight import Qwen3VLPreAndPostLayerWeight +from lightllm.models.qwen2_vl.model import QWen2VLTokenizer +from lightllm.models.qwen3.model import Qwen3TpPartModel +from lightllm.server.core.objs import SamplingParams +from lightllm.models.qwen3_moe.model import Qwen3MOEModel +from lightllm.server.multimodal_params import AudioItem, MultimodalParams, ImageItem +from lightllm.models.neo_chat_moe.vision_process import smart_resize +from lightllm.models.internvl.model import InternvlTokenizer +from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer +from lightllm.models.neo_chat_moe.layer_infer.transformer_layer_infer import NeoChatMOETransformerLayerInfer +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.models.neo_chat_moe.layer_weights.transformer_layer_weight import NeoChatMOETransformerLayerWeight +from lightllm.models.neo_chat_moe.layer_weights.pre_and_post_layer_weight import NeoChatMOEPreAndPostLayerWeight +from lightllm.common.basemodel.multimodal_tokenizer import BaseMultiModalTokenizer +from lightllm.models.neo_chat_moe.infer_struct import NeoChatInferStateInfo + +IMG_START_TOKEN = "" +IMG_END_TOKEN = "" +IMG_TOKEN = "" +AUDIO_START_TOKEN = "" + + +class NeoChatTokenizer(BaseMultiModalTokenizer): + def __init__(self, tokenizer, model_cfg, **kwargs): + super().__init__(tokenizer) + self.tokenizer = tokenizer + self.min_pixel = model_cfg.get("vision_config").get("min_pixels") + self.max_pixel = model_cfg.get("vision_config").get("max_pixels") + self.patch_size = model_cfg.get("vision_config").get("patch_size") + self.downsample_ratio = model_cfg.get("vision_config").get("downsample_ratio") + + self.image_token_id = model_cfg.get("image_token_id") + self.image_start_tag = IMG_START_TOKEN + self.image_start_id = tokenizer.convert_tokens_to_ids(self.image_start_tag) + self.image_end_tag = IMG_END_TOKEN + self.image_end_id = tokenizer.convert_tokens_to_ids(self.image_end_tag) + + def init_imageitem_extral_params( + self, img: ImageItem, multi_params: MultimodalParams, sampling_params: SamplingParams + ): + return + + def init_audioitem_extral_params( + self, audio: AudioItem, multi_params: MultimodalParams, sampling_params: SamplingParams + ): + raise NotImplementedError + + def get_audio_token_length(self, audio: AudioItem): + raise NotImplementedError + + def get_image_token_length(self, img: ImageItem): + width, height = img.image_w, img.image_h + resized_height, resized_width = smart_resize( + height=height, + width=width, + factor=int(self.patch_size // self.downsample_ratio), + min_pixels=self.min_pixel, + max_pixels=self.max_pixel, + ) + grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size + token_num = int((grid_h * grid_w) * (self.downsample_ratio ** 2)) + # 这里的grid_h和grid_w需要* self.downsample_ratio么?再仔细看下代码 + img.grid_thwd = (1, int(grid_h * self.downsample_ratio), int(grid_w * self.downsample_ratio), 1 - token_num) + return token_num + + # only change the impl of the encode func: + def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs): + # TEXTTEXTTEXT --> TEXTTEXTTEXT + image_tokens = IMG_START_TOKEN + IMG_END_TOKEN + if multimodal_params is None: + add_special_tokens = kwargs.get("add_special_tokens", True) + return self.tokenizer.encode(prompt, add_special_tokens=add_special_tokens) + image_count = len(multimodal_params.images) + prompt = prompt.replace(IMG_TOKEN, image_tokens, image_count) + + origin_ids = self.tokenizer.encode(prompt, add_special_tokens=kwargs["add_special_tokens"]) + # --> id,id+1...id+num + input_ids = [] + image_id = 0 + start_idx = 0 + while True: + try: + start_idx = origin_ids.index(self.image_start_id) + if start_idx + 1 >= len(origin_ids): + break + if origin_ids[start_idx + 1] == self.image_end_id: + input_ids.extend(origin_ids[: start_idx + 1]) + token_id = multimodal_params.images[image_id].token_id + token_num = multimodal_params.images[image_id].token_num + multimodal_params.images[image_id].start_idx = len(input_ids) + input_ids.extend(range(token_id, token_id + token_num)) + input_ids.append(self.image_end_id) + origin_ids = origin_ids[start_idx + 2 :] + image_id += 1 + else: + raise ValueError("image token error") + except ValueError: + break + input_ids.extend(origin_ids) + return input_ids + + +@ModelRegistry(["neo_chat"], is_multimodal=True, condition=llm_model_type_is("qwen3_moe")) +class NeoTpMOEPartModel(Qwen3MOEModel): + + pre_layer_infer_class = LlamaMultimodalPreLayerInfer + transformer_layer_infer_class = NeoChatMOETransformerLayerInfer + + pre_and_post_weight_class = NeoChatMOEPreAndPostLayerWeight + transformer_weight_class = NeoChatMOETransformerLayerWeight + + infer_state_class = NeoChatInferStateInfo + + def __init__(self, kvargs): + super().__init__(kvargs) + return + + def _init_inferstate_cls(self): + pass + + def _init_config(self): + with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: + all_config = json.load(json_file) + self.config = all_config["llm_config"] + # rename keys + repair_config(self.config, same_names=["num_attention_heads", "n_head"]) + repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"]) + repair_config(self.config, same_names=["num_hidden_layers", "n_layer"]) + if self.finetune_config: + self.config["vocab_size"] = self.finetune_config.vocab_size + return diff --git a/lightllm/models/neo_chat_moe/neo_visual.py b/lightllm/models/neo_chat_moe/neo_visual.py new file mode 100644 index 000000000..852ddc095 --- /dev/null +++ b/lightllm/models/neo_chat_moe/neo_visual.py @@ -0,0 +1,279 @@ +import os +import torch +import torch.nn.functional as F +from PIL import Image +from typing import List +from io import BytesIO +import torch.nn as nn +from transformers.activations import ACT2FN +from safetensors import safe_open +from lightllm.server.multimodal_params import ImageItem +from transformers.modeling_outputs import BaseModelOutputWithPooling +from transformers.modeling_utils import PreTrainedModel +from lightllm.models.neo_chat_moe.vision_process import load_image_native +from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data + + +def apply_rotary_emb_1d( + x: torch.Tensor, + cos_cached: torch.Tensor, + sin_cached: torch.Tensor, + positions: torch.Tensor, +): + """对输入张量的一部分应用1D RoPE。""" + # x: (..., seq_len, dim_part) + # positions: (..., seq_len) + # cos_cached: (max_pos, dim_part / 2) + cos_cached = cos_cached.to(device=positions.device) + sin_cached = sin_cached.to(device=positions.device) + + cos = cos_cached[positions] # Shape: (positions.shape, dim_part / 2) + sin = sin_cached[positions] # Shape: (positions.shape, dim_part / 2) + + x1 = x[..., 0::2] + x2 = x[..., 1::2] + + rotated_x1 = x1 * cos - x2 * sin + rotated_x2 = x1 * sin + x2 * cos + + x_rotated = torch.empty_like(x) + x_rotated[..., 0::2] = rotated_x1 + x_rotated[..., 1::2] = rotated_x2 + return x_rotated + + +def apply_2d_rotary_pos_emb( + x: torch.Tensor, + cos_cached_x: torch.Tensor, + sin_cached_x: torch.Tensor, + cos_cached_y: torch.Tensor, + sin_cached_y: torch.Tensor, + abs_positions_x: torch.Tensor, + abs_positions_y: torch.Tensor, +): + """应用2D RoPE到输入张量x。""" + dim = x.shape[-1] + dim_half = dim // 2 + + # 假设我们将embedding的前半部分用于一个方向的RoPE,后半部分用于另一个方向 + # 例如,前一半给X坐标,后一半给Y坐标 (或者反过来,但要保持一致) + x_part_1 = x[..., :dim_half] + x_part_2 = x[..., dim_half:] + + # 将与 abs_positions_x 相关的旋转应用于 x_part_1 + rotated_part_1 = apply_rotary_emb_1d(x_part_1, cos_cached_x, sin_cached_x, abs_positions_x) + # 将与 abs_positions_y 相关的旋转应用于 x_part_2 + rotated_part_2 = apply_rotary_emb_1d(x_part_2, cos_cached_y, sin_cached_y, abs_positions_y) + + # 将它们重新拼接起来。确保顺序与你分割时一致。 + return torch.cat((rotated_part_1, rotated_part_2), dim=-1) + + +def build_abs_positions_from_grid_hw(grid_hw: torch.Tensor, device=None): + """ + Compute patch coordinates (x, y) + + Args: + grid_hw: (B, 2) tensor representing (H, W) per image + """ + device = grid_hw.device + B = grid_hw.shape[0] + + # Get the number of patches per image + H = grid_hw[:, 0] + W = grid_hw[:, 1] + N = H * W + N_total = N.sum() + + # Create the batch index for each patch (B x patch count) + patch_to_sample = torch.repeat_interleave(torch.arange(B, device=device), N) # (N_total,) + + # Generate intra-image patch index (row-major order) + patch_id_within_image = torch.arange(N_total, device=device) + patch_id_within_image = ( + patch_id_within_image + - torch.cumsum(torch.cat([torch.tensor([0], device=device), N[:-1]]), dim=0)[patch_to_sample] + ) + + # Get H/W for each patch according to its image + W_per_patch = W[patch_to_sample] + abs_x = patch_id_within_image % W_per_patch + abs_y = patch_id_within_image // W_per_patch + + return abs_x, abs_y + + +class NeoVisionTransformerPretrainedModel(nn.Module): + def __init__( + self, + kvargs, + hidden_size: int = 1024, + llm_hidden_size: int = 2048, + downsample_ratio: float = 0.5, + patch_size: int = 16, + num_channels: int = 3, + max_position_embeddings_vision: int = 10000, + rope_theta_vision: float = 10000.0, + min_pixels: int = 65536, + max_pixels: int = 2408448, + **kwargs, + ): + super().__init__() + self.weight_dir = kvargs["weight_dir"] + self.data_type = kvargs.get("data_type", "bfloat16") + self.embed_dim = hidden_size + self.llm_hidden_size = llm_hidden_size + self.patch_size = patch_size + self.num_channels = num_channels + self.downsample_ratio = downsample_ratio + self.downsample_factor = int(1 / downsample_ratio) + self.max_position_embeddings_vision = max_position_embeddings_vision + self.rope_theta_vision = rope_theta_vision + self.rope_dim_part = self.embed_dim // 2 + self.min_pixels = min_pixels + self.max_pixels = max_pixels + + self.patch_embedding = nn.Conv2d( + in_channels=num_channels, out_channels=self.embed_dim, kernel_size=patch_size, stride=patch_size + ) + + self.dense_embedding = nn.Conv2d( + in_channels=self.embed_dim, + out_channels=self.llm_hidden_size, + kernel_size=self.downsample_factor, + stride=self.downsample_factor, + ) + self.gelu = nn.GELU() + + self.repe_dim_part = self.embed_dim // 2 + self.cos_x, self.sin_x = self.precompute_rope_freqs_sincos() + self.cos_y, self.sin_y = self.precompute_rope_freqs_sincos() + self._init_datatype() + + def _init_datatype(self): + if isinstance(self.data_type, torch.dtype): + return + if self.data_type in ["fp16", "float16"]: + self.data_type = torch.float16 + elif self.data_type in ["bf16", "bfloat16"]: + self.data_type = torch.bfloat16 + elif self.data_type in ["fp32", "float32"]: + self.data_type = torch.float32 + else: + raise ValueError(f"Unsupport datatype {self.data_type}!") + return + + def load_model(self, weight_dir): + bin_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")] + if bin_weight_files: + weight_dict = {} + for file_ in bin_weight_files: + f = torch.load(os.path.join(weight_dir, file_), "cpu") + for k, v in f.items(): + if "vision_model" in k: + weight_dict[k[len("vision_model.embeddings.") :]] = v + else: + hf_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".safetensors")] + weight_dict = {} + for file_ in hf_weight_files: + f = safe_open(os.path.join(weight_dir, file_), "pt", "cpu") + for k in f.keys(): + if "vision_model" in k: + weight_dict[k[len("vision_model.embeddings.") :]] = f.get_tensor(k) + self.load_state_dict(weight_dict) + + def precompute_rope_freqs_sincos(self): + inv_freq = 1.0 / ( + self.rope_theta_vision ** (torch.arange(0, self.rope_dim_part, 2).float() / self.rope_dim_part) + ) + t = torch.arange(self.max_position_embeddings_vision).type_as(inv_freq) + freqs = torch.outer(t, inv_freq) + return torch.cos(freqs), torch.sin(freqs) + + def _apply_2d_rotary_pos_emb(self, patch_embeds, grid_hw): + """ + Apply 2D Rotary Position Embedding to the patch embeddings. + """ + abs_pos_x, abs_pos_y = build_abs_positions_from_grid_hw(grid_hw, device=patch_embeds.device) + embeddings = apply_2d_rotary_pos_emb( + patch_embeds.to(torch.float32), # RoPE calculations are often more stable in float32 + self.cos_x, + self.sin_x, + self.cos_y, + self.sin_y, + abs_pos_x, + abs_pos_y, + ).to(self.patch_embedding.weight.dtype) + return embeddings + + def forward(self, pixel_values: torch.Tensor, grid_hw: torch.Tensor) -> torch.Tensor: + pixel_values = pixel_values.view( + -1, + 3, + self.patch_size, + self.patch_size, + ) + patch_embeds = self.gelu(self.patch_embedding(pixel_values)).view(-1, self.embed_dim) + patch_embeds = self._apply_2d_rotary_pos_emb(patch_embeds, grid_hw) + assert (grid_hw[:, 0] * grid_hw[:, 1]).sum() == patch_embeds.shape[ + 0 + ], "Grid size and patch embeds size mismatch." + + patches_list = [] + cur_position = 0 + for i in range(grid_hw.shape[0]): + h, w = grid_hw[i] + patches_per_img = patch_embeds[cur_position : cur_position + h * w].view(h, w, -1).unsqueeze(0) + patches_per_img = self.dense_embedding(patches_per_img.permute(0, 3, 1, 2)) + patches_per_img = patches_per_img.permute(0, 2, 3, 1) + patches_list.append(patches_per_img.view(-1, patches_per_img.shape[-1])) + cur_position += h * w + + embeddings = torch.cat(patches_list, dim=0) # (N_total // downsample_factor**2, C) + assert cur_position == patch_embeds.shape[0] + assert embeddings.shape[0] == int(patch_embeds.shape[0] / self.downsample_factor ** 2) + + return embeddings + + def encode(self, images: List[ImageItem]): + img_tensors = [] + valid_ids = [] + valid_id = 0 + img_grids = [] + uuids = [] + + for i, img in enumerate(images): + if isinstance(img, ImageItem): + uuids.append(img.uuid) + image_data = read_shm(get_shm_name_data(img.uuid)) + image_data = Image.open(BytesIO(image_data)) + pixel_values, image_grid_hw = load_image_native( + image_data, + patch_size=self.patch_size, + downsample_ratio=self.downsample_ratio, + min_pixels=self.min_pixels, + max_pixels=self.max_pixels, + ) + img_tensors.append(pixel_values) + img_grids.append(image_grid_hw) + else: + raise Exception("Unsupport input types: {} for {}".format(type(img), img)) + + # must devide merge_length + cur_num = int(img_tensors[-1].shape[0] * (self.downsample_ratio ** 2)) + print(f"cur_num is {cur_num}") + valid_ids.append([valid_id, valid_id + cur_num]) + valid_id += cur_num + + if len(img_tensors) <= 0: + return None + + imgs = torch.cat(img_tensors, dim=0) + grid_hw = torch.cat(img_grids, dim=0) + + pixel_values = imgs.to("cuda", dtype=self.data_type, non_blocking=True) + image_grid_hw = grid_hw.to("cuda", non_blocking=True) + + all_img_embeds = self.forward(pixel_values, grid_hw=image_grid_hw) + + return all_img_embeds, uuids, valid_ids diff --git a/lightllm/models/neo_chat_moe/triton_kernel/__init__.py b/lightllm/models/neo_chat_moe/triton_kernel/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py b/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py new file mode 100644 index 000000000..f5dae493c --- /dev/null +++ b/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py @@ -0,0 +1,452 @@ +import math +import torch +import triton +import triton.language as tl + +from lightllm.utils.device_utils import is_tesla + + +@triton.jit +def _fwd_kernel( + Q, + K, + V, + sm_scale, + Out, + position_ids, # 1D: packed like Q (only NEW tokens), length == Q.shape[0] + B_Start_Loc, + B_Seqlen, + Req_to_tokens, + B_req_idx, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_req_to_tokens_b, + stride_req_to_tokens_s, + kv_group_num, + b_prompt_cache_len, + H: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + start_m = tl.program_id(0) + cur_bh = tl.program_id(1) + cur_batch = cur_bh // H + cur_head = cur_bh % H + + cur_kv_head = cur_head // kv_group_num + + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + prompt_cache_len = tl.load(b_prompt_cache_len + cur_batch) + total_len = tl.load(B_Seqlen + cur_batch) + cur_batch_seq_len = total_len - prompt_cache_len # NEW len + cur_batch_req_idx = tl.load(B_req_idx + cur_batch) + + block_start_loc = BLOCK_M * start_m + + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = block_start_loc + tl.arange(0, BLOCK_M) + + # Q pointers + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + + offs_d[None, :] * stride_qd + ) + + q_valid = offs_m < cur_batch_seq_len + q = tl.load(Q + off_q, mask=q_valid[:, None], other=0.0) + + # online softmax state + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) + block_end_loc = total_len + + # absolute q positions in the request + q_pos = prompt_cache_len + offs_m # [M] + + # q_gid from packed position_ids (aligned with Q rows) + q_gid = tl.load( + position_ids + cur_batch_in_all_start_index + offs_m, + mask=q_valid, + other=-2147483648, + ).to(tl.int32) + + BIG = tl.full([BLOCK_N], 1000000000, tl.int32) # ensure != any normal gid + + for start_n in range(0, block_mask * block_end_loc, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + + k_pos = start_n + offs_n # [N] + k_valid = k_pos < block_end_loc + + # map logical pos -> mem_index (for K/V) + kv_loc = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * k_pos, + mask=k_valid, + other=0, + ).to(tl.int64) + + # load K + off_k = kv_loc[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd + k = tl.load(K + off_k, mask=k_valid[None, :], other=0.0) + + qk = tl.dot(q, k) + + # k_gid: + # - for cached keys (k_pos < prompt_cache_len): set to BIG + k_pos so equality is always false + # - for new keys (k_pos >= prompt_cache_len): read from packed position_ids by (k_pos - prompt_cache_len) + k_in_new = k_pos >= prompt_cache_len + k_new_idx = (k_pos - prompt_cache_len).to(tl.int32) # [N] valid only when k_in_new + k_gid_new = tl.load( + position_ids + cur_batch_in_all_start_index + k_new_idx, + mask=k_valid & k_in_new, + other=-2147483647, + ).to(tl.int32) + + k_gid = tl.where( + k_in_new, + k_gid_new, + (k_pos.to(tl.int32) + BIG), + ) + + # mask: causal OR same gid (only possible inside NEW part) + mask = (q_pos[:, None] >= k_pos[None, :]) | (q_gid[:, None] == k_gid[None, :]) + mask = mask & q_valid[:, None] & k_valid[None, :] + + qk = tl.where(mask, qk * sm_scale, -1.0e8) + + # online softmax + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_ij[:, None] + p = tl.math.exp2(qk) + l_ij = tl.sum(p, 1) + + alpha = tl.math.exp2(m_i - m_ij) + l_i = l_i * alpha + l_ij + acc = acc * alpha[:, None] + + # load V + off_v = kv_loc[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd + v = tl.load(V + off_v, mask=k_valid[:, None], other=0.0) + + p = p.to(v.dtype) + acc = tl.dot(p, v, acc) + + m_i = m_ij + + acc = acc / l_i[:, None] + + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + + offs_d[None, :] * stride_od + ) + tl.store(Out + off_o, acc, mask=q_valid[:, None]) + + +@torch.no_grad() +def context_attention_fwd_neo( + q, + k, + v, + o, + position_ids, # 1D packed like q (only NEW tokens) + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + max_input_len, + req_to_token_indexs, +): + # minimal safety: position_ids must cover packed q rows + assert position_ids.numel() >= q.shape[0], (position_ids.numel(), q.shape[0]) + + BLOCK_M = 128 if not is_tesla() else 64 + + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128, 256} + base_head_dim = Lq // 2 + sm_scale = 1.0 / (base_head_dim ** 0.5) * 1.4426950408889634 + + batch, head = b_seq_len.shape[0], q.shape[1] + kv_group_num = q.shape[1] // k.shape[1] + + grid = lambda meta: (triton.cdiv(max_input_len, meta["BLOCK_M"]), batch * head, 1) + + BLOCK_N = BLOCK_M + num_warps = 4 if Lk <= 64 else 8 + num_stages = 1 + + _fwd_kernel[grid]( + q, + k, + v, + sm_scale, + o, + position_ids, + b_start_loc, + b_seq_len, + req_to_token_indexs, + b_req_idx, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + req_to_token_indexs.stride(0), + req_to_token_indexs.stride(1), + kv_group_num=kv_group_num, + b_prompt_cache_len=b_prompt_cache_len, + H=head, + BLOCK_DMODEL=Lk, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + num_warps=num_warps, + num_stages=num_stages, + ) + + +def reference_attention( + q, + k, + v, + position_ids_q, # 1D packed like q (only NEW tokens) + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + req_to_token_indexs, +): + device = q.device + dtype = q.dtype + sum_q, Hq, D = q.shape + Hk = k.shape[1] + kv_group_num = Hq // Hk + + batch = b_seq_len.shape[0] + out = torch.empty_like(q) + scale = 1.0 / math.sqrt(D) + + for b in range(batch): + req = int(b_req_idx[b].item()) + total_len = int(b_seq_len[b].item()) + prompt_len = int(b_prompt_cache_len[b].item()) + new_len = total_len - prompt_len + + q_start = int(b_start_loc[b].item()) + q_blk = q[q_start : q_start + new_len] # [M, Hq, D] + gid_new = position_ids_q[q_start : q_start + new_len].to(torch.int64) # [M] + + # gather K/V for full request by logical pos -> mem_index + token_locs = req_to_token_indexs[req, :total_len].to(torch.int64) # [L] + k_blk = k[token_locs] # [L, Hk, D] + v_blk = v[token_locs] # [L, Hk, D] + + # expand kv heads to q heads (GQA) + k_hq = k_blk.repeat_interleave(kv_group_num, dim=1) # [L, Hq, D] + v_hq = v_blk.repeat_interleave(kv_group_num, dim=1) # [L, Hq, D] + + # positions + q_pos = torch.arange(prompt_len, total_len, device=device, dtype=torch.int64) # [M] + k_pos = torch.arange(0, total_len, device=device, dtype=torch.int64) # [L] + + # build allow mask: + # causal always + allow = k_pos[None, :] <= q_pos[:, None] + + # full-attn only inside NEW part by gid + # compare only when k_pos in NEW + k_in_new = k_pos >= prompt_len + k_rel = (k_pos - prompt_len).clamp_min(0) # [L] + # map k_rel to gid_new, but only valid where k_in_new + k_gid = torch.empty((total_len,), device=device, dtype=torch.int64) + k_gid[:] = 10 ** 12 + k_pos # never equal to gid_new + k_gid[k_in_new] = gid_new[k_rel[k_in_new]] + + allow = allow | (gid_new[q_pos - prompt_len][:, None] == k_gid[None, :]) + + # scores: [Hq, M, L] + q_t = q_blk.permute(1, 0, 2).to(torch.float32) # [Hq, M, D] + k_t = k_hq.permute(1, 2, 0).to(torch.float32) # [Hq, D, L] + scores = torch.matmul(q_t, k_t) * scale # [Hq, M, L] + + neg = torch.tensor(-1.0e9, device=device, dtype=torch.float32) + scores = torch.where(allow[None, :, :], scores, neg) + + p = torch.softmax(scores, dim=-1).to(torch.float32) # [Hq, M, L] + v_t = v_hq.permute(1, 0, 2).to(torch.float32) # [Hq, L, D] + out_hq = torch.matmul(p, v_t) # [Hq, M, D] + out_blk = out_hq.permute(1, 0, 2).to(dtype) # [M, Hq, D] + + out[q_start : q_start + new_len] = out_blk + + return out + + +def make_test_case( + device="cuda", + dtype=torch.float16, + batch=3, + Hq=8, + Hk=4, + D=64, + seed=0, + base_index=50000, +): + torch.manual_seed(seed) + + # prompt (cached) len and new len + prompt_lens = torch.randint(low=2, high=8, size=(batch,), device=device) + new_lens = torch.randint(low=1, high=8, size=(batch,), device=device) + total_lens = (prompt_lens + new_lens).to(torch.int32) + + max_total_len = int(total_lens.max().item()) + max_new_len = int(new_lens.max().item()) + + # packed q start + b_start_loc = torch.zeros((batch,), device=device, dtype=torch.int32) + cur = 0 + for b in range(batch): + b_start_loc[b] = cur + cur += int(new_lens[b].item()) + sum_q = cur + + b_seq_len = total_lens + b_prompt_cache_len = prompt_lens.to(torch.int32) + + # one req per batch + num_req = batch + b_req_idx = torch.arange(batch, device=device, dtype=torch.int32) + + # global KV space large, indices not small + sum_kv = int(total_lens.sum().item()) + kv_size = base_index + sum_kv + 1024 + pool = torch.randperm(kv_size - base_index, device=device, dtype=torch.int64)[:sum_kv] + base_index + + # Req_to_tokens [num_req, max_total_len] + req_to_token_indexs = torch.zeros((num_req, max_total_len), device=device, dtype=torch.int32) + p = 0 + for r in range(num_req): + L = int(total_lens[r].item()) + req_to_token_indexs[r, :L] = pool[p : p + L].to(torch.int32) + p += L + + # position_ids_q: only NEW tokens, packed like q + position_ids_q = torch.empty((sum_q,), device=device, dtype=torch.int32) + for b in range(batch): + M = int(new_lens[b].item()) + start = int(b_start_loc[b].item()) + + gid = torch.arange(M, device=device, dtype=torch.int32) + + # make one repeated block inside NEW part to simulate image tokens + if M >= 4 and torch.rand((), device=device).item() > 0.3: + s = int(torch.randint(0, M - 2, (1,), device=device).item()) + e = min(M, s + 3) + gid[s:e] = gid[s] + + position_ids_q[start : start + M] = gid + + q = torch.randn((sum_q, Hq, D), device=device, dtype=dtype) + k = torch.randn((kv_size, Hk, D), device=device, dtype=dtype) + v = torch.randn((kv_size, Hk, D), device=device, dtype=dtype) + o = torch.empty((sum_q, Hq, D), device=device, dtype=dtype) + + return ( + q, + k, + v, + o, + position_ids_q, + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + max_new_len, + req_to_token_indexs, + ) + + +def check_once(device="cuda", dtype=torch.float16, seed=0): + ( + q, + k, + v, + o, + position_ids_q, + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + max_new_len, + req_to_token_indexs, + ) = make_test_case(device=device, dtype=dtype, seed=seed) + + context_attention_fwd_neo( + q, + k, + v, + o, + position_ids_q, + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + max_new_len, + req_to_token_indexs, + ) + + ref = reference_attention( + q, + k, + v, + position_ids_q, + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + req_to_token_indexs, + ) + + diff = (o - ref).abs() + max_abs = diff.max().item() + denom = ref.abs().max().item() + 1e-6 + max_rel = max_abs / denom + + print(f"seed={seed}, dtype={dtype}") + print(f"max_abs_error = {max_abs:.6e}") + print(f"max_rel_error = {max_rel:.6e}") + print("allclose(fp16 tol)?", torch.allclose(o, ref, atol=5e-2, rtol=5e-2)) + + +if __name__ == "__main__": + if not torch.cuda.is_available(): + print("No CUDA, skip.") + else: + torch.cuda.synchronize() + check_once(dtype=torch.bfloat16, seed=0) + check_once(dtype=torch.bfloat16, seed=1) + check_once(dtype=torch.bfloat16, seed=2) diff --git a/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py b/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py new file mode 100644 index 000000000..955f48bd8 --- /dev/null +++ b/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py @@ -0,0 +1,174 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _get_neo_position_triton( + b_image_start_idx: torch.Tensor, + b_image_thwd: torch.Tensor, + b_image_thwd_stride0: torch.Tensor, + b_image_nums: torch.Tensor, + b_image_start_num: torch.Tensor, + b_image_len: torch.Tensor, + position_ids: torch.Tensor, + position_ids_stride0: torch.Tensor, + b_ready_cache_len: torch.Tensor, + b_q_seq_len: torch.Tensor, + b_start_loc: torch.Tensor, + BLOCK_SIZE: tl.constexpr, +) -> torch.Tensor: + cur_batch = tl.program_id(0) + cache_len = tl.load(b_ready_cache_len + cur_batch) + q_seq_len = tl.load(b_q_seq_len + cur_batch) + image_num = tl.load(b_image_nums + cur_batch) + image_start_num = tl.load(b_image_start_num + cur_batch) + start_loc = tl.load(b_start_loc + cur_batch) + for i in range(image_num): + local_image_start_idx = tl.load(b_image_start_idx + image_start_num + i) + image_start_idx = start_loc + local_image_start_idx - cache_len + image_len = tl.load(b_image_len + image_start_num + i) + # image_h = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 1) + image_w = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 2) + for j in range(0, image_len, BLOCK_SIZE): + off = j + tl.arange(0, BLOCK_SIZE) + # 目前没考虑视频,所以t 恒为 0 + t_pos = local_image_start_idx + off * 0 + h_pos = off // image_w + w_pos = off % image_w + tl.store( + position_ids + off + image_start_idx, + t_pos, + mask=(off < image_len) + & (off + local_image_start_idx - cache_len < q_seq_len) + & (local_image_start_idx - cache_len + off >= 0), + ) + tl.store( + position_ids + position_ids_stride0 + off + image_start_idx, + h_pos, + mask=(off < image_len) + & (off + local_image_start_idx - cache_len < q_seq_len) + & (local_image_start_idx - cache_len + off >= 0), + ) + tl.store( + position_ids + position_ids_stride0 * 2 + off + image_start_idx, + w_pos, + mask=(off < image_len) + & (off + local_image_start_idx - cache_len < q_seq_len) + & (local_image_start_idx - cache_len + off >= 0), + ) + + for i in range(image_num): + local_image_start_idx = tl.load(b_image_start_idx + image_start_num + i) + image_len = tl.load(b_image_len + image_start_num + i) + image_delta = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 3) + image_end = local_image_start_idx + image_len - cache_len + text_start = tl.maximum(0, image_end) + for j in range(text_start, q_seq_len, BLOCK_SIZE): + off = j + tl.arange(0, BLOCK_SIZE) + t_pos = tl.load(position_ids + off + start_loc, mask=(off < q_seq_len), other=0.0) + image_delta + h_pos = tl.load(position_ids + position_ids_stride0 + off + start_loc, mask=(off < q_seq_len), other=0.0) + w_pos = tl.load( + position_ids + position_ids_stride0 * 2 + off + start_loc, mask=(off < q_seq_len), other=0.0 + ) + tl.store(position_ids + off + start_loc, t_pos, mask=(off < q_seq_len)) + tl.store(position_ids + position_ids_stride0 + off + start_loc, h_pos, mask=(off < q_seq_len)) + tl.store(position_ids + position_ids_stride0 * 2 + off + start_loc, w_pos, mask=(off < q_seq_len)) + return + + +def get_neo_position_triton( + b_image_start_idx: torch.Tensor, + b_image_thwd: torch.Tensor, + b_image_nums: torch.Tensor, + b_image_start_num: torch.Tensor, + b_image_len: torch.Tensor, + position_ids: torch.Tensor, + b_ready_cache_len: torch.Tensor, + b_q_seq_len: torch.Tensor, + b_start_loc: torch.Tensor, +) -> torch.Tensor: + + batch_size = b_q_seq_len.shape[0] + assert batch_size == b_image_nums.shape[0] + grid = (batch_size,) + BLOCK_SIZE = 64 + _get_neo_position_triton[grid]( + b_image_start_idx=b_image_start_idx, + b_image_thwd=b_image_thwd, + b_image_thwd_stride0=b_image_thwd.stride(0), + b_image_nums=b_image_nums, + b_image_start_num=b_image_start_num, + b_image_len=b_image_len, + position_ids=position_ids, + position_ids_stride0=position_ids.stride(0), + b_ready_cache_len=b_ready_cache_len, + b_q_seq_len=b_q_seq_len, + b_start_loc=b_start_loc, + BLOCK_SIZE=BLOCK_SIZE, + ) + + +def test(): + b_image_start_idx = torch.tensor([0, 0, 4], dtype=torch.int32, device="cuda") + b_image_thwd = torch.tensor([[1, 2, 2, -3], [1, 2, 2, -3], [1, 2, 4, -7]], dtype=torch.int32, device="cuda") + b_image_nums = torch.tensor([1, 2], dtype=torch.int32, device="cuda") + b_image_start_num = torch.tensor([0, 1], dtype=torch.int32, device="cuda") + b_image_len = torch.tensor([4, 4, 8], dtype=torch.int32, device="cuda") + position_ids = ( + torch.tensor([0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=torch.int32, device="cuda") + .unsqueeze(0) + .expand(3, -1) + .contiguous() + ) + position_ids[1:].zero_() + b_ready_cache_len = torch.tensor([0, 0], dtype=torch.int32, device="cuda") + b_q_seq_len = torch.tensor([7, 13], dtype=torch.int32, device="cuda") + b_start_loc = torch.tensor([0, 7], dtype=torch.int32, device="cuda") + get_neo_position_triton( + b_image_start_idx, + b_image_thwd, + b_image_nums, + b_image_start_num, + b_image_len, + position_ids, + b_ready_cache_len, + b_q_seq_len, + b_start_loc, + ) + + print(position_ids) + # old_value = torch.cat([position_ids[:, 2:7], position_ids[:, 7 + 2 :]], dim=1) + + # position_ids = ( + # torch.tensor([2, 3, 4, 5, 6, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=torch.int32, device="cuda") + # .unsqueeze(0) + # .expand(3, -1) + # .contiguous() + # ) + # b_ready_cache_len = torch.tensor([2, 2], dtype=torch.int32, device="cuda") + # b_q_seq_len = torch.tensor([5, 11], dtype=torch.int32, device="cuda") + # b_start_loc = torch.tensor([0, 5], dtype=torch.int32, device="cuda") + + # get_neo_position_triton( + # b_image_start_idx, + # b_image_thwd, + # b_image_nums, + # b_image_start_num, + # b_image_len, + # position_ids, + # b_ready_cache_len, + # b_q_seq_len, + # b_start_loc, + # ) + + # print(f"old_value:\n{old_value}") + # print(f"position_ids:\n{position_ids}") + # assert torch.equal(old_value, position_ids) + + """ + tensor([[0, 0, 0, 0, 2, 3, 4, 0, 0, 0, 0, 2, 2, 2, 2, 4, 5, 6, 7, 8], + [0, 0, 1, 1, 2, 3, 4, 0, 0, 1, 1, 2, 2, 3, 3, 4, 5, 6, 7, 8], + [0, 1, 0, 1, 2, 3, 4, 0, 1, 0, 1, 2, 3, 2, 3, 4, 5, 6, 7, 8]], + device='cuda:0', dtype=torch.int32) + """ diff --git a/lightllm/models/neo_chat_moe/vision_process.py b/lightllm/models/neo_chat_moe/vision_process.py new file mode 100644 index 000000000..aa008e18f --- /dev/null +++ b/lightllm/models/neo_chat_moe/vision_process.py @@ -0,0 +1,141 @@ +import re +import math +import torch +import string +import numpy as np +import pandas as pd +from PIL import Image +import torch.distributed as dist +import torchvision.transforms as T + +IMAGENET_MEAN = (0.485, 0.456, 0.406) +IMAGENET_STD = (0.229, 0.224, 0.225) + + +def round_by_factor(number: int, factor: int) -> int: + """Returns the closest integer to 'number' that is divisible by 'factor'.""" + return round(number / factor) * factor + + +def ceil_by_factor(number: int, factor: int) -> int: + """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" + return math.ceil(number / factor) * factor + + +def floor_by_factor(number: int, factor: int) -> int: + """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" + return math.floor(number / factor) * factor + + +# copy from https://github.com/QwenLM/Qwen2.5-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py#L60 +def smart_resize( + height: int, width: int, factor: int = 32, min_pixels: int = 65536, max_pixels: int = 4194304 +) -> tuple[int, int]: + """ + Rescales the image so that the following conditions are met: + + 1. Both dimensions (height and width) are divisible by 'factor'. + + 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. + + 3. The aspect ratio of the image is maintained as closely as possible. + """ + if max(height, width) / min(height, width) > 200: + raise ValueError( + f"absolute aspect ratio must be smaller than {200}, got {max(height, width) / min(height, width)}" + ) + h_bar = max(factor, round_by_factor(height, factor)) + w_bar = max(factor, round_by_factor(width, factor)) + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = max(factor, floor_by_factor(height / beta, factor)) + w_bar = max(factor, floor_by_factor(width / beta, factor)) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = ceil_by_factor(height * beta, factor) + w_bar = ceil_by_factor(width * beta, factor) + return h_bar, w_bar + + +def dynamic_preprocess_native_resolution(image, size_factor=32, min_pixels=65536, max_pixels=4194304, **kwargs): + width, height = image.size + resized_height, resized_width = smart_resize( + height, + width, + factor=size_factor, + min_pixels=min_pixels, + max_pixels=max_pixels, + ) + image = image.resize((resized_width, resized_height)) + + return image + + +def preprocess_pixel_values(pixel_values, patch_size=16): + c, h, w = pixel_values.shape + grid_h = h // patch_size + grid_w = w // patch_size + + flatten_pixel_values = ( + pixel_values.view(c, grid_h, patch_size, grid_w, patch_size) + .permute(1, 3, 0, 2, 4) # [grid_h, grid_w, c, patch_size, patch_size] + .reshape(grid_h * grid_w, c * patch_size ** 2) + ) + + grid_hw = torch.tensor([[grid_h, grid_w]]).to(device=pixel_values.device) + + return flatten_pixel_values, grid_hw + + +def get_contrasting_background(image): + """ + Calculate the color (white or black) that is different from the average foreground color + to use as the background color + """ + image_np = np.array(image) + if (image_np[:, :, 3] == 0).any(): + non_transparent_pixels = image_np[:, :, :3][image_np[:, :, 3] > 0] + if non_transparent_pixels.size == 0: + return None + pixel_mean = non_transparent_pixels.mean() + contrasting_color = (0, 0, 0) if pixel_mean > 382.5 else (255, 255, 255) + return contrasting_color + else: + return None + + +def load_image_native(image, patch_size=16, downsample_ratio=0.5, min_pixels=65536, max_pixels=4194304, upscale=False): + """ + Load and preprocess an image file, converting it to RGB mode, + resizing, normalizing, and optionally adding a thumbnail version. + """ + if image.mode == "RGBA": + bg_color = get_contrasting_background(image) + if bg_color: + background = Image.new("RGB", image.size, bg_color) + background.paste(image, mask=image.split()[3]) + image = background.convert("RGB") + else: + image = image.convert("RGB") + else: + image = image.convert("RGB") + + if upscale: + image = image.resize((image.width * 2, image.height * 2), Image.BILINEAR) + + transform = T.Compose( + [ + T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), + T.ToTensor(), + T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), + ] + ) + + new_image = dynamic_preprocess_native_resolution( + image, size_factor=int(patch_size // downsample_ratio), min_pixels=min_pixels, max_pixels=max_pixels + ) + pixel_values, grid_hw = preprocess_pixel_values(transform(new_image).to(torch.float32), patch_size=patch_size) + + print(f"Transfer image_size from ({image.height, image.width}) to ({new_image.height, new_image.width})") + + return pixel_values, grid_hw From e208733708e86834fafd21f17adc818e4145fa3e Mon Sep 17 00:00:00 2001 From: Weichao Luo Date: Thu, 8 Jan 2026 16:42:33 +0800 Subject: [PATCH 23/71] support verl. --- lightllm/server/api_http.py | 1 + lightllm/server/api_start.py | 16 +++++++++++++++- lightllm/server/core/objs/start_args_type.py | 2 -- lightllm/server/httpserver/manager.py | 4 ++-- lightllm/utils/device_utils.py | 3 ++- lightllm/utils/serializer.py | 2 ++ lightllm/utils/torch_memory_saver_utils.py | 2 +- 7 files changed, 23 insertions(+), 7 deletions(-) diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index ff9acafc9..a933a7947 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -335,6 +335,7 @@ async def handle_request_common(request_obj, handler): else: return create_error_response(HTTPStatus.BAD_REQUEST, ret.msg) except Exception as e: + logger.error("handle_request_common (%s) error occurred: %s", str(request_obj), str(e), exc_info=True) return create_error_response(HTTPStatus.EXPECTATION_FAILED, f"error: {str(e)}") diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index ffd794b2d..f47d0ddd4 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -1,3 +1,4 @@ +import multiprocessing as mp import os import sys import time @@ -83,7 +84,13 @@ def signal_handler(sig, frame): return -def normal_or_p_d_start(args: StartArgs): +def _set_envs_and_config(args: StartArgs): + mp.set_start_method("spawn", force=True) + + +def _launch_subprocesses(args: StartArgs): + + _set_envs_and_config(args) set_unique_server_name(args) if not args.disable_shm_warning: @@ -350,6 +357,13 @@ def normal_or_p_d_start(args: StartArgs): ], ) + return process_manager + + +def normal_or_p_d_start(args: StartArgs): + + process_manager = _launch_subprocesses(args) + # 启动 gunicorn command = [ "gunicorn", diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 0af795a09..80719745a 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -151,5 +151,3 @@ class StartArgs: weight_version: str = "default" - enable_torch_memory_saver: bool = field(default=False) - enable_weight_cpu_backup: bool = field(default=False) diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 083e939ba..5136032ad 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -924,7 +924,7 @@ async def update_weights_from_distributed(self, request: UpdateWeightsFromDistri await self.abort_request(AbortReq(abort_all=True)) if request.flush_cache: - await self.flush_cache() + await self.flush_cache(FlushCacheReq()) return await self.http_to_model_special_request( GeneralHttpToModelRpcReq(func_name="update_weights_from_distributed", func_args=request) @@ -935,7 +935,7 @@ async def update_weights_from_tensor(self, request: UpdateWeightsFromTensorReq) await self.abort_request(AbortReq(abort_all=True)) if request.flush_cache: - await self.flush_cache() + await self.flush_cache(FlushCacheReq()) return await self.http_to_model_special_request( GeneralHttpToModelRpcReq(func_name="update_weights_from_tensor", func_args=request) diff --git a/lightllm/utils/device_utils.py b/lightllm/utils/device_utils.py index d2b6d06a8..5d58d5d7b 100644 --- a/lightllm/utils/device_utils.py +++ b/lightllm/utils/device_utils.py @@ -85,7 +85,8 @@ def get_current_device_name(): gpu_name = torch.cuda.get_device_name(device).replace(" ", "_") return gpu_name else: - raise RuntimeError("No GPU available") + return "unknown" # need fix + # raise RuntimeError("No GPU available") @lru_cache(maxsize=None) diff --git a/lightllm/utils/serializer.py b/lightllm/utils/serializer.py index e0b523303..ae6f418df 100644 --- a/lightllm/utils/serializer.py +++ b/lightllm/utils/serializer.py @@ -88,6 +88,8 @@ class SafeUnpickler(pickle.Unpickler): "sglang.srt.model_executor.model_runner.", "sglang.srt.layers.", "sglang.srt.utils.", + # --- LightLLM --- + "lightllm.utils.", } DENY_CLASSES = { diff --git a/lightllm/utils/torch_memory_saver_utils.py b/lightllm/utils/torch_memory_saver_utils.py index edf15fa83..c1184ef30 100644 --- a/lightllm/utils/torch_memory_saver_utils.py +++ b/lightllm/utils/torch_memory_saver_utils.py @@ -20,7 +20,7 @@ class MemoryTag(Enum): KV_CACHE = "kv_cache" - WEIGHT = "weight" + WEIGHT = "weights" GRAPH = "graph" def is_kv_cache(self): From 245357cc25d503e63fc9f6f2690769dd2cfd48a3 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Thu, 8 Jan 2026 11:40:51 +0000 Subject: [PATCH 24/71] improve0108 --- .../common/kv_cache_mem_manager/__init__.py | 2 + .../kv_cache_mem_manager/mem_manager.py | 2 +- .../common/kv_cache_mem_manager/mem_utils.py | 7 ++- .../kv_cache_mem_manager/neo_mem_manager.py | 46 +++++++++++++++++++ lightllm/utils/kv_cache_utils.py | 20 ++++++-- 5 files changed, 71 insertions(+), 6 deletions(-) create mode 100755 lightllm/common/kv_cache_mem_manager/neo_mem_manager.py diff --git a/lightllm/common/kv_cache_mem_manager/__init__.py b/lightllm/common/kv_cache_mem_manager/__init__.py index 66caf5d78..a78026144 100644 --- a/lightllm/common/kv_cache_mem_manager/__init__.py +++ b/lightllm/common/kv_cache_mem_manager/__init__.py @@ -6,6 +6,7 @@ from .ppl_int4kv_mem_manager import PPLINT4KVMemoryManager from .deepseek2_mem_manager import Deepseek2MemoryManager from .deepseek2_fp8kv_mem_manager import Deepseek2FP8KVMemoryManager +from .neo_mem_manager import NeoMemoryManager __all__ = [ "MemoryManager", @@ -17,4 +18,5 @@ "PPLINT8KVMemoryManager", "Deepseek2MemoryManager", "Deepseek2FP8KVMemoryManager", + "NeoMemoryManager", ] diff --git a/lightllm/common/kv_cache_mem_manager/mem_manager.py b/lightllm/common/kv_cache_mem_manager/mem_manager.py index b599bedfc..64483a79d 100755 --- a/lightllm/common/kv_cache_mem_manager/mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/mem_manager.py @@ -28,7 +28,7 @@ class MemoryManager: def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): self.size = size self.head_num = head_num - self.head_dim = head_dim * 2 # neo kv 是[k, k_h, k_w]拼在一起的 + self.head_dim = head_dim self.layer_num = layer_num self.always_copy = always_copy self.dtype = dtype diff --git a/lightllm/common/kv_cache_mem_manager/mem_utils.py b/lightllm/common/kv_cache_mem_manager/mem_utils.py index 259c5a56f..b655a274b 100644 --- a/lightllm/common/kv_cache_mem_manager/mem_utils.py +++ b/lightllm/common/kv_cache_mem_manager/mem_utils.py @@ -7,6 +7,7 @@ PPLINT4KVMemoryManager, Deepseek2MemoryManager, Deepseek2FP8KVMemoryManager, + NeoMemoryManager, ) from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import get_env_start_args @@ -23,7 +24,7 @@ def select_mem_manager_class(): # case 1 # 先判断是否是 deepseek 系列的模型 model_class = get_llm_model_class() - from lightllm.models import Deepseek2TpPartModel + from lightllm.models import Deepseek2TpPartModel, NeoTpMOEPartModel, NeoTpPartModel if issubclass(model_class, Deepseek2TpPartModel): mem_class = Deepseek2MemoryManager @@ -32,6 +33,10 @@ def select_mem_manager_class(): logger.info(f"Model kv cache using mode {mode}, mem_manager class: {mem_class}") return mem_class + # 判断是否是 neo 系列的模型 + elif issubclass(model_class, NeoTpMOEPartModel) or issubclass(model_class, NeoTpPartModel): + mem_class = NeoMemoryManager + return mem_class # case normal logger.info(f"mode setting params: {mode}") diff --git a/lightllm/common/kv_cache_mem_manager/neo_mem_manager.py b/lightllm/common/kv_cache_mem_manager/neo_mem_manager.py new file mode 100755 index 000000000..0a79aa072 --- /dev/null +++ b/lightllm/common/kv_cache_mem_manager/neo_mem_manager.py @@ -0,0 +1,46 @@ +import torch +from lightllm.utils.dist_utils import get_current_rank_in_node +from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt +from lightllm.common.kv_cache_mem_manager.mem_manager import MemoryManager + + +class NeoMemoryManager(MemoryManager): + def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): + self.size = size + self.head_num = head_num + self.head_dim = head_dim * 2 # neo kv 是[k, k_h, k_w]拼在一起的 + self.layer_num = layer_num + self.always_copy = always_copy + self.dtype = dtype + # profile the max total token num if the size is None + self.profile_size(mem_fraction) + + self.mem_state = torch.arange( + 0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True + ) + self._mem_state_return = torch.arange( + 0, self.size * 3, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True + ) + self._return_start = 0 + self.mark_start = 0 + self.mark_end = self.size + + self.can_use_mem_size = self.size + + # 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。 + from lightllm.utils.envs_utils import get_unique_server_name + + rank_in_node = get_current_rank_in_node() + self.shared_can_use_token_num = SharedInt( + f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}" + ) + + self.shared_can_use_token_num.set_value(self.can_use_mem_size) + self._init_buffers( + self.size, + dtype, + head_num, + self.head_dim, + layer_num, + ) + self.HOLD_TOKEN_MEMINDEX = self.size diff --git a/lightllm/utils/kv_cache_utils.py b/lightllm/utils/kv_cache_utils.py index ed183e393..26d50f810 100644 --- a/lightllm/utils/kv_cache_utils.py +++ b/lightllm/utils/kv_cache_utils.py @@ -21,6 +21,7 @@ PPLINT4KVMemoryManager, Deepseek2MemoryManager, Deepseek2FP8KVMemoryManager, + NeoMemoryManager, ) from typing import List, Tuple, Optional @@ -83,26 +84,37 @@ def calcu_cpu_cache_meta() -> "CpuKVCacheMeta": scale_head_dim=0, scale_data_type=get_llm_data_type(), ) - elif mem_manager_class is MemoryManager: + elif mem_manager_class is PPLINT8KVMemoryManager: cpu_cache_meta = CpuKVCacheMeta( page_num=0, token_page_size=args.cpu_cache_token_page_size, layer_num=get_layer_num(args.model_dir), num_heads=get_num_key_value_heads(args.model_dir) * 2, head_dim=get_head_dim(args.model_dir), + data_type=torch.int8, + scale_head_dim=get_head_dim(args.model_dir) // 8, + scale_data_type=get_llm_data_type(), + ) + elif mem_manager_class is PPLINT8KVMemoryManager: + cpu_cache_meta = CpuKVCacheMeta( + page_num=0, + token_page_size=args.cpu_cache_token_page_size, + layer_num=get_layer_num(args.model_dir), + num_heads=get_num_key_value_heads(args.model_dir) * 2, + head_dim=get_head_dim(args.model_dir) * 2, data_type=get_llm_data_type(), scale_head_dim=0, scale_data_type=get_llm_data_type(), ) - elif mem_manager_class is PPLINT8KVMemoryManager: + elif mem_manager_class is MemoryManager: cpu_cache_meta = CpuKVCacheMeta( page_num=0, token_page_size=args.cpu_cache_token_page_size, layer_num=get_layer_num(args.model_dir), num_heads=get_num_key_value_heads(args.model_dir) * 2, head_dim=get_head_dim(args.model_dir), - data_type=torch.int8, - scale_head_dim=get_head_dim(args.model_dir) // 8, + data_type=get_llm_data_type(), + scale_head_dim=0, scale_data_type=get_llm_data_type(), ) else: From 6503ac8040e283998c01d903cba4ec72e8061943 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Thu, 8 Jan 2026 12:58:42 +0000 Subject: [PATCH 25/71] add min/max pixels sampling parameters --- lightllm/models/neo_chat_moe/model.py | 13 +++++++++++-- lightllm/models/neo_chat_moe/neo_visual.py | 8 +++++--- lightllm/server/core/objs/sampling_params.py | 7 ++++++- 3 files changed, 22 insertions(+), 6 deletions(-) diff --git a/lightllm/models/neo_chat_moe/model.py b/lightllm/models/neo_chat_moe/model.py index e4123d109..ce2093d45 100644 --- a/lightllm/models/neo_chat_moe/model.py +++ b/lightllm/models/neo_chat_moe/model.py @@ -46,6 +46,15 @@ def __init__(self, tokenizer, model_cfg, **kwargs): def init_imageitem_extral_params( self, img: ImageItem, multi_params: MultimodalParams, sampling_params: SamplingParams ): + img.extra_params["min_pixels"] = ( + sampling_params.min_pixels if sampling_params.min_pixels > 0 else self.min_pixel + ) + img.extra_params["max_pixels"] = ( + sampling_params.max_pixels if sampling_params.max_pixels > 0 else self.max_pixel + ) + assert ( + img.extra_params["min_pixels"] <= img.extra_params["max_pixels"] + ), "min_pixels should be less than or equal to max_pixels" return def init_audioitem_extral_params( @@ -62,8 +71,8 @@ def get_image_token_length(self, img: ImageItem): height=height, width=width, factor=int(self.patch_size // self.downsample_ratio), - min_pixels=self.min_pixel, - max_pixels=self.max_pixel, + min_pixels=img.extra_params["min_pixels"], + max_pixels=img.extra_params["max_pixels"], ) grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size token_num = int((grid_h * grid_w) * (self.downsample_ratio ** 2)) diff --git a/lightllm/models/neo_chat_moe/neo_visual.py b/lightllm/models/neo_chat_moe/neo_visual.py index 852ddc095..59bd23e2b 100644 --- a/lightllm/models/neo_chat_moe/neo_visual.py +++ b/lightllm/models/neo_chat_moe/neo_visual.py @@ -247,12 +247,15 @@ def encode(self, images: List[ImageItem]): uuids.append(img.uuid) image_data = read_shm(get_shm_name_data(img.uuid)) image_data = Image.open(BytesIO(image_data)) + a = img.extra_params["min_pixels"] + b = img.extra_params["max_pixels"] + print(f"self.min_pixels is {a} ,max_pixelx is {b}") pixel_values, image_grid_hw = load_image_native( image_data, patch_size=self.patch_size, downsample_ratio=self.downsample_ratio, - min_pixels=self.min_pixels, - max_pixels=self.max_pixels, + min_pixels=img.extra_params["min_pixels"], + max_pixels=img.extra_params["max_pixels"], ) img_tensors.append(pixel_values) img_grids.append(image_grid_hw) @@ -261,7 +264,6 @@ def encode(self, images: List[ImageItem]): # must devide merge_length cur_num = int(img_tensors[-1].shape[0] * (self.downsample_ratio ** 2)) - print(f"cur_num is {cur_num}") valid_ids.append([valid_id, valid_id + cur_num]) valid_id += cur_num diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py index f073319d7..3ab2c36c4 100644 --- a/lightllm/server/core/objs/sampling_params.py +++ b/lightllm/server/core/objs/sampling_params.py @@ -293,6 +293,8 @@ class SamplingParams(ctypes.Structure): ("ignore_eos", ctypes.c_bool), # the max number of image patches to be used in the internvl model, for the test ("image_max_patch_num", ctypes.c_int), + ("min_pixels", ctypes.c_int), + ("max_pixels", ctypes.c_int), ("max_new_tokens", ctypes.c_int), ("min_new_tokens", ctypes.c_int), # Whether to count input tokens for presence_penalty, frequency_penalty and repetition_penalty @@ -343,7 +345,8 @@ def init(self, tokenizer, **kwargs): self.top_p = kwargs.get("top_p", SamplingParams._top_p) self.top_k = kwargs.get("top_k", SamplingParams._top_k) self.ignore_eos = kwargs.get("ignore_eos", False) - self.image_max_patch_num = kwargs.get("image_max_patch_num", -1) + self.min_pixels = kwargs.get("min_pixels", -1) + self.max_pixels = kwargs.get("max_pixels", -1) self.max_new_tokens = kwargs.get("max_new_tokens", 16) self.min_new_tokens = kwargs.get("min_new_tokens", 1) self.input_penalty = kwargs.get("input_penalty", DEFAULT_INPUT_PENALTY) @@ -482,6 +485,8 @@ def to_dict(self): "image_max_patch_num": self.image_max_patch_num, "max_new_tokens": self.max_new_tokens, "min_new_tokens": self.min_new_tokens, + "min_pixels": self.min_pixels, + "max_pixels": self.max_pixels, "exponential_decay_length_penalty": self.exponential_decay_length_penalty.to_tuple(), "stop_sequences": self.stop_sequences.to_list(), "best_of": self.best_of, From 07df460fae6dacb64d9304583ecd842cf5435f08 Mon Sep 17 00:00:00 2001 From: Weichao Luo Date: Mon, 12 Jan 2026 15:07:17 +0800 Subject: [PATCH 26/71] fix fused_moe not installed use pip. --- .../basemodel/layer_weights/meta_weights/fused_moe/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/__init__.py diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/__init__.py new file mode 100644 index 000000000..e69de29bb From a6f00fbe4a2a9f6149675557d64e51b25f24b024 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Thu, 15 Jan 2026 11:51:50 +0000 Subject: [PATCH 27/71] add visual nccl port alloc --- lightllm/server/api_start.py | 20 +++++++------------- lightllm/server/core/objs/start_args_type.py | 3 +-- 2 files changed, 8 insertions(+), 15 deletions(-) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index f47d0ddd4..c64502e58 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -187,15 +187,6 @@ def _launch_subprocesses(args: StartArgs): else: args.visual_gpu_ids = args.visual_gpu_ids[:total_required_gpus] - # 检查visual_nccl_port数量是否足够 - # if len(args.visual_nccl_ports) < args.visual_dp: - # raise ValueError( - # f"Not enough visual_nccl_ports specified. You need at least {args.visual_dp}, " - # f"but got ({len(args.visual_nccl_ports)})." - # ) - # else: - # args.visual_nccl_ports = args.visual_nccl_ports[: args.visual_dp] - if args.visual_dp <= 0: raise ValueError("visual_dp must be a positive integer.") @@ -240,9 +231,9 @@ def _launch_subprocesses(args: StartArgs): args.data_type = get_dtype(args.model_dir) assert args.data_type in ["fp16", "float16", "bf16", "bfloat16", "fp32", "float32"] - already_uesd_ports = args.visual_nccl_ports + [args.nccl_port, args.port] + already_uesd_ports = [args.nccl_port, args.port] if args.run_mode == "decode": - already_uesd_ports = args.visual_nccl_ports + [args.nccl_port, args.port, args.pd_decode_rpyc_port] + already_uesd_ports = [args.nccl_port, args.port, args.pd_decode_rpyc_port] # 提前锁定端口,防止在单个机器上启动多个实列的时候,要到模型启动的时候才能 # 捕获到端口设置冲突的问题 @@ -251,7 +242,7 @@ def _launch_subprocesses(args: StartArgs): node_world_size = args.tp // args.nnodes can_use_ports = alloc_can_use_network_port( - num=9 + node_world_size + args.visual_dp * args.visual_tp, used_nccl_ports=already_uesd_ports + num=9 + node_world_size + args.visual_dp * args.visual_tp + args.visual_dp, used_nccl_ports=already_uesd_ports ) logger.info(f"alloced ports: {can_use_ports}") ( @@ -268,11 +259,14 @@ def _launch_subprocesses(args: StartArgs): can_use_ports = can_use_ports[9:] visual_model_tp_ports = [] + visual_nccl_ports = [] for _ in range(args.visual_dp): tp_ports_for_dp = can_use_ports[0 : args.visual_tp] - can_use_ports = can_use_ports[args.visual_tp :] + visual_nccl_ports.append(can_use_ports[args.visual_tp]) + can_use_ports = can_use_ports[args.visual_tp + 1 :] visual_model_tp_ports.append(tp_ports_for_dp) + args.visual_nccl_ports = visual_nccl_ports # 将申请好的端口放入args参数中 args.router_port = router_port args.router_rpc_port = router_rpc_port diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 80719745a..6834518a1 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -85,7 +85,7 @@ class StartArgs: visual_gpu_ids: Optional[List[int]] = field(default=None) visual_tp: int = field(default=1) visual_dp: int = field(default=1) - visual_nccl_ports: List[int] = field(default_factory=lambda: [29500]) + visual_nccl_ports: List[int] = field(default=None) enable_monitor_auth: bool = field(default=False) disable_cudagraph: bool = field(default=False) graph_max_batch_size: int = field(default=256) @@ -150,4 +150,3 @@ class StartArgs: enable_weight_cpu_backup: bool = field(default=False) weight_version: str = "default" - From 9360197e355499bb63dda8829c4d9ad61e6896e8 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Thu, 15 Jan 2026 12:48:41 +0000 Subject: [PATCH 28/71] fix0115 --- .../layer_infer/transformer_layer_infer.py | 4 +- .../models/llama/triton_kernel/rmsnorm.py | 43 +++++++++---------- 2 files changed, 22 insertions(+), 25 deletions(-) diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index 8c6015677..2cf37a10a 100755 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -191,14 +191,14 @@ def _att_norm( self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight ) -> torch.Tensor: out = self.alloc_tensor(input.shape, input.dtype) - rmsnorm_forward(input, weight=layer_weight.att_norm_weight_.weight, eps=self.eps_, out=out) + out = rmsnorm_forward(input, weight=layer_weight.att_norm_weight_.weight, eps=self.eps_, out=out) return out def _ffn_norm( self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight ) -> torch.Tensor: out = self.alloc_tensor(input.shape, input.dtype) - rmsnorm_forward(input, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_, out=out) + out = rmsnorm_forward(input, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_, out=out) return out def _get_qkv( diff --git a/lightllm/models/llama/triton_kernel/rmsnorm.py b/lightllm/models/llama/triton_kernel/rmsnorm.py index 0140847af..de6089159 100644 --- a/lightllm/models/llama/triton_kernel/rmsnorm.py +++ b/lightllm/models/llama/triton_kernel/rmsnorm.py @@ -4,7 +4,7 @@ import triton.language as tl import os -rmsnorm_num_warps = int(os.getenv("RMSNORM_WARPS", "8")) +rmsnorm_num_warps = int(os.getenv("RMSNORM_WARPS", "4")) @triton.jit @@ -36,15 +36,15 @@ def _rms_norm_fwd_fused( for off in range(0, N, BLOCK_SIZE): cols = off + tl.arange(0, BLOCK_SIZE) mask = cols < N - w = tl.load(W + cols, mask=mask).to(tl.float32) + w = tl.load(W + cols, mask=mask) x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) - x_hat = x * rstd - y = x_hat * w + x_hat = (x * rstd).to(tl.bfloat16) + y = x_hat * w.to(tl.bfloat16) # Write output - tl.store(Y + cols * y_stride1, y.to(Y.dtype.element_ty), mask=mask) + tl.store(Y + cols * y_stride1, y, mask=mask) -def rmsnorm_forward(x: torch.Tensor, weight, eps, out=None): +def rmsnorm_forward1(x: torch.Tensor, weight, eps, out=None): # allocate output y = torch.empty_like(x) if out is None else out # reshape input data into 2D tensor @@ -78,22 +78,19 @@ def rmsnorm_forward(x: torch.Tensor, weight, eps, out=None): return y -def torch_rms_norm(x, weight, eps): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) * weight +def rmsnorm_forward(hidden_states, weight, eps, out=None): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + eps) + out = weight * hidden_states.to(input_dtype) + return out -def test_rms_norm(M, N, dtype, eps=1e-5, device="cuda"): - # create data - x_shape = (M, N) - w_shape = (x_shape[-1],) - weight = torch.rand(w_shape, dtype=dtype, device="cuda") - x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") - # forward pass - y_tri = rmsnorm_forward(x, weight, eps) - y_ref = torch_rms_norm(x.to(torch.float32), weight.to(torch.float32), eps).to(dtype) - - # compare - print("type:", y_tri.dtype, y_ref.dtype) - print("max delta:", torch.max(torch.abs(y_tri - y_ref))) - assert torch.allclose(y_tri, y_ref, atol=1e-2, rtol=0) - return +def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + print(f"norm weight dtype:{self.weight.dtype}") + return self.weight * hidden_states.to(input_dtype) From 920a741d1b7e300dd8e9b64d6db02969f2e10bfe Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Thu, 15 Jan 2026 13:11:44 +0000 Subject: [PATCH 29/71] fix0115 --- .../models/qwen3/triton_kernel/qk_norm.py | 27 ++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/lightllm/models/qwen3/triton_kernel/qk_norm.py b/lightllm/models/qwen3/triton_kernel/qk_norm.py index 40322e509..e58cce0f2 100644 --- a/lightllm/models/qwen3/triton_kernel/qk_norm.py +++ b/lightllm/models/qwen3/triton_kernel/qk_norm.py @@ -34,7 +34,7 @@ def _rms_norm_fwd_fused( tl.store(X + cols, y.to(X.dtype.element_ty)) -def qk_rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, eps): +def qk_rmsnorm_forward1(x: torch.Tensor, weight: torch.Tensor, eps): """ This function is used to perform in-place RMSNorm on the input tensor, and to adapt the head_dim norm for Qwen3 MoE and the splited qk tensor layout. @@ -64,3 +64,28 @@ def qk_rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, eps): num_warps=4, ) return x + + +@torch.no_grad() +def qk_rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, eps: float): + assert torch.is_tensor(x) and torch.is_tensor(weight) + # assert weight.ndim == 1, weight.shape + # assert x.is_contiguous(), "x.is_contiguous()" + + head_dim = weight.numel() + x2d = x.view(-1, x.shape[-1]) # (M2, N) + M2, N = x2d.shape + assert N % head_dim == 0, (N, head_dim) + H = N // head_dim + + x3 = x2d.view(M2, H, head_dim) # (M2, H, D) + + x_fp32 = x3.to(torch.float32) + w = weight.view(1, 1, head_dim) + + var = x_fp32.pow(2).mean(dim=-1, keepdim=True) + rstd = torch.rsqrt(var + eps) + y = (x_fp32 * rstd).to(torch.bfloat16) * w + + x3.copy_(y.to(dtype=x3.dtype)) + return x From 3aa5e18ef20b4fbae8f65b4c3c359cc90bb26b82 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Fri, 16 Jan 2026 06:01:00 +0000 Subject: [PATCH 30/71] fp8 online quant for moe --- .../fused_moe/fused_moe_weight_tp.py | 32 +++++++++++++++---- .../common/quantization/quantize_method.py | 2 ++ lightllm/common/quantization/w8a8_quant.py | 17 ++++++++-- 3 files changed, 41 insertions(+), 10 deletions(-) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_tp.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_tp.py index bf7b218b7..023c7ba63 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_tp.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_tp.py @@ -10,6 +10,7 @@ get_row_slice_mixin, get_col_slice_mixin, ) +from threading import Lock def create_tp_moe_wegiht_obj( @@ -80,6 +81,7 @@ def __init__( self.quantized_weight = quant_cfg.quantized_weight if self.quant_method.method_name != "none": self.weight_scale_suffix = self.quant_method.weight_scale_suffix + self.quant_method.is_moe = True self.w1_weight_name = gate_proj_name self.w2_weight_name = down_proj_name @@ -103,6 +105,9 @@ def __init__( self.col_slicer = get_col_slice_mixin( self.quant_method.method_name, tp_rank=self.tp_rank_, tp_world_size=get_dp_world_size() ) + self.lock = Lock() + # for online per-tensor quantization + self.gate_up_buffer = [[None, None] for _ in range(self.n_routed_experts)] self._create_weight() def _create_weight(self): @@ -206,16 +211,16 @@ def load_hf_weights(self, weights): # Load each expert with TP slicing for i_experts in range(self.n_routed_experts): self._load_expert(i_experts, weights, type="weight", suffix=self.quant_method.weight_suffix) - if self.w13.weight_scale is not None: + if self.w13.weight_scale is not None and self.quantized_weight: self._load_expert(i_experts, weights, type="weight_scale", suffix=self.quant_method.weight_scale_suffix) - if self.w13.weight_zero_point is not None: + if self.w13.weight_zero_point is not None and self.quantized_weight: self._load_expert( i_experts, weights, type="weight_zero_point", suffix=self.quant_method.weight_zero_point_suffix ) def _load_weight_func(self, weight: torch.Tensor, weight_pack: WeightPack, start_idx: int = 0): if self.quant_method.weight_need_quanted(weight): - self.quant_method.quantize(weight, weight_pack, start_idx) + self.quant_method.quantize_moe(weight, weight_pack, start_idx) else: self.quant_method.load_weight(weight, weight_pack, start_idx) @@ -225,10 +230,23 @@ def _load_expert(self, expert_idx, weights, type: str, suffix: str = "weight"): w3_weight = f"{self.weight_prefix}.{expert_idx}.{self.w3_weight_name}.{suffix}" intermediate_size = self.split_inter_size load_func, slice_func = self._get_load_and_slice_func(type, is_row=True) - if w1_weight in weights: - load_func(slice_func(weights[w1_weight]), self.w13.get_expert(expert_idx), start_idx=0) - if w3_weight in weights: - load_func(slice_func(weights[w3_weight]), self.w13.get_expert(expert_idx), start_idx=intermediate_size) + if suffix == "weight": + with self.lock: + if w1_weight in weights: + self.gate_up_buffer[expert_idx][0] = slice_func(weights[w1_weight]) + if w3_weight in weights: + self.gate_up_buffer[expert_idx][1] = slice_func(weights[w3_weight]) + if None not in self.gate_up_buffer[expert_idx]: + tmp_weight = torch.cat( + [self.gate_up_buffer[expert_idx][0], self.gate_up_buffer[expert_idx][1]], dim=0 + ) + load_func(tmp_weight, self.w13.get_expert(expert_idx), start_idx=0) + self.gate_up_buffer[expert_idx] = [None, None] + else: + if w1_weight in weights: + load_func(slice_func(weights[w1_weight]), self.w13.get_expert(expert_idx), start_idx=0) + if w3_weight in weights: + load_func(slice_func(weights[w3_weight]), self.w13.get_expert(expert_idx), start_idx=intermediate_size) load_func, slice_func = self._get_load_and_slice_func(type, is_row=False) if w2_weight in weights: diff --git a/lightllm/common/quantization/quantize_method.py b/lightllm/common/quantization/quantize_method.py index 77e59465e..971ea20a1 100644 --- a/lightllm/common/quantization/quantize_method.py +++ b/lightllm/common/quantization/quantize_method.py @@ -15,6 +15,8 @@ def get_expert(self, expert_idx: int): assert self.weight.ndim == 3, f"weight must be a 3D tensor, but got {self.weight.ndim}" weight = self.weight[expert_idx] weight_scale = self.weight_scale[expert_idx] if self.weight_scale is not None else None + if weight_scale is not None and weight_scale.ndim == 0: + weight_scale = weight_scale.unsqueeze(0) weight_zero_point = self.weight_zero_point[expert_idx] if self.weight_zero_point is not None else None return WeightPack(weight=weight, weight_scale=weight_scale, weight_zero_point=weight_zero_point) diff --git a/lightllm/common/quantization/w8a8_quant.py b/lightllm/common/quantization/w8a8_quant.py index e4f7b552a..b2a5ce6ed 100644 --- a/lightllm/common/quantization/w8a8_quant.py +++ b/lightllm/common/quantization/w8a8_quant.py @@ -116,10 +116,9 @@ def __init__(self): self.is_moe = False self.has_weight_scale = True self.has_weight_zero_point = False + self.is_moe = False def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0) -> None: - if self.is_moe: - return self.quantize_moe(weight, output, offset) qweight, weight_scale = scaled_fp8_quant( weight.cuda(self.device_id_), scale=None, use_per_token_if_dynamic=True ) @@ -127,6 +126,14 @@ def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0) -> output.weight_scale[offset : offset + weight_scale.shape[0]].copy_(weight_scale.view(-1)) return + def quantize_moe(self, weight: torch.Tensor, output: WeightPack, offset: int = 0) -> None: + qweight, weight_scale = scaled_fp8_quant( + weight.cuda(self.device_id_), scale=None, use_per_token_if_dynamic=False + ) + output.weight[:, :].copy_(qweight) + output.weight_scale[:].copy_(weight_scale) + return + def apply( self, input_tensor: torch.Tensor, @@ -160,7 +167,11 @@ def create_weight( ) -> WeightPack: expert_prefix = (num_experts,) if num_experts > 1 else () weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=torch.float8_e4m3fn).cuda(device_id) - weight_scale = torch.empty(expert_prefix + (out_dim,), dtype=torch.float32).cuda(device_id) + if self.is_moe: + # per-tensor for moe + weight_scale = torch.empty((num_experts,), dtype=torch.float32).cuda(device_id) + else: + weight_scale = torch.empty(expert_prefix + (out_dim,), dtype=torch.float32).cuda(device_id) return WeightPack(weight=weight, weight_scale=weight_scale) From 7cb890b4c6d97632068fc18550a746e55ec53fcc Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Fri, 16 Jan 2026 12:56:56 +0000 Subject: [PATCH 31/71] hotfix for fa3 of llama --- .../layer_infer/transformer_layer_infer.py | 144 +++++++++++++----- 1 file changed, 110 insertions(+), 34 deletions(-) diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index bb38c45bb..640d04d6d 100755 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -41,6 +41,14 @@ from lightllm.utils.sgl_utils import flash_attn_with_kvcache +try: + import flash_attn + import flash_attn_3_cuda + +except ImportError: + flash_attn_3_cuda = None + logger.warning("flash_attn is not installed, you can't use the api of it. ") + class LlamaTransformerLayerInfer(TransformerLayerInferTpl): """ """ @@ -326,25 +334,59 @@ def _context_attention_flashattention(self, q, kv, infer_state: FlashAttentionSt :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : ].reshape(-1, 1, self.tp_v_head_num_, self.head_dim_) q = q.reshape(-1, self.tp_q_head_num_, self.head_dim_) - k_descale, v_descale = None, None # disable quantization + # k_descale, v_descale = None, None # disable quantization Lq = q.shape[-1] sm_scale = 1.0 / (Lq ** 0.5) - o = flash_attn_with_kvcache( - q=q, - k_cache=cache_k, - v_cache=cache_v, - page_table=infer_state.page_table, - cache_seqlens=infer_state.b_seq_len, - cu_seqlens_q=infer_state.cu_seqlens_q, - cu_seqlens_k_new=infer_state.cu_seqlens_k, - max_seqlen_q=infer_state.q_max_seq_len, - softmax_scale=sm_scale, - causal=True, - window_size=(-1, -1), - softcap=0.0, - k_descale=k_descale, - v_descale=v_descale, - return_softmax_lse=False, + # o = flash_attn_with_kvcache( + # q=q, + # k_cache=cache_k, + # v_cache=cache_v, + # page_table=infer_state.page_table, + # cache_seqlens=infer_state.b_seq_len, + # cu_seqlens_q=infer_state.cu_seqlens_q, + # cu_seqlens_k_new=infer_state.cu_seqlens_k, + # max_seqlen_q=infer_state.q_max_seq_len, + # softmax_scale=sm_scale, + # causal=True, + # window_size=(-1, -1), + # softcap=0.0, + # k_descale=k_descale, + # v_descale=v_descale, + # return_softmax_lse=False, + # ) + o, softmax_lse, *rest = flash_attn_3_cuda.fwd( + q, + cache_k, + cache_v, + None, + None, + None, # qv + None, # out + infer_state.cu_seqlens_q, + None, + infer_state.cu_seqlens_k, + None, + infer_state.b_seq_len, + infer_state.max_q_seq_len, + None, + infer_state.page_table, + None, + None, + None, + None, + None, + None, + None, + sm_scale, + True, # causal + -1, # window_size + -1, # window_size_right + 0.0, + True, + None, + 0, + None, + 0, ) return o @@ -839,25 +881,59 @@ def _token_decode_attention_flashattention(self, q, infer_state: FlashAttentionS :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : ].reshape(-1, 1, self.tp_v_head_num_, self.head_dim_) q = q.reshape(-1, self.tp_q_head_num_, self.head_dim_) - k_descale, v_descale = None, None # disable quantization + # k_descale, v_descale = None, None # disable quantization Lq = q.shape[-1] sm_scale = 1.0 / (Lq ** 0.5) - o = flash_attn_with_kvcache( - q=q, - k_cache=cache_k, - v_cache=cache_v, - page_table=infer_state.page_table, - cache_seqlens=infer_state.b_seq_len, - cu_seqlens_q=infer_state.cu_seqlens_q, - cu_seqlens_k_new=infer_state.cu_seqlens_k, - max_seqlen_q=1, - softmax_scale=sm_scale, - causal=True, - window_size=(-1, -1), - softcap=0.0, - k_descale=k_descale, - v_descale=v_descale, - return_softmax_lse=False, + # o = flash_attn_with_kvcache( + # q=q, + # k_cache=cache_k, + # v_cache=cache_v, + # page_table=infer_state.page_table, + # cache_seqlens=infer_state.b_seq_len, + # cu_seqlens_q=infer_state.cu_seqlens_q, + # cu_seqlens_k_new=infer_state.cu_seqlens_k, + # max_seqlen_q=1, + # softmax_scale=sm_scale, + # causal=True, + # window_size=(-1, -1), + # softcap=0.0, + # k_descale=k_descale, + # v_descale=v_descale, + # return_softmax_lse=False, + # ) + o, softmax_lse, *rest = flash_attn_3_cuda.fwd( + q, + cache_k, + cache_v, + None, + None, + None, # qv + None, # out + infer_state.cu_seqlens_q, + None, + infer_state.cu_seqlens_k, + None, + infer_state.b_seq_len, + infer_state.max_q_seq_len, + None, + infer_state.page_table, + None, + None, + None, + None, + None, + None, + None, + sm_scale, + True, # causal + -1, # window_size + -1, # window_size_right + 0.0, + True, + None, + 0, + None, + 0, ) return o From c242a75a7286267a31dc7746ad8b983d419a9fc5 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 19 Jan 2026 03:23:06 +0000 Subject: [PATCH 32/71] fp8w8a8 triton config --- ...fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json | 110 ++++++++++++++++++ ..._fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json | 110 ++++++++++++++++++ ...fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json | 110 ++++++++++++++++++ ..._fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json | 110 ++++++++++++++++++ .../{topk_num=8}_NVIDIA_H100_80GB_HBM3.json | 50 ++++++++ ...t16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json | 74 ++++++++++++ ...t16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json | 74 ++++++++++++ ...torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json | 74 ++++++++++++ ...=torch.float16}_NVIDIA_H100_80GB_HBM3.json | 74 ++++++++++++ 9 files changed, 786 insertions(+) create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_align_fused:v1/{topk_num=8}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.float16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.float16}_NVIDIA_H100_80GB_HBM3.json diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 000000000..ee316f610 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "128": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "32768": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "67584": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "800": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "8192": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 000000000..ddda23d25 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": true, + "num_stages": 4, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": true, + "num_stages": 2, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": true, + "num_stages": 4, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": true, + "num_stages": 2, + "num_warps": 4 + }, + "32768": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": true, + "num_stages": 2, + "num_warps": 4 + }, + "67584": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": true, + "num_stages": 2, + "num_warps": 4 + }, + "800": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": true, + "num_stages": 4, + "num_warps": 4 + }, + "8192": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 000000000..560ca6c09 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 5, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 8 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "8448": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 000000000..e950ff095 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": true, + "num_stages": 4, + "num_warps": 4 + }, + "8448": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_align_fused:v1/{topk_num=8}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_align_fused:v1/{topk_num=8}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 000000000..7f479b838 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_align_fused:v1/{topk_num=8}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,50 @@ +{ + "1": { + "BLOCK_SIZE": 256, + "num_warps": 2 + }, + "100": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "1024": { + "BLOCK_SIZE": 128, + "num_warps": 2 + }, + "128": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "16": { + "BLOCK_SIZE": 256, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE": 256, + "num_warps": 8 + }, + "256": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "32": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "4096": { + "BLOCK_SIZE": 128, + "num_warps": 1 + }, + "64": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "8": { + "BLOCK_SIZE": 256, + "num_warps": 2 + }, + "8448": { + "BLOCK_SIZE": 256, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 000000000..b3051c658 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,74 @@ +{ + "1": { + "BLOCK_DIM": 128, + "BLOCK_M": 1, + "NUM_STAGE": 2, + "num_warps": 8 + }, + "100": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "1024": { + "BLOCK_DIM": 1024, + "BLOCK_M": 2, + "NUM_STAGE": 4, + "num_warps": 4 + }, + "128": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "16": { + "BLOCK_DIM": 128, + "BLOCK_M": 1, + "NUM_STAGE": 2, + "num_warps": 4 + }, + "2048": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "256": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 2 + }, + "32": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 2 + }, + "4096": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 1 + }, + "64": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "8": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 2 + }, + "8448": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.float16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.float16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 000000000..fdb321221 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.float16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,74 @@ +{ + "1": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 8 + }, + "100": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "1024": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 2 + }, + "128": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "16": { + "BLOCK_DIM": 128, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 4 + }, + "2048": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "256": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "32": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 2, + "num_warps": 8 + }, + "4096": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 2 + }, + "64": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "8": { + "BLOCK_DIM": 128, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "8448": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 000000000..a94e66935 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,74 @@ +{ + "1024": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "128": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "16384": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "2048": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "256": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "32768": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "512": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "64": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "67584": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "8": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "800": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "8192": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.float16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.float16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 000000000..441421fd5 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.float16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,74 @@ +{ + "1024": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "128": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "16384": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "2048": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "256": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "32768": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "512": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "64": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "67584": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "8": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "800": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "8192": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + } +} \ No newline at end of file From a0195aa5e89c9d1f5302872aa24bad88690d84dd Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 19 Jan 2026 09:00:44 +0000 Subject: [PATCH 33/71] fp16 config --- lightllm/common/quantization/w8a8_quant.py | 2 +- ...fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json | 110 ++++++++++++++++++ ...fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json | 110 ++++++++++++++++++ 3 files changed, 221 insertions(+), 1 deletion(-) create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json diff --git a/lightllm/common/quantization/w8a8_quant.py b/lightllm/common/quantization/w8a8_quant.py index b2a5ce6ed..7e819ccd2 100644 --- a/lightllm/common/quantization/w8a8_quant.py +++ b/lightllm/common/quantization/w8a8_quant.py @@ -127,7 +127,7 @@ def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0) -> return def quantize_moe(self, weight: torch.Tensor, output: WeightPack, offset: int = 0) -> None: - qweight, weight_scale = scaled_fp8_quant( + qweight, weight_scale = vllm_ops.scaled_fp8_quant( weight.cuda(self.device_id_), scale=None, use_per_token_if_dynamic=False ) output.weight[:, :].copy_(qweight) diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 000000000..e02770109 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "32768": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "67584": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "800": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8192": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 000000000..0713de799 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 8 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "8448": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + } +} \ No newline at end of file From 7f0c43756d647e26938a6b6e9ac919a70e5f3165 Mon Sep 17 00:00:00 2001 From: Weichao Luo Date: Wed, 21 Jan 2026 14:05:27 +0800 Subject: [PATCH 34/71] release ipc tensor early. --- .../server/router/model_infer/mode_backend/base_backend.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 4be944584..28faa6a83 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -462,7 +462,9 @@ def update_weights_from_tensor(self, request: UpdateWeightsFromTensorReq): def _unwrap_tensor(tensor, tp_rank, device): if isinstance(tensor, LocalSerializedTensor): tensor = tensor.get(tp_rank) - return tensor.to(device) + clone = tensor.to(device).clone() + del tensor # free the ipc tensor + return clone named_tensors = { name: _unwrap_tensor(tensor, tp_rank=self.rank_in_dp, device=infered_device) From 5738d9ee53078d5ca839ac5842ed4ca2f5691498 Mon Sep 17 00:00:00 2001 From: sound Date: Wed, 21 Jan 2026 17:15:18 +0800 Subject: [PATCH 35/71] bugfix: fix flattened_bucket update weights --- .../server/router/model_infer/mode_backend/base_backend.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 28faa6a83..c198e083d 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -451,7 +451,8 @@ def update_weights_from_tensor(self, request: UpdateWeightsFromTensorReq): monkey_patch_torch_reductions() if request.load_format == "flattened_bucket": # Handle flattened bucket format - return self._update_weights_from_flattened_bucket(flattened_tensor_bucket_dict=request.named_tensors) + serialized_named_tensors = MultiprocessingSerializer.deserialize(request.serialized_named_tensors[self.rank_in_dp]) + return self._update_weights_from_flattened_bucket(flattened_tensor_bucket_dict=serialized_named_tensors) # We need to get device after patch otherwise the device would be wrong self.device_module = torch.get_device_module("cuda") From e11bf58707cec1e1bd06076f81e22d4aa655e659 Mon Sep 17 00:00:00 2001 From: sound Date: Thu, 22 Jan 2026 11:07:01 +0800 Subject: [PATCH 36/71] bugfix: fix update_weights from tensor --- .../server/router/model_infer/mode_backend/base_backend.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index c198e083d..099bd20da 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -440,9 +440,14 @@ def _update_weights_from_flattened_bucket( # Create bucket and reconstruct tensors bucket = FlattenedTensorBucket(flattened_tensor=flattened_tensor, metadata=converted_metadata) reconstructed_tensors = bucket.reconstruct_tensors() + + named_tensors = { + name: tensor + for name, tensor in reconstructed_tensors + } # Load the reconstructed tensors using the standard method - self.model.load_weights(reconstructed_tensors) + self.model.load_weights(named_tensors) return True, "Succeeded to update parameter online from flattened bucket tensor." From ce76f8a4c2a40243633c3fffa8f1744f18c02491 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Thu, 29 Jan 2026 07:20:42 +0000 Subject: [PATCH 37/71] fix start --- lightllm/common/basemodel/basemodel.py | 2 +- lightllm/server/api_start.py | 2 -- lightllm/server/core/objs/start_args_type.py | 4 +++- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index a0d2f41ea..c7cbb1f27 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -184,7 +184,7 @@ def _init_weights(self, start_layer_index=0): return def load_weights(self, weight_dict: dict): - assert isinstance(weight_dict, dict), "weight_dict must be a dict" + assert weight_dict is None or isinstance(weight_dict, dict), "weight_dict must be a dict or None" load_hf_weights( self.data_type, self.weight_dir_, diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index c918832a9..121f3b9a6 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -244,7 +244,6 @@ def _launch_subprocesses(args: StartArgs): ( nccl_port, router_port, - router_rpc_port, detokenization_port, http_server_port, visual_port, @@ -272,7 +271,6 @@ def _launch_subprocesses(args: StartArgs): if args.pd_decode_rpyc_port is None: args.pd_decode_rpyc_port = pd_decode_rpyc_port args.router_port = router_port - args.router_rpc_port = router_rpc_port args.detokenization_port = detokenization_port args.http_server_port = http_server_port args.visual_port = visual_port diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 365db9736..2cb12ed89 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -156,7 +156,6 @@ class StartArgs: enable_dp_prompt_cache_fetch: bool = field(default=False) # zmp ports router_port: int = field(default=None) - router_rpc_port: int = field(default=None) detokenization_port: int = field(default=None) http_server_port: int = field(default=None) visual_port: int = field(default=None) @@ -181,5 +180,8 @@ class StartArgs: disable_custom_allreduce: bool = field(default=False) enable_torch_memory_saver: bool = field(default=False) enable_weight_cpu_backup: bool = field(default=False) + hardware_platform: str = field(default="cuda", metadata={"choices": ["cuda", "musa"]}) + enable_torch_fallback: bool = field(default=False) + enable_triton_fallback: bool = field(default=False) weight_version: str = "default" From 45259ec01584800d2c11ae64a8936e57ee2a1b5b Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Thu, 29 Jan 2026 09:12:09 +0000 Subject: [PATCH 38/71] add-merge-kv-mode --- .../layer_infer/transformer_layer_infer.py | 71 +++++++++++++++++++ .../layer_weights/transformer_layer_weight.py | 26 ++++--- .../models/qwen3/triton_kernel/qk_norm.py | 38 +++++----- 3 files changed, 107 insertions(+), 28 deletions(-) diff --git a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py index ed48a9c6f..4adf0e506 100644 --- a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py @@ -17,6 +17,7 @@ class NeoChatMOETransformerLayerInfer(Qwen3MOETransformerLayerInfer): def __init__(self, data_type, network_config, mode): + self._is_merge_kv = network_config["is_merge_kv"] super().__init__(data_type, network_config, mode) return @@ -27,6 +28,14 @@ def _bind_attention(self): return def _get_qkv(self, input, infer_state: NeoChatInferStateInfo, layer_weight: NeoChatMOETransformerLayerWeight): + if self._is_merge_kv: + return self._get_qkv_mergekv(input, infer_state, layer_weight) + else: + return self._get_qkv_not_mergekv(input, infer_state, layer_weight) + + def _get_qkv_not_mergekv( + self, input, infer_state: NeoChatInferStateInfo, layer_weight: NeoChatMOETransformerLayerWeight + ): input = input.view(-1, self.embed_dim_) q = layer_weight.q_proj.mm(input) # [T, Hq*D] @@ -97,6 +106,68 @@ def _get_qkv(self, input, infer_state: NeoChatInferStateInfo, layer_weight: NeoC cache_kv = torch.cat([k, v], dim=1) return q, cache_kv + def _get_qkv_mergekv( + self, input, infer_state: NeoChatInferStateInfo, layer_weight: NeoChatMOETransformerLayerWeight + ): + input = input.view(-1, self.embed_dim_) + + q = layer_weight.q_proj.mm(input) # [T, Hq*D] + q_hw = layer_weight.q_hw_proj.mm(input) + k_hw = layer_weight.k_hw_proj.mm(input) + + cache_kv = layer_weight.kv_proj.mm(input) # [T, (Hk+Hv)*D] + + qk_rmsnorm_forward(q, weight=layer_weight.q_norm_weight_.weight, eps=self.eps_) + qk_rmsnorm_forward(q_hw, weight=layer_weight.q_norm_h_weight_.weight, eps=self.eps_) + qk_rmsnorm_forward(k_hw, weight=layer_weight.k_norm_w_weight_.weight, eps=self.eps_) + + q_hw = q_hw.view(q.shape[0], self.tp_q_head_num_, self.head_dim_) + q_h, q_w = q_hw.chunk(2, dim=-1) + + qk_rmsnorm_forward( + cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], + weight=layer_weight.k_norm_weight_.weight, + eps=self.eps_, + ) + + k_hw = k_hw.view(q.shape[0], self.tp_k_head_num_, self.head_dim_) + k_h, k_w = k_hw.chunk(2, dim=-1) + + cache_kv = cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) + + rotary_emb_fwd( + q.view(-1, self.tp_q_head_num_, self.head_dim_), + cache_kv[:, : self.tp_k_head_num_, :], + infer_state.position_cos, + infer_state.position_sin, + ) + rotary_emb_fwd( + q_h, + k_h, + infer_state.position_cos_h, + infer_state.position_sin_h, + ) + rotary_emb_fwd( + q_w, + k_w, + infer_state.position_cos_w, + infer_state.position_sin_w, + ) + + q3 = q.view(-1, self.tp_q_head_num_, self.head_dim_) + q3 = torch.cat([q3, q_h, q_w], dim=-1) + q = q3.reshape(q3.shape[0], -1) + + k = cache_kv[:, : self.tp_k_head_num_, :] + k = torch.cat([k, k_h, k_w], dim=-1) + + v = cache_kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :] + v_pad = torch.zeros((v.shape[0], v.shape[1], self.head_dim_), device=v.device, dtype=v.dtype) + v = torch.cat([v, v_pad], dim=-1) + + cache_kv = torch.cat([k, v], dim=1) + return q, cache_kv + def _context_attention_kernel( self, q, kv, infer_state: NeoChatInferStateInfo, layer_weight, out=None ) -> torch.Tensor: diff --git a/lightllm/models/neo_chat_moe/layer_weights/transformer_layer_weight.py b/lightllm/models/neo_chat_moe/layer_weights/transformer_layer_weight.py index bc38f1adc..d8c842bb9 100644 --- a/lightllm/models/neo_chat_moe/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/neo_chat_moe/layer_weights/transformer_layer_weight.py @@ -7,6 +7,7 @@ class NeoChatMOETransformerLayerWeight(Qwen3MOETransformerLayerWeight): def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): + self._is_merge_kv = network_config["merge_kv"] super().__init__(layer_num, data_type, network_config, mode, quant_cfg) return @@ -17,11 +18,15 @@ def _init_weight_names(self): self._k_weight_hw_name = f"model.layers.{self.layer_num_}.self_attn.k_proj_hw.weight" self._k_bias_hw_name = None - self._q_norm_h_name = f"model.layers.{self.layer_num_}.self_attn.q_norm_h.weight" - self._q_norm_w_name = f"model.layers.{self.layer_num_}.self_attn.q_norm_w.weight" + if self._is_merge_kv: + self._q_norm_hw_name = f"model.layers.{self.layer_num_}.self_attn.q_norm_hw.weight" + self._k_norm_hw_name = f"model.layers.{self.layer_num_}.self_attn.k_norm_hw.weight" + else: + self._q_norm_h_name = f"model.layers.{self.layer_num_}.self_attn.q_norm_h.weight" + self._q_norm_w_name = f"model.layers.{self.layer_num_}.self_attn.q_norm_w.weight" - self._k_norm_h_name = f"model.layers.{self.layer_num_}.self_attn.k_norm_h.weight" - self._k_norm_w_name = f"model.layers.{self.layer_num_}.self_attn.k_norm_w.weight" + self._k_norm_h_name = f"model.layers.{self.layer_num_}.self_attn.k_norm_h.weight" + self._k_norm_w_name = f"model.layers.{self.layer_num_}.self_attn.k_norm_w.weight" def _init_qkv(self): super()._init_qkv() @@ -44,8 +49,11 @@ def _init_qkv(self): def _init_norm(self): super()._init_norm() - - self.q_norm_h_weight_ = NormWeight(weight_name=self._q_norm_h_name, data_type=self.data_type_) - self.q_norm_w_weight_ = NormWeight(weight_name=self._q_norm_w_name, data_type=self.data_type_) - self.k_norm_h_weight_ = NormWeight(weight_name=self._k_norm_h_name, data_type=self.data_type_) - self.k_norm_w_weight_ = NormWeight(weight_name=self._k_norm_w_name, data_type=self.data_type_) + if self._is_merge_kv: + self.q_norm_hw_weight_ = NormWeight(weight_name=self._q_norm_hw_name, data_type=self.data_type_) + self.k_norm_hw_weight_ = NormWeight(weight_name=self._k_norm_hw_name, data_type=self.data_type_) + else: + self.q_norm_h_weight_ = NormWeight(weight_name=self._q_norm_h_name, data_type=self.data_type_) + self.q_norm_w_weight_ = NormWeight(weight_name=self._q_norm_w_name, data_type=self.data_type_) + self.k_norm_h_weight_ = NormWeight(weight_name=self._k_norm_h_name, data_type=self.data_type_) + self.k_norm_w_weight_ = NormWeight(weight_name=self._k_norm_w_name, data_type=self.data_type_) diff --git a/lightllm/models/qwen3/triton_kernel/qk_norm.py b/lightllm/models/qwen3/triton_kernel/qk_norm.py index e58cce0f2..8e0de6a6e 100644 --- a/lightllm/models/qwen3/triton_kernel/qk_norm.py +++ b/lightllm/models/qwen3/triton_kernel/qk_norm.py @@ -34,7 +34,7 @@ def _rms_norm_fwd_fused( tl.store(X + cols, y.to(X.dtype.element_ty)) -def qk_rmsnorm_forward1(x: torch.Tensor, weight: torch.Tensor, eps): +def qk_rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, eps): """ This function is used to perform in-place RMSNorm on the input tensor, and to adapt the head_dim norm for Qwen3 MoE and the splited qk tensor layout. @@ -66,26 +66,26 @@ def qk_rmsnorm_forward1(x: torch.Tensor, weight: torch.Tensor, eps): return x -@torch.no_grad() -def qk_rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, eps: float): - assert torch.is_tensor(x) and torch.is_tensor(weight) - # assert weight.ndim == 1, weight.shape - # assert x.is_contiguous(), "x.is_contiguous()" +# @torch.no_grad() +# def qk_rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, eps: float): +# assert torch.is_tensor(x) and torch.is_tensor(weight) +# # assert weight.ndim == 1, weight.shape +# # assert x.is_contiguous(), "x.is_contiguous()" - head_dim = weight.numel() - x2d = x.view(-1, x.shape[-1]) # (M2, N) - M2, N = x2d.shape - assert N % head_dim == 0, (N, head_dim) - H = N // head_dim +# head_dim = weight.numel() +# x2d = x.view(-1, x.shape[-1]) # (M2, N) +# M2, N = x2d.shape +# assert N % head_dim == 0, (N, head_dim) +# H = N // head_dim - x3 = x2d.view(M2, H, head_dim) # (M2, H, D) +# x3 = x2d.view(M2, H, head_dim) # (M2, H, D) - x_fp32 = x3.to(torch.float32) - w = weight.view(1, 1, head_dim) +# x_fp32 = x3.to(torch.float32) +# w = weight.view(1, 1, head_dim) - var = x_fp32.pow(2).mean(dim=-1, keepdim=True) - rstd = torch.rsqrt(var + eps) - y = (x_fp32 * rstd).to(torch.bfloat16) * w +# var = x_fp32.pow(2).mean(dim=-1, keepdim=True) +# rstd = torch.rsqrt(var + eps) +# y = (x_fp32 * rstd).to(torch.bfloat16) * w - x3.copy_(y.to(dtype=x3.dtype)) - return x +# x3.copy_(y.to(dtype=x3.dtype)) +# return x From da3b53db4024786a0a98457f88ddbf0f7d716a0d Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Thu, 29 Jan 2026 09:36:15 +0000 Subject: [PATCH 39/71] add-neo-chat0129 --- .../neo_chat_moe/layer_infer/transformer_layer_infer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py index 4adf0e506..8e5c8e5cd 100644 --- a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py @@ -118,8 +118,8 @@ def _get_qkv_mergekv( cache_kv = layer_weight.kv_proj.mm(input) # [T, (Hk+Hv)*D] qk_rmsnorm_forward(q, weight=layer_weight.q_norm_weight_.weight, eps=self.eps_) - qk_rmsnorm_forward(q_hw, weight=layer_weight.q_norm_h_weight_.weight, eps=self.eps_) - qk_rmsnorm_forward(k_hw, weight=layer_weight.k_norm_w_weight_.weight, eps=self.eps_) + qk_rmsnorm_forward(q_hw, weight=layer_weight.q_norm_hw_weight_.weight, eps=self.eps_) + qk_rmsnorm_forward(k_hw, weight=layer_weight.k_norm_hw_weight_.weight, eps=self.eps_) q_hw = q_hw.view(q.shape[0], self.tp_q_head_num_, self.head_dim_) q_h, q_w = q_hw.chunk(2, dim=-1) From 043e898e6589ab40c652f3e6e930ac951c83523a Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Fri, 30 Jan 2026 05:19:37 +0000 Subject: [PATCH 40/71] moe fused weight --- .../meta_weights/fused_moe/fused_moe_weight.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py index 6bcf7fc03..5d6519de4 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py @@ -295,6 +295,7 @@ def _create_weight(self): device_id=self.device_id_, num_experts=self.local_n_routed_experts, ) + self.w1, self.w3 = w13_param_list self.w1_list: List[WeightPack] = self._get_expert_weight_list(w13_param_list[0]) self.w3_list: List[WeightPack] = self._get_expert_weight_list(w13_param_list[1]) self.w2_list: List[WeightPack] = self._get_expert_weight_list(self.w2) @@ -312,6 +313,8 @@ def _load_weight(self, expert_idx_to_local_idx: Dict[int, int], weights: Dict[st for expert_idx, local_expert_idx in expert_idx_to_local_idx.items(): with self.lock: self._load_expert(expert_idx, local_expert_idx, weights) + # for rl updated weight + self._load_merge_weight(weights) self._load_expert_scale( expert_idx, local_expert_idx, @@ -332,6 +335,7 @@ def _load_expert( w1_weight = f"{self.weight_prefix}.{expert_idx}.{self.w1_weight_name}.{self.quant_method.weight_suffix}" w2_weight = f"{self.weight_prefix}.{expert_idx}.{self.w2_weight_name}.{self.quant_method.weight_suffix}" w3_weight = f"{self.weight_prefix}.{expert_idx}.{self.w3_weight_name}.{self.quant_method.weight_suffix}" + row_slice_func = self.row_slicer._slice_weight col_slice_func = self.col_slicer._slice_weight if w1_weight in weights: @@ -341,6 +345,17 @@ def _load_expert( if w2_weight in weights: self.quant_method.load_weight(col_slice_func(weights[w2_weight]), self.w2_list[local_expert_idx]) + def _load_merge_weight(self, weights: Dict[str, torch.Tensor]): + w1_merge_weight = f"{self.weight_prefix}.{self.w1_weight_name}" + w2_merge_weight = f"{self.weight_prefix}.{self.w2_weight_name}" + w3_merge_weight = f"{self.weight_prefix}.{self.w3_weight_name}" + if w1_merge_weight in weights: + self.quant_method.load_weight(weights[w1_merge_weight], self.w1) + if w2_merge_weight in weights: + self.quant_method.load_weight(weights[w2_merge_weight], self.w2) + if w3_merge_weight in weights: + self.quant_method.load_weight(weights[w3_merge_weight], self.w3) + def _load_expert_scale( self, expert_idx: int, From 80cfcc4f7984fbb66d5a6ac0168d33a36ef6be83 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Fri, 30 Jan 2026 06:32:12 +0000 Subject: [PATCH 41/71] fix neo --- lightllm/models/llama/model.py | 2 - .../layer_infer/transformer_layer_infer.py | 36 ++++++------ .../pre_and_post_layer_weight.py | 4 +- .../layer_weights/transformer_layer_weight.py | 42 +++++++++----- lightllm/models/neo_chat_moe/infer_struct.py | 6 +- .../layer_infer/transformer_layer_infer.py | 51 +++++++++-------- .../pre_and_post_layer_weight.py | 4 +- .../layer_weights/transformer_layer_weight.py | 56 +++++++++++++------ 8 files changed, 119 insertions(+), 82 deletions(-) diff --git a/lightllm/models/llama/model.py b/lightllm/models/llama/model.py index 7064edae8..f86bd5f83 100644 --- a/lightllm/models/llama/model.py +++ b/lightllm/models/llama/model.py @@ -110,7 +110,6 @@ def _init_to_get_rotary(self, default_base=10000): rope_scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0) base = self.config.get("rope_theta", float(default_base)) - print(f"base is {base}") if "max_sequence_length" in self.config: max_seq_len = self.config["max_sequence_length"] else: @@ -151,7 +150,6 @@ def _init_to_get_hw_rotary(self, default_base=10000): rope_scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0) base = self.config.get("rope_theta_hw", float(default_base)) - print(f"hw_base is {base}") if "max_sequence_length" in self.config: max_seq_len = self.config["max_sequence_length"] else: diff --git a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py index 1cf13c413..a3436b28e 100644 --- a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py @@ -1,23 +1,20 @@ import torch from functools import partial from typing import Tuple -from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.models.neo_chat_moe.infer_struct import NeoChatInferStateInfo from lightllm.models.neo_chat_moe.triton_kernel.context_attention_fwd_neo import context_attention_fwd_neo -from lightllm.models.llama.triton_kernel.context_flashattention_nopad import context_attention_fwd from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd from lightllm.models.qwen3.layer_infer.transformer_layer_infer import Qwen3TransformerLayerInfer from lightllm.models.neo_chat.layer_weights.transformer_layer_weight import NeoChatTransformerLayerWeight from lightllm.distributed import all_reduce import torch.distributed as dist from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer -from lightllm.models.qwen3.triton_kernel.qk_norm import qk_rmsnorm_forward class NeoChatTransformerLayerInfer(Qwen3TransformerLayerInfer): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) return def _bind_attention(self): @@ -40,25 +37,24 @@ def _get_qkv(self, input, infer_state: NeoChatInferStateInfo, layer_weight: NeoC cache_kv = layer_weight.kv_proj.mm(input) # [T, (Hk+Hv)*D] - qk_rmsnorm_forward(q, weight=layer_weight.q_norm_weight_.weight, eps=self.eps_) + layer_weight.q_norm_weight_(q, eps=self.eps_) q_h_2d = q_h.reshape(q.shape[0], -1) q_w_2d = q_w.reshape(q.shape[0], -1) - qk_rmsnorm_forward(q_h_2d, weight=layer_weight.q_norm_h_weight_.weight, eps=self.eps_) - qk_rmsnorm_forward(q_w_2d, weight=layer_weight.q_norm_w_weight_.weight, eps=self.eps_) + layer_weight.q_norm_h_weight_(q_h_2d, eps=self.eps_) + layer_weight.q_norm_w_weight_(q_w_2d, eps=self.eps_) q_h = q_h_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2) q_w = q_w_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2) - qk_rmsnorm_forward( + layer_weight.k_norm_weight_( cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], - weight=layer_weight.k_norm_weight_.weight, eps=self.eps_, ) k_h_2d = k_h.reshape(q.shape[0], -1) # [T, Hk*(D/2)] k_w_2d = k_w.reshape(q.shape[0], -1) - qk_rmsnorm_forward(k_h_2d, weight=layer_weight.k_norm_h_weight_.weight, eps=self.eps_) - qk_rmsnorm_forward(k_w_2d, weight=layer_weight.k_norm_w_weight_.weight, eps=self.eps_) + layer_weight.k_norm_h_weight_(k_h_2d, eps=self.eps_) + layer_weight.k_norm_w_weight_(k_w_2d, eps=self.eps_) k_h = k_h_2d.view(q.shape[0], self.tp_k_head_num_, self.head_dim_ // 2) k_w = k_w_2d.view(q.shape[0], self.tp_k_head_num_, self.head_dim_ // 2) @@ -119,7 +115,7 @@ def _context_attention_kernel( o3 = o3[:, :, : self.head_dim_].contiguous() return o3.view(o3.shape[0], -1) - def _token_decode_attention_normal(self, q, infer_state: NeoChatInferStateInfo, layer_weight, out=None): + def _token_attention_kernel(self, q, infer_state: NeoChatInferStateInfo, layer_weight): total_token_num = infer_state.total_token_num batch_size = infer_state.batch_size @@ -134,18 +130,22 @@ def _token_decode_attention_normal(self, q, infer_state: NeoChatInferStateInfo, att_m_tensor, infer_state.req_manager.req_to_token_indexs, infer_state.b_req_idx, - infer_state.b_start_loc, + infer_state.b_kv_start_loc, infer_state.b_seq_len, - infer_state.max_len_in_batch, + infer_state.max_kv_seq_len, ) - from lightllm.models.llama.triton_kernel.token_attention_softmax_and_reducev import token_softmax_reducev_fwd + from lightllm.common.basemodel.triton_kernel.att.decode_att.mha.stage3_decode_att import ( + token_attention_softmax_and_reducev, + ) + + token_softmax_reducev_fwd = token_attention_softmax_and_reducev.token_softmax_reducev_fwd v_3d = infer_state.mem_manager.kv_buffer[self.layer_num_][ :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : self.head_dim_ ] - o_3d = self.alloc_tensor((batch_size, self.tp_q_head_num_, self.head_dim_), q.dtype) if out is None else out + o_3d = self.alloc_tensor((batch_size, self.tp_q_head_num_, self.head_dim_), q.dtype) token_softmax_reducev_fwd( att_m_tensor, @@ -153,7 +153,7 @@ def _token_decode_attention_normal(self, q, infer_state: NeoChatInferStateInfo, o_3d, infer_state.req_manager.req_to_token_indexs, infer_state.b_req_idx, - infer_state.b_start_loc, + infer_state.b_kv_start_loc, infer_state.b_seq_len, ) return o_3d.view(batch_size, -1) diff --git a/lightllm/models/neo_chat/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/neo_chat/layer_weights/pre_and_post_layer_weight.py index c1f0638ac..e6489f39a 100644 --- a/lightllm/models/neo_chat/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/neo_chat/layer_weights/pre_and_post_layer_weight.py @@ -13,8 +13,8 @@ def rename_weight_keys(weights): class NeoChatPreAndPostLayerWeight(Qwen2PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) return def load_hf_weights(self, weights): diff --git a/lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py b/lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py index e5e769a76..e62afae9b 100644 --- a/lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py @@ -1,13 +1,13 @@ from lightllm.models.qwen3.layer_weights.transformer_layer_weight import Qwen3TransformerLayerWeight from lightllm.common.basemodel.layer_weights.meta_weights import ( - NormWeight, + QKRMSNORMWeight, ROWMMWeight, ) class NeoChatTransformerLayerWeight(Qwen3TransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) return def _init_weight_names(self): @@ -26,26 +26,42 @@ def _init_weight_names(self): def _init_qkv(self): super()._init_qkv() self.q_hw_proj = ROWMMWeight( + in_dim=self.network_config_["hidden_size"], + out_dims=[self.q_head_num_ * self.head_dim], weight_names=self._q_weight_hw_name, data_type=self.data_type_, bias_names=self._q_bias_hw_name, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="q_hw_proj", + quant_method=self.get_quant_method("q_hw_proj"), ) self.k_hw_proj = ROWMMWeight( + in_dim=self.network_config_["hidden_size"], + out_dims=[self.k_head_num_ * self.head_dim], weight_names=self._k_weight_hw_name, data_type=self.data_type_, bias_names=self._k_bias_hw_name, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="k_hw_proj", + quant_method=self.get_quant_method("k_hw_proj"), ) def _init_norm(self): super()._init_norm() - self.q_norm_h_weight_ = NormWeight(weight_name=self._q_norm_h_name, data_type=self.data_type_) - self.q_norm_w_weight_ = NormWeight(weight_name=self._q_norm_w_name, data_type=self.data_type_) - self.k_norm_h_weight_ = NormWeight(weight_name=self._k_norm_h_name, data_type=self.data_type_) - self.k_norm_w_weight_ = NormWeight(weight_name=self._k_norm_w_name, data_type=self.data_type_) + self.q_norm_h_weight_ = QKRMSNORMWeight( + dim=self.head_dim // 2, + weight_name=self._q_norm_h_name, + data_type=self.data_type_, + ) + self.q_norm_w_weight_ = QKRMSNORMWeight( + dim=self.head_dim // 2, + weight_name=self._q_norm_w_name, + data_type=self.data_type_, + ) + self.k_norm_h_weight_ = QKRMSNORMWeight( + dim=self.head_dim // 2, + weight_name=self._k_norm_h_name, + data_type=self.data_type_, + ) + self.k_norm_w_weight_ = QKRMSNORMWeight( + dim=self.head_dim // 2, + weight_name=self._k_norm_w_name, + data_type=self.data_type_, + ) diff --git a/lightllm/models/neo_chat_moe/infer_struct.py b/lightllm/models/neo_chat_moe/infer_struct.py index 0c7d9372e..13d1ba5fc 100644 --- a/lightllm/models/neo_chat_moe/infer_struct.py +++ b/lightllm/models/neo_chat_moe/infer_struct.py @@ -17,8 +17,8 @@ def __init__(self): self.position_cos_w = None self.position_sin_w = None - def init_some_extra_state(self, model: LlamaTpPartModel, input_ids: torch.Tensor): - LlamaInferStateInfo.init_some_extra_state(self, model, input_ids) + def init_some_extra_state(self, model: LlamaTpPartModel): + LlamaInferStateInfo.init_some_extra_state(self, model) if self.is_prefill: self.position_ids = self.get_neo_position(self.multimodal_params) else: @@ -94,6 +94,6 @@ def get_neo_position(self, multimodal_params: List[dict]) -> torch.Tensor: position_ids=position_ids, b_ready_cache_len=self.b_ready_cache_len, b_q_seq_len=self.b_q_seq_len, - b_start_loc=self.b_start_loc, + b_start_loc=self.b_q_start_loc, ) return position_ids diff --git a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py index 8e5c8e5cd..ad891e6ca 100644 --- a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py @@ -1,24 +1,21 @@ import torch from functools import partial from typing import Tuple -from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.models.neo_chat_moe.infer_struct import NeoChatInferStateInfo from lightllm.models.neo_chat_moe.triton_kernel.context_attention_fwd_neo import context_attention_fwd_neo -from lightllm.models.llama.triton_kernel.context_flashattention_nopad import context_attention_fwd from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd from lightllm.models.qwen3_moe.layer_infer.transformer_layer_infer import Qwen3MOETransformerLayerInfer from lightllm.models.neo_chat_moe.layer_weights.transformer_layer_weight import NeoChatMOETransformerLayerWeight from lightllm.distributed import all_reduce import torch.distributed as dist from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer -from lightllm.models.qwen3.triton_kernel.qk_norm import qk_rmsnorm_forward class NeoChatMOETransformerLayerInfer(Qwen3MOETransformerLayerInfer): - def __init__(self, data_type, network_config, mode): - self._is_merge_kv = network_config["is_merge_kv"] - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + self._is_merge_kv = network_config.get("merge_kv", True) + super().__init__(data_type, network_config) return def _bind_attention(self): @@ -49,25 +46,24 @@ def _get_qkv_not_mergekv( cache_kv = layer_weight.kv_proj.mm(input) # [T, (Hk+Hv)*D] - qk_rmsnorm_forward(q, weight=layer_weight.q_norm_weight_.weight, eps=self.eps_) + layer_weight.q_norm_weight_(q, eps=self.eps_) q_h_2d = q_h.reshape(q.shape[0], -1) q_w_2d = q_w.reshape(q.shape[0], -1) - qk_rmsnorm_forward(q_h_2d, weight=layer_weight.q_norm_h_weight_.weight, eps=self.eps_) - qk_rmsnorm_forward(q_w_2d, weight=layer_weight.q_norm_w_weight_.weight, eps=self.eps_) + layer_weight.q_norm_h_weight_(q_h_2d, eps=self.eps_) + layer_weight.q_norm_w_weight_(q_w_2d, eps=self.eps_) q_h = q_h_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2) q_w = q_w_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2) - qk_rmsnorm_forward( + layer_weight.k_norm_weight_( cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], - weight=layer_weight.k_norm_weight_.weight, eps=self.eps_, ) k_h_2d = k_h.reshape(q.shape[0], -1) # [T, Hk*(D/2)] k_w_2d = k_w.reshape(q.shape[0], -1) - qk_rmsnorm_forward(k_h_2d, weight=layer_weight.k_norm_h_weight_.weight, eps=self.eps_) - qk_rmsnorm_forward(k_w_2d, weight=layer_weight.k_norm_w_weight_.weight, eps=self.eps_) + layer_weight.k_norm_h_weight_(k_h_2d, eps=self.eps_) + layer_weight.k_norm_w_weight_(k_w_2d, eps=self.eps_) k_h = k_h_2d.view(q.shape[0], self.tp_k_head_num_, self.head_dim_ // 2) k_w = k_w_2d.view(q.shape[0], self.tp_k_head_num_, self.head_dim_ // 2) @@ -117,16 +113,15 @@ def _get_qkv_mergekv( cache_kv = layer_weight.kv_proj.mm(input) # [T, (Hk+Hv)*D] - qk_rmsnorm_forward(q, weight=layer_weight.q_norm_weight_.weight, eps=self.eps_) - qk_rmsnorm_forward(q_hw, weight=layer_weight.q_norm_hw_weight_.weight, eps=self.eps_) - qk_rmsnorm_forward(k_hw, weight=layer_weight.k_norm_hw_weight_.weight, eps=self.eps_) + layer_weight.q_norm_weight_(q, eps=self.eps_) + layer_weight.q_norm_hw_weight_(q_hw, eps=self.eps_) + layer_weight.k_norm_hw_weight_(k_hw, eps=self.eps_) q_hw = q_hw.view(q.shape[0], self.tp_q_head_num_, self.head_dim_) q_h, q_w = q_hw.chunk(2, dim=-1) - qk_rmsnorm_forward( + layer_weight.k_norm_weight_( cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], - weight=layer_weight.k_norm_weight_.weight, eps=self.eps_, ) @@ -180,17 +175,17 @@ def _context_attention_kernel( o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2), infer_state.position_ids[0], # [0,0,1,2,3,3,3,4] infer_state.b_req_idx, - infer_state.b_start_loc, + infer_state.b_q_start_loc, infer_state.b_seq_len, infer_state.b_ready_cache_len, - infer_state.max_len_in_batch, + infer_state.max_q_seq_len, infer_state.req_manager.req_to_token_indexs, ) o3 = o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2) o3 = o3[:, :, : self.head_dim_].contiguous() return o3.view(o3.shape[0], -1) - def _token_decode_attention_normal(self, q, infer_state: NeoChatInferStateInfo, layer_weight, out=None): + def _token_attention_kernel(self, q, infer_state: NeoChatInferStateInfo, layer_weight): total_token_num = infer_state.total_token_num batch_size = infer_state.batch_size @@ -205,18 +200,22 @@ def _token_decode_attention_normal(self, q, infer_state: NeoChatInferStateInfo, att_m_tensor, infer_state.req_manager.req_to_token_indexs, infer_state.b_req_idx, - infer_state.b_start_loc, + infer_state.b_kv_start_loc, infer_state.b_seq_len, - infer_state.max_len_in_batch, + infer_state.max_kv_seq_len, ) - from lightllm.models.llama.triton_kernel.token_attention_softmax_and_reducev import token_softmax_reducev_fwd + from lightllm.common.basemodel.triton_kernel.att.decode_att.mha.stage3_decode_att import ( + token_attention_softmax_and_reducev, + ) + + token_softmax_reducev_fwd = token_attention_softmax_and_reducev.token_softmax_reducev_fwd v_3d = infer_state.mem_manager.kv_buffer[self.layer_num_][ :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : self.head_dim_ ] - o_3d = self.alloc_tensor((batch_size, self.tp_q_head_num_, self.head_dim_), q.dtype) if out is None else out + o_3d = self.alloc_tensor((batch_size, self.tp_q_head_num_, self.head_dim_), q.dtype) token_softmax_reducev_fwd( att_m_tensor, @@ -224,7 +223,7 @@ def _token_decode_attention_normal(self, q, infer_state: NeoChatInferStateInfo, o_3d, infer_state.req_manager.req_to_token_indexs, infer_state.b_req_idx, - infer_state.b_start_loc, + infer_state.b_kv_start_loc, infer_state.b_seq_len, ) return o_3d.view(batch_size, -1) diff --git a/lightllm/models/neo_chat_moe/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/neo_chat_moe/layer_weights/pre_and_post_layer_weight.py index 7766a5d29..4b0eae91c 100644 --- a/lightllm/models/neo_chat_moe/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/neo_chat_moe/layer_weights/pre_and_post_layer_weight.py @@ -13,8 +13,8 @@ def rename_weight_keys(weights): class NeoChatMOEPreAndPostLayerWeight(Qwen2PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) return def load_hf_weights(self, weights): diff --git a/lightllm/models/neo_chat_moe/layer_weights/transformer_layer_weight.py b/lightllm/models/neo_chat_moe/layer_weights/transformer_layer_weight.py index d8c842bb9..26e986cdd 100644 --- a/lightllm/models/neo_chat_moe/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/neo_chat_moe/layer_weights/transformer_layer_weight.py @@ -1,14 +1,14 @@ from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight from lightllm.common.basemodel.layer_weights.meta_weights import ( - NormWeight, + QKRMSNORMWeight, ROWMMWeight, ) class NeoChatMOETransformerLayerWeight(Qwen3MOETransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): - self._is_merge_kv = network_config["merge_kv"] - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + self._is_merge_kv = network_config.get("merge_kv", True) + super().__init__(layer_num, data_type, network_config, quant_cfg) return def _init_weight_names(self): @@ -31,29 +31,53 @@ def _init_weight_names(self): def _init_qkv(self): super()._init_qkv() self.q_hw_proj = ROWMMWeight( + in_dim=self.network_config_["hidden_size"], + out_dims=[self.q_head_num_ * self.head_dim], weight_names=self._q_weight_hw_name, data_type=self.data_type_, bias_names=self._q_bias_hw_name, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="q_hw_proj", + quant_method=self.get_quant_method("q_hw_proj"), ) self.k_hw_proj = ROWMMWeight( + in_dim=self.network_config_["hidden_size"], + out_dims=[self.k_head_num_ * self.head_dim], weight_names=self._k_weight_hw_name, data_type=self.data_type_, bias_names=self._k_bias_hw_name, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="k_hw_proj", + quant_method=self.get_quant_method("k_hw_proj"), ) def _init_norm(self): super()._init_norm() if self._is_merge_kv: - self.q_norm_hw_weight_ = NormWeight(weight_name=self._q_norm_hw_name, data_type=self.data_type_) - self.k_norm_hw_weight_ = NormWeight(weight_name=self._k_norm_hw_name, data_type=self.data_type_) + self.q_norm_hw_weight_ = QKRMSNORMWeight( + dim=self.head_dim, + weight_name=self._q_norm_hw_name, + data_type=self.data_type_, + ) + self.k_norm_hw_weight_ = QKRMSNORMWeight( + dim=self.head_dim, + weight_name=self._k_norm_hw_name, + data_type=self.data_type_, + ) else: - self.q_norm_h_weight_ = NormWeight(weight_name=self._q_norm_h_name, data_type=self.data_type_) - self.q_norm_w_weight_ = NormWeight(weight_name=self._q_norm_w_name, data_type=self.data_type_) - self.k_norm_h_weight_ = NormWeight(weight_name=self._k_norm_h_name, data_type=self.data_type_) - self.k_norm_w_weight_ = NormWeight(weight_name=self._k_norm_w_name, data_type=self.data_type_) + self.q_norm_h_weight_ = QKRMSNORMWeight( + dim=self.head_dim // 2, + weight_name=self._q_norm_h_name, + data_type=self.data_type_, + ) + self.q_norm_w_weight_ = QKRMSNORMWeight( + dim=self.head_dim // 2, + weight_name=self._q_norm_w_name, + data_type=self.data_type_, + ) + self.k_norm_h_weight_ = QKRMSNORMWeight( + dim=self.head_dim // 2, + weight_name=self._k_norm_h_name, + data_type=self.data_type_, + ) + self.k_norm_w_weight_ = QKRMSNORMWeight( + dim=self.head_dim // 2, + weight_name=self._k_norm_w_name, + data_type=self.data_type_, + ) From 6bbdb4feaa3523b25cc8221290a77130a933f204 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Fri, 30 Jan 2026 07:58:42 +0000 Subject: [PATCH 42/71] fix launch --- lightllm/models/neo_chat_moe/neo_visual.py | 6 +++--- lightllm/server/req_id_generator.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/lightllm/models/neo_chat_moe/neo_visual.py b/lightllm/models/neo_chat_moe/neo_visual.py index 59bd23e2b..60fa82f2b 100644 --- a/lightllm/models/neo_chat_moe/neo_visual.py +++ b/lightllm/models/neo_chat_moe/neo_visual.py @@ -247,9 +247,9 @@ def encode(self, images: List[ImageItem]): uuids.append(img.uuid) image_data = read_shm(get_shm_name_data(img.uuid)) image_data = Image.open(BytesIO(image_data)) - a = img.extra_params["min_pixels"] - b = img.extra_params["max_pixels"] - print(f"self.min_pixels is {a} ,max_pixelx is {b}") + # a = img.extra_params["min_pixels"] + # b = img.extra_params["max_pixels"] + # print(f"self.min_pixels is {a} ,max_pixelx is {b}") pixel_values, image_grid_hw = load_image_native( image_data, patch_size=self.patch_size, diff --git a/lightllm/server/req_id_generator.py b/lightllm/server/req_id_generator.py index 9bf9040c3..da1fade0d 100644 --- a/lightllm/server/req_id_generator.py +++ b/lightllm/server/req_id_generator.py @@ -30,7 +30,7 @@ def __init__(self): self.current_id.arr[0] = 0 self.current_id.arr[1] = 0 self.lock = AtomicShmLock(f"{get_unique_server_name()}_req_id_gen_lock") - self._wait_all_workers_ready() + # self._wait_all_workers_ready() logger.info("ReqIDGenerator init finished") def _wait_all_workers_ready(self): From e436ba565ed2ce2daa10d45df79ecac535c5c72b Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Fri, 30 Jan 2026 08:02:32 +0000 Subject: [PATCH 43/71] fix launch --- lightllm/server/req_id_generator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lightllm/server/req_id_generator.py b/lightllm/server/req_id_generator.py index da1fade0d..20da121dc 100644 --- a/lightllm/server/req_id_generator.py +++ b/lightllm/server/req_id_generator.py @@ -30,7 +30,8 @@ def __init__(self): self.current_id.arr[0] = 0 self.current_id.arr[1] = 0 self.lock = AtomicShmLock(f"{get_unique_server_name()}_req_id_gen_lock") - # self._wait_all_workers_ready() + if self.args.httpserver_workers > 1: + self._wait_all_workers_ready() logger.info("ReqIDGenerator init finished") def _wait_all_workers_ready(self): From aef65bcbef6d6ce637b4c4a1862871c479a3029d Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Fri, 30 Jan 2026 08:14:05 +0000 Subject: [PATCH 44/71] fix tp slice for merged moe weight --- .../meta_weights/mm_weight/mm_slicer.py | 48 +++++++++---------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py index ddbf98a86..15f050c14 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py @@ -47,17 +47,17 @@ def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Ten # 默认weight 的shape是 outxin,这也是目前最通用的约定。 -# 所以row-wise是沿着dim=0进行切分,col-wise是沿着dim=1进行切分。 +# 这里约定row-wise沿着倒数第二维切分,col-wise沿着第一维切分。 class RowSliceMixin(SliceMixinTpl): def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times: int = 1): super().__init__(tp_rank, tp_world_size, repeat_times) def _slice_weight(self, weight: torch.Tensor) -> torch.Tensor: assert ( - weight.shape[0] * self.repeat_times_ % self.tp_world_size_ == 0 - ), f"tp slice error {weight.shape[0] * self.repeat_times_} % {self.tp_world_size_}" - start, end = self._get_slice_start_end(weight.shape[0]) - return weight[start:end, :] + weight.shape[-2] * self.repeat_times_ % self.tp_world_size_ == 0 + ), f"tp slice error {weight.shape[-2] * self.repeat_times_} % {self.tp_world_size_}" + start, end = self._get_slice_start_end(weight.shape[-2]) + return weight[..., start:end, :] def _slice_bias(self, bias: torch.Tensor) -> torch.Tensor: assert ( @@ -75,17 +75,17 @@ def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times: def _slice_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor: assert ( - weight_scale.shape[0] % self.tp_world_size_ == 0 - ), f"tp slice error {weight_scale.shape[0]} % {self.tp_world_size_}" - start, end = self._get_slice_start_end(weight_scale.shape[0]) - return weight_scale[start:end] + weight_scale.shape[-2] % self.tp_world_size_ == 0 + ), f"tp slice error {weight_scale.shape[-2]} % {self.tp_world_size_}" + start, end = self._get_slice_start_end(weight_scale.shape[-2]) + return weight_scale[..., start:end, :] def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Tensor: assert ( - weight_zero_point.shape[0] % self.tp_world_size_ == 0 - ), f"tp slice error {weight_zero_point.shape[0]} % {self.tp_world_size_}" - start, end = self._get_slice_start_end(weight_zero_point.shape[0]) - return weight_zero_point[start:end] + weight_zero_point.shape[-2] % self.tp_world_size_ == 0 + ), f"tp slice error {weight_zero_point.shape[-2]} % {self.tp_world_size_}" + start, end = self._get_slice_start_end(weight_zero_point.shape[-2]) + return weight_zero_point[..., start:end, :] class ColSliceMixin(SliceMixinTpl): @@ -94,10 +94,10 @@ def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times: def _slice_weight(self, weight: torch.Tensor) -> torch.Tensor: assert ( - weight.shape[1] * self.repeat_times_ % self.tp_world_size_ == 0 - ), f"tp slice error {weight.shape[1] * self.repeat_times_ } % {self.tp_world_size_}" - start, end = self._get_slice_start_end(weight.shape[1]) - return weight[:, start:end] + weight.shape[-1] * self.repeat_times_ % self.tp_world_size_ == 0 + ), f"tp slice error {weight.shape[-1] * self.repeat_times_ } % {self.tp_world_size_}" + start, end = self._get_slice_start_end(weight.shape[-1]) + return weight[..., start:end] def _slice_bias(self, bias: torch.Tensor) -> torch.Tensor: return bias / self.tp_world_size_ * self.repeat_times_ @@ -110,16 +110,16 @@ def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times: def _slice_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor: assert ( weight_scale.shape[1] * self.repeat_times_ % self.tp_world_size_ == 0 - ), f"tp slice error {weight_scale.shape[1] * self.repeat_times_ } % {self.tp_world_size_}" - start, end = self._get_slice_start_end(weight_scale.shape[1]) - return weight_scale[:, start:end] + ), f"tp slice error {weight_scale.shape[-1] * self.repeat_times_ } % {self.tp_world_size_}" + start, end = self._get_slice_start_end(weight_scale.shape[-1]) + return weight_scale[..., start:end] def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Tensor: assert ( - weight_zero_point.shape[1] * self.repeat_times_ % self.tp_world_size_ == 0 - ), f"tp slice error {weight_zero_point.shape[1] * self.repeat_times_ } % {self.tp_world_size_}" - start, end = self._get_slice_start_end(weight_zero_point.shape[1]) - return weight_zero_point[:, start:end] + weight_zero_point.shape[-1] * self.repeat_times_ % self.tp_world_size_ == 0 + ), f"tp slice error {weight_zero_point.shape[-1] * self.repeat_times_ } % {self.tp_world_size_}" + start, end = self._get_slice_start_end(weight_zero_point.shape[-1]) + return weight_zero_point[..., start:end] # awq 的量化权重是inxout存储格式,需要定制实现。 From bc87692403b1c62380a46978630c5c71c9a0d657 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Fri, 30 Jan 2026 08:18:50 +0000 Subject: [PATCH 45/71] fix fusemoe weight --- .../meta_weights/fused_moe/fused_moe_weight.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py index 5d6519de4..77d6d40e9 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py @@ -349,12 +349,14 @@ def _load_merge_weight(self, weights: Dict[str, torch.Tensor]): w1_merge_weight = f"{self.weight_prefix}.{self.w1_weight_name}" w2_merge_weight = f"{self.weight_prefix}.{self.w2_weight_name}" w3_merge_weight = f"{self.weight_prefix}.{self.w3_weight_name}" + row_slice_func = self.row_slicer._slice_weight + col_slice_func = self.col_slicer._slice_weight if w1_merge_weight in weights: - self.quant_method.load_weight(weights[w1_merge_weight], self.w1) + self.quant_method.load_weight(row_slice_func(weights[w1_merge_weight]), self.w1) if w2_merge_weight in weights: - self.quant_method.load_weight(weights[w2_merge_weight], self.w2) + self.quant_method.load_weight(col_slice_func(weights[w2_merge_weight]), self.w2) if w3_merge_weight in weights: - self.quant_method.load_weight(weights[w3_merge_weight], self.w3) + self.quant_method.load_weight(row_slice_func(weights[w3_merge_weight]), self.w3) def _load_expert_scale( self, From cf5bcbf4b99045cddca61d12f3236e760cc00baa Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Fri, 30 Jan 2026 10:12:17 +0000 Subject: [PATCH 46/71] fa3 for neo --- .../layer_infer/transformer_layer_infer.py | 96 +++++++++++-------- 1 file changed, 54 insertions(+), 42 deletions(-) diff --git a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py index ad891e6ca..3670dac68 100644 --- a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py @@ -185,45 +185,57 @@ def _context_attention_kernel( o3 = o3[:, :, : self.head_dim_].contiguous() return o3.view(o3.shape[0], -1) - def _token_attention_kernel(self, q, infer_state: NeoChatInferStateInfo, layer_weight): - total_token_num = infer_state.total_token_num - batch_size = infer_state.batch_size - - q_3d = q.view(batch_size, self.tp_q_head_num_, self.head_dim_ * 2) - - att_m_tensor = self.alloc_tensor((self.tp_q_head_num_, total_token_num), torch.float32) - - k_3d = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] - token_att_fwd( - q_3d, - k_3d, - att_m_tensor, - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_kv_start_loc, - infer_state.b_seq_len, - infer_state.max_kv_seq_len, - ) - - from lightllm.common.basemodel.triton_kernel.att.decode_att.mha.stage3_decode_att import ( - token_attention_softmax_and_reducev, - ) - - token_softmax_reducev_fwd = token_attention_softmax_and_reducev.token_softmax_reducev_fwd - - v_3d = infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : self.head_dim_ - ] - - o_3d = self.alloc_tensor((batch_size, self.tp_q_head_num_, self.head_dim_), q.dtype) - - token_softmax_reducev_fwd( - att_m_tensor, - v_3d, - o_3d, - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_kv_start_loc, - infer_state.b_seq_len, - ) - return o_3d.view(batch_size, -1) + def _token_attention_kernel( + self, + q: torch.Tensor, + infer_state: NeoChatInferStateInfo, + layer_weight: NeoChatMOETransformerLayerWeight, + ) -> torch.Tensor: + _k, _v = infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_) + _q = q.view(-1, self.tp_q_head_num_, self.head_dim_ * 2) + o_tensor = infer_state.decode_att_state.decode_att(q=_q, k=_k, v=_v, alloc_func=self.alloc_tensor) + o_tensor = o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2)[:, :, : self.head_dim_].contiguous() + return o_tensor + + # def _token_attention_kernel(self, q, infer_state: NeoChatInferStateInfo, layer_weight): + # total_token_num = infer_state.total_token_num + # batch_size = infer_state.batch_size + + # q_3d = q.view(batch_size, self.tp_q_head_num_, self.head_dim_ * 2) + + # att_m_tensor = self.alloc_tensor((self.tp_q_head_num_, total_token_num), torch.float32) + + # k_3d = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] + # token_att_fwd( + # q_3d, + # k_3d, + # att_m_tensor, + # infer_state.req_manager.req_to_token_indexs, + # infer_state.b_req_idx, + # infer_state.b_kv_start_loc, + # infer_state.b_seq_len, + # infer_state.max_kv_seq_len, + # ) + + # from lightllm.common.basemodel.triton_kernel.att.decode_att.mha.stage3_decode_att import ( + # token_attention_softmax_and_reducev, + # ) + + # token_softmax_reducev_fwd = token_attention_softmax_and_reducev.token_softmax_reducev_fwd + + # v_3d = infer_state.mem_manager.kv_buffer[self.layer_num_][ + # :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : self.head_dim_ + # ] + + # o_3d = self.alloc_tensor((batch_size, self.tp_q_head_num_, self.head_dim_), q.dtype) + + # token_softmax_reducev_fwd( + # att_m_tensor, + # v_3d, + # o_3d, + # infer_state.req_manager.req_to_token_indexs, + # infer_state.b_req_idx, + # infer_state.b_kv_start_loc, + # infer_state.b_seq_len, + # ) + # return o_3d.view(batch_size, -1) From a23288b489369b57c278e5af9b609a025e0c705d Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Fri, 30 Jan 2026 10:43:43 +0000 Subject: [PATCH 47/71] fix dead visual process --- lightllm/server/visualserver/model_infer/model_rpc.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 5e07b162d..22dfa915b 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -4,6 +4,7 @@ import torch import socket import inspect +import setproctitle from datetime import timedelta from typing import Dict, List, Tuple from transformers.configuration_utils import PretrainedConfig @@ -26,6 +27,8 @@ from lightllm.utils.envs_utils import get_env_start_args from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient from lightllm.server.visualserver import set_vit_att_backend +from lightllm.utils.process_check import start_parent_check_thread +from lightllm.utils.envs_utils import get_unique_server_name class VisualModelRpcServer(rpyc.Service): @@ -175,6 +178,8 @@ async def encode(self, images: List[ImageItem]): def _init_env(port, device_id): # 注册graceful 退出的处理 graceful_registry(inspect.currentframe().f_code.co_name) + setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::visual_server::RANK{device_id}") + start_parent_check_thread() import lightllm.utils.rpyc_fix_utils as _ From f5585404ab489521dc1a96ce7ba017d70d244668 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Fri, 30 Jan 2026 10:59:37 +0000 Subject: [PATCH 48/71] auto visual dp --- lightllm/server/api_start.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 121f3b9a6..58dac941b 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -164,6 +164,10 @@ def _launch_subprocesses(args: StartArgs): assert args.mtp_draft_model_dir is None assert args.mtp_step == 0 + # automatically set visual_dp based on visual_tp and tp + if args.visual_tp < args.tp and args.tp % args.visual_tp == 0: + args.visual_dp = args.tp // args.visual_tp + # 检查GPU数量是否足够 if args.visual_gpu_ids is None: args.visual_gpu_ids = list(range(args.visual_dp * args.visual_tp)) From 12c6c6b274f598c6e14484e0a026b4bdbad47125 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Fri, 30 Jan 2026 12:15:36 +0000 Subject: [PATCH 49/71] fix format --- lightllm/utils/patch_torch.py | 8 +++----- lightllm/utils/serializer.py | 17 +++++++---------- lightllm/utils/tensor_bucket.py | 16 ++++++---------- 3 files changed, 16 insertions(+), 25 deletions(-) diff --git a/lightllm/utils/patch_torch.py b/lightllm/utils/patch_torch.py index c504e4bbc..9f51edeb6 100644 --- a/lightllm/utils/patch_torch.py +++ b/lightllm/utils/patch_torch.py @@ -9,7 +9,8 @@ def monkey_patch_torch_reductions(): """Monkey patching before Torch https://github.com/pytorch/pytorch/pull/149248 is fixed""" - # Currently, NPU does not support UUID. This has been temporarily commented out, with support expected in the fourth quarter. + # Currently, NPU does not support UUID. This has been temporarily commented out, + # with support expected in the fourth quarter. # if _is_npu: # return @@ -32,9 +33,7 @@ def monkey_patch_torch_reductions(): def _reduce_tensor_modified(*args, **kwargs): output_fn, output_args = reductions._reduce_tensor_original(*args, **kwargs) - output_args = _modify_tuple( - output_args, _REDUCE_TENSOR_ARG_DEVICE_INDEX, _device_to_uuid - ) + output_args = _modify_tuple(output_args, _REDUCE_TENSOR_ARG_DEVICE_INDEX, _device_to_uuid) return output_fn, output_args @@ -62,4 +61,3 @@ def _device_from_maybe_uuid(device_maybe_uuid: Union[int, str]) -> int: def _modify_tuple(t, index: int, modifier: Callable): return *t[:index], modifier(t[index]), *t[index + 1 :] - diff --git a/lightllm/utils/serializer.py b/lightllm/utils/serializer.py index ae6f418df..d8180aeb0 100644 --- a/lightllm/utils/serializer.py +++ b/lightllm/utils/serializer.py @@ -1,4 +1,3 @@ - # copied from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/utils/common.py import base64 @@ -108,27 +107,25 @@ def find_class(self, module, name): # Block deterministic attacks if (module, name) in self.DENY_CLASSES: raise RuntimeError( - f"Blocked unsafe class loading ({module}.{name}), " - f"to prevent exploitation of CVE-2025-10164" + f"Blocked unsafe class loading ({module}.{name}), " f"to prevent exploitation of CVE-2025-10164" ) # Allowlist of safe-to-load modules. - if any( - (module + ".").startswith(prefix) for prefix in self.ALLOWED_MODULE_PREFIXES - ): + if any((module + ".").startswith(prefix) for prefix in self.ALLOWED_MODULE_PREFIXES): return super().find_class(module, name) # Block everything else. (Potential attack surface) raise RuntimeError( - f"Blocked unsafe class loading ({module}.{name}), " - f"to prevent exploitation of CVE-2025-10164" + f"Blocked unsafe class loading ({module}.{name}), " f"to prevent exploitation of CVE-2025-10164" ) + @dataclass class LocalSerializedTensor: - """torch.Tensor that gets serialized by MultiprocessingSerializer (which only serializes a pointer and not the data). + """torch.Tensor that gets serialized by MultiprocessingSerializer + (which only serializes a pointer and not the data). The i-th element in the list corresponds to i-th rank's GPU.""" values: List[bytes] def get(self, rank: int): - return MultiprocessingSerializer.deserialize(self.values[rank]) \ No newline at end of file + return MultiprocessingSerializer.deserialize(self.values[rank]) diff --git a/lightllm/utils/tensor_bucket.py b/lightllm/utils/tensor_bucket.py index 762bd0dd0..a9d7a367d 100644 --- a/lightllm/utils/tensor_bucket.py +++ b/lightllm/utils/tensor_bucket.py @@ -1,4 +1,6 @@ -# copy from https://raw.githubusercontent.com/sgl-project/sglang/refs/heads/main/python/sglang/srt/weight_sync/tensor_bucket.py +# copy from +# https://raw.githubusercontent.com/sgl-project/sglang/refs/heads/main/python/sglang/ +# srt/weight_sync/tensor_bucket.py from dataclasses import dataclass from typing import List, Tuple @@ -74,9 +76,7 @@ def __init__( else: # Initialize from pre-flattened data if flattened_tensor is None or metadata is None: - raise ValueError( - "Must provide either named_tensors or both flattened_tensor and metadata" - ) + raise ValueError("Must provide either named_tensors or both flattened_tensor and metadata") self.flattened_tensor = flattened_tensor self.metadata = metadata @@ -97,12 +97,8 @@ def reconstruct_tensors(self) -> List[Tuple[str, torch.Tensor]]: reconstructed = [None] * len(self.metadata) for i, meta in enumerate(self.metadata): - tensor = ( - self.flattened_tensor[meta.start_idx : meta.end_idx] - .view(meta.dtype) - .reshape(meta.shape) - ) + tensor = self.flattened_tensor[meta.start_idx : meta.end_idx].view(meta.dtype).reshape(meta.shape) reconstructed[i] = (meta.name, tensor) - return reconstructed \ No newline at end of file + return reconstructed From fd91cad792ce26318538ef8051a627b9f5b5474c Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Mon, 2 Feb 2026 07:46:54 +0000 Subject: [PATCH 50/71] fix decode scale --- lightllm/common/basemodel/attention/base_att.py | 1 + lightllm/common/basemodel/attention/fa3/fp.py | 7 +++++-- .../neo_chat_moe/layer_infer/transformer_layer_infer.py | 9 ++++++++- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/lightllm/common/basemodel/attention/base_att.py b/lightllm/common/basemodel/attention/base_att.py index 859d97ca8..6429bce9a 100644 --- a/lightllm/common/basemodel/attention/base_att.py +++ b/lightllm/common/basemodel/attention/base_att.py @@ -65,6 +65,7 @@ class AttControl: mla_prefill_dict: Dict = None mla_decode: bool = False mla_decode_dict: Dict = None + scale: float = None @dataclass diff --git a/lightllm/common/basemodel/attention/fa3/fp.py b/lightllm/common/basemodel/attention/fa3/fp.py index 952bb39d9..2f5fccd57 100644 --- a/lightllm/common/basemodel/attention/fa3/fp.py +++ b/lightllm/common/basemodel/attention/fa3/fp.py @@ -220,8 +220,11 @@ def _normal_decode_att( sink_weight = None k_descale, v_descale = None, None # disable quantization - Lq = q.shape[-1] - sm_scale = 1.0 / (Lq ** 0.5) + if att_control.scale is not None: + sm_scale = att_control.scale + else: + Lq = q.shape[-1] + sm_scale = 1.0 / (Lq ** 0.5) o = flash_attn_with_kvcache( q=q, k_cache=k.view(k.shape[0], 1, k.shape[1], k.shape[2]), diff --git a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py index 3670dac68..c5efe1eef 100644 --- a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py @@ -10,6 +10,7 @@ from lightllm.distributed import all_reduce import torch.distributed as dist from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer +from lightllm.common.basemodel.attention.base_att import AttControl class NeoChatMOETransformerLayerInfer(Qwen3MOETransformerLayerInfer): @@ -193,7 +194,13 @@ def _token_attention_kernel( ) -> torch.Tensor: _k, _v = infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_) _q = q.view(-1, self.tp_q_head_num_, self.head_dim_ * 2) - o_tensor = infer_state.decode_att_state.decode_att(q=_q, k=_k, v=_v, alloc_func=self.alloc_tensor) + att_control = AttControl() + if att_control.scale is None: + att_control.scale = 1.0 / (self.head_dim_ ** 0.5) + # att_control.mla_decode_dict["softmax_scale"] = 1.0 / (self.head_dim_ ** 0.5) + o_tensor = infer_state.decode_att_state.decode_att( + q=_q, k=_k, v=_v, att_control=att_control, alloc_func=self.alloc_tensor + ) o_tensor = o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2)[:, :, : self.head_dim_].contiguous() return o_tensor From 26812639a79be7094510032eda1bb0d035f6be9b Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Mon, 2 Feb 2026 08:58:27 +0000 Subject: [PATCH 51/71] add new mode support text_ids+image_ids --- lightllm/models/neo_chat_moe/model.py | 8 +++++--- lightllm/server/httpserver/manager.py | 15 ++++++++++++++- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/lightllm/models/neo_chat_moe/model.py b/lightllm/models/neo_chat_moe/model.py index ce2093d45..cf4404090 100644 --- a/lightllm/models/neo_chat_moe/model.py +++ b/lightllm/models/neo_chat_moe/model.py @@ -88,9 +88,11 @@ def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs): add_special_tokens = kwargs.get("add_special_tokens", True) return self.tokenizer.encode(prompt, add_special_tokens=add_special_tokens) image_count = len(multimodal_params.images) - prompt = prompt.replace(IMG_TOKEN, image_tokens, image_count) - - origin_ids = self.tokenizer.encode(prompt, add_special_tokens=kwargs["add_special_tokens"]) + if not kwargs.get("already_tokenized", False): + prompt = prompt.replace(IMG_TOKEN, image_tokens, image_count) + origin_ids = self.tokenizer.encode(prompt, add_special_tokens=kwargs["add_special_tokens"]) + else: + origin_ids = prompt # --> id,id+1...id+num input_ids = [] image_id = 0 diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index f48b9d04c..dde1b5189 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -493,7 +493,20 @@ async def _encode( # 这里的校验对多模态不是很充分, to do if all(isinstance(e, int) for e in prompt): - if not self.enable_multimodal and not self.pd_mode.is_D(): + if self.enable_multimodal: + assert ( + len(multimodal_params.images + multimodal_params.audios) <= self.args.cache_capacity + ), "too many multimodal items!" + if multimodal_params.audios: + assert self.args.enable_multimodal_audio, "audio multimodal not enabled" + await self._alloc_multimodal_resources(multimodal_params, sampling_params) + prompt_ids = self.tokenizer.encode( + prompt, + multimodal_params, + add_special_tokens=sampling_params.add_special_tokens, + already_tokenized=True, + ) + elif not self.enable_multimodal and not self.pd_mode.is_D(): if all(e < self.vocab_size for e in prompt): return prompt else: From fd17aa083555ae9551c919c58cf7ed075d84f255 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Mon, 2 Feb 2026 11:04:47 +0000 Subject: [PATCH 52/71] add new mode support text_ids+image_ids --- lightllm/server/httpserver/manager.py | 1 + 1 file changed, 1 insertion(+) diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index dde1b5189..b6a7b0d12 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -506,6 +506,7 @@ async def _encode( add_special_tokens=sampling_params.add_special_tokens, already_tokenized=True, ) + return prompt_ids elif not self.enable_multimodal and not self.pd_mode.is_D(): if all(e < self.vocab_size for e in prompt): return prompt From e516bd9eabbbaefdabf33612406cc8547d7e7b36 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 2 Feb 2026 13:10:11 +0000 Subject: [PATCH 53/71] add cuda empty cache --- lightllm/common/basemodel/basemodel.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index c7cbb1f27..4f37ebf24 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -1070,25 +1070,32 @@ def release_weight(self): def release_kv_cache(self): self.torch_memory_saver.pause(tag=MemoryTag.KV_CACHE) + torch.cuda.empty_cache() def release_graph(self): self.torch_memory_saver.pause(tag=MemoryTag.GRAPH) + torch.cuda.empty_cache() def release_all(self): self.torch_memory_saver.pause(tag=MemoryTag.WEIGHT) self.torch_memory_saver.pause(tag=MemoryTag.KV_CACHE) self.torch_memory_saver.pause(tag=MemoryTag.GRAPH) + torch.cuda.empty_cache() def resume_weight(self): + torch.cuda.empty_cache() self.torch_memory_saver.resume(tag=MemoryTag.WEIGHT) def resume_kv_cache(self): + torch.cuda.empty_cache() self.torch_memory_saver.resume(tag=MemoryTag.KV_CACHE) def resume_graph(self): + torch.cuda.empty_cache() self.torch_memory_saver.resume(tag=MemoryTag.GRAPH) def resume_all(self): + torch.cuda.empty_cache() self.torch_memory_saver.resume(tag=MemoryTag.WEIGHT) self.torch_memory_saver.resume(tag=MemoryTag.KV_CACHE) self.torch_memory_saver.resume(tag=MemoryTag.GRAPH) From 81a0c1282b89c5247139af2669d451bb4a2bd0b9 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 2 Feb 2026 14:42:10 +0000 Subject: [PATCH 54/71] add invalid token ids to sampling_param for rl training --- .../triton_kernel/post_process/__init__.py | 0 .../post_process/apply_invalid_token.py | 36 +++++++++++++++++++ .../{ => post_process}/apply_penalty.py | 0 .../apply_penalty_gpu_cache.py | 0 .../server/core/objs/py_sampling_params.py | 4 +++ lightllm/server/core/objs/sampling_params.py | 28 +++++++++++++++ .../server/router/model_infer/infer_batch.py | 6 ++++ .../mode_backend/generic_post_process.py | 34 ++++++++++++++++-- 8 files changed, 106 insertions(+), 2 deletions(-) create mode 100644 lightllm/common/basemodel/triton_kernel/post_process/__init__.py create mode 100644 lightllm/common/basemodel/triton_kernel/post_process/apply_invalid_token.py rename lightllm/common/basemodel/triton_kernel/{ => post_process}/apply_penalty.py (100%) rename lightllm/common/basemodel/triton_kernel/{ => post_process}/apply_penalty_gpu_cache.py (100%) diff --git a/lightllm/common/basemodel/triton_kernel/post_process/__init__.py b/lightllm/common/basemodel/triton_kernel/post_process/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/common/basemodel/triton_kernel/post_process/apply_invalid_token.py b/lightllm/common/basemodel/triton_kernel/post_process/apply_invalid_token.py new file mode 100644 index 000000000..353affd8e --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/post_process/apply_invalid_token.py @@ -0,0 +1,36 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_apply_invalid_token( + Logits, + invalid_token_ids, + cu_invalid_token_num, + stride_logit_b, +): + cur_batch = tl.program_id(0) + start_index = tl.load(cu_invalid_token_num + cur_batch) + end_index = tl.load(cu_invalid_token_num + cur_batch + 1) + for i in range(start_index, end_index): + cur_invalid_token_id = tl.load(invalid_token_ids + i) + cur_logit_ptr = Logits + cur_batch * stride_logit_b + cur_invalid_token_id + tl.store(cur_logit_ptr, float("-inf")) + return + + +def apply_invalid_token_ids( + Logits: torch.Tensor, + invalid_token_ids: torch.Tensor, + cu_invalid_token_num: torch.Tensor, +): + batch_size = Logits.shape[0] + grid = (batch_size,) + _fwd_kernel_apply_invalid_token[grid]( + Logits=Logits, + invalid_token_ids=invalid_token_ids, + cu_invalid_token_num=cu_invalid_token_num, + stride_logit_b=Logits.stride(0), + ) + return diff --git a/lightllm/common/basemodel/triton_kernel/apply_penalty.py b/lightllm/common/basemodel/triton_kernel/post_process/apply_penalty.py similarity index 100% rename from lightllm/common/basemodel/triton_kernel/apply_penalty.py rename to lightllm/common/basemodel/triton_kernel/post_process/apply_penalty.py diff --git a/lightllm/common/basemodel/triton_kernel/apply_penalty_gpu_cache.py b/lightllm/common/basemodel/triton_kernel/post_process/apply_penalty_gpu_cache.py similarity index 100% rename from lightllm/common/basemodel/triton_kernel/apply_penalty_gpu_cache.py rename to lightllm/common/basemodel/triton_kernel/post_process/apply_penalty_gpu_cache.py diff --git a/lightllm/server/core/objs/py_sampling_params.py b/lightllm/server/core/objs/py_sampling_params.py index 887f360c8..9194a235d 100644 --- a/lightllm/server/core/objs/py_sampling_params.py +++ b/lightllm/server/core/objs/py_sampling_params.py @@ -54,6 +54,8 @@ def __init__( # processor which only retains scores for the given token ids. Defaults to None. # allowed_token_ids only can be used in "--output_constraint_mode outlines" started server. allowed_token_ids: Optional[List[int]] = None, + # if provided, the invalid token ids will be ignored during generation + invalid_token_ids: Optional[List[int]] = None, # p d mode used params group_request_id: Optional[int] = None, # move kv to deocde node, only used in pd mode @@ -88,6 +90,7 @@ def __init__( self.guided_grammar = guided_grammar self.guided_json = guided_json self.allowed_token_ids = allowed_token_ids + self.invalid_token_ids = invalid_token_ids self.group_request_id = group_request_id self.move_kv_to_decode_node = move_kv_to_decode_node self.suggested_dp_index = suggested_dp_index @@ -267,6 +270,7 @@ def to_dict(self): ret["guided_grammar"] = self.guided_grammar ret["guided_json"] = self.guided_json ret["allowed_token_ids"] = self.allowed_token_ids + ret["invalid_token_ids"] = self.invalid_token_ids ret["move_kv_to_decode_node"] = self.move_kv_to_decode_node return ret diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py index 7d4d2531b..650d15512 100644 --- a/lightllm/server/core/objs/sampling_params.py +++ b/lightllm/server/core/objs/sampling_params.py @@ -17,6 +17,7 @@ REGULAR_CONSTRAINT_MAX_LENGTH = int(os.getenv("LIGHTLLM_REGULAR_CONSTRAINT_MAX_LENGTH", 2048)) GRAMMAR_CONSTRAINT_MAX_LENGTH = int(os.getenv("LIGHTLLM_GRAMMAR_CONSTRAINT_MAX_LENGTH", 2048)) JSON_SCHEMA_MAX_LENGTH = int(os.getenv("LIGHTLLM_JSON_SCHEMA_MAX_LENGTH", 2048)) +INVALID_TOKEN_IDS_MAX_LENGTH = int(os.getenv("LIGHTLLM_INVALID_TOKEN_IDS_MAX_LENGTH", 10)) class StopSequence(ctypes.Structure): @@ -205,6 +206,25 @@ def to_list(self): return list(self.ids[: self.size]) +class InvalidTokenIds(ctypes.Structure): + _pack_ = 4 + _fields_ = [ + ("ids", ctypes.c_int * INVALID_TOKEN_IDS_MAX_LENGTH), + ("size", ctypes.c_int), + ] + + def initialize(self, ids: List[int]): + self.size = len(ids) + assert ( + self.size <= INVALID_TOKEN_IDS_MAX_LENGTH + ), f"Too many invalid token IDs {self.size} > {INVALID_TOKEN_IDS_MAX_LENGTH}." + self.ids[: self.size] = ids[:] + return + + def to_list(self): + return list(self.ids[: self.size]) + + class ExponentialDecayLengthPenalty(ctypes.Structure): _pack_ = 4 _fields_ = [ @@ -306,6 +326,8 @@ class SamplingParams(ctypes.Structure): # processor which only retains scores for the given token ids. Defaults to None. # allowed_token_ids only can be used in "--output_constraint_mode outlines" started server. ("allowed_token_ids", AllowedTokenIds), + # if provided, the invalid token ids will be ignored during generation + ("invalid_token_ids", InvalidTokenIds), ("stop_sequences", StopSequenceGroups), ("exponential_decay_length_penalty", ExponentialDecayLengthPenalty), ("group_request_id", ctypes.c_int64), # p d mode used params @@ -395,6 +417,11 @@ def init(self, tokenizer, **kwargs): self.allowed_token_ids = AllowedTokenIds() self.allowed_token_ids.initialize(allowed_token_ids) + # Initialize invalid_token_ids + invalid_token_ids = kwargs.get("invalid_token_ids", []) + self.invalid_token_ids = InvalidTokenIds() + self.invalid_token_ids.initialize(invalid_token_ids) + if self.do_sample is False: self.temperature = 1.0 self.top_p = 1.0 @@ -495,6 +522,7 @@ def to_dict(self): "guided_grammar": self.guided_grammar.to_str(), "guided_json": self.guided_json.to_str(), "allowed_token_ids": self.allowed_token_ids.to_list(), + "invalid_token_ids": self.invalid_token_ids.to_list(), "group_request_id": self.group_request_id, "move_kv_to_decode_node": self.move_kv_to_decode_node.to_dict(), "skip_special_tokens": self.skip_special_tokens, diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 89230c92d..2b35fad05 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -266,6 +266,7 @@ def __init__( self.fsm_current_state: int = 0 self.allowed_token_ids = self.shm_param.allowed_token_ids.to_list() + self.invalid_token_ids = self.shm_param.invalid_token_ids.to_list() if len(self.allowed_token_ids) == 0: self.allowed_token_ids = None @@ -281,6 +282,11 @@ def __init__( logger.error("allowed_token_ids contain tokenid >= vobsize, we remove these token ids") self.allowed_token_ids = [e for e in self.allowed_token_ids if e < vocab_size] + if len(self.invalid_token_ids) > 0: + if not all(e < vocab_size for e in self.invalid_token_ids): + logger.error("invalid_token_ids contain tokenid >= vobsize, we remove these token ids") + self.invalid_token_ids = [e for e in self.invalid_token_ids if e < vocab_size] + # nixl decode node information if self.shm_param.nixl_params.data_len > 0: self.nixl_decode_node: NIXLDecodeNodeInfo = pickle.loads(self.shm_param.nixl_params.get()) diff --git a/lightllm/server/router/model_infer/mode_backend/generic_post_process.py b/lightllm/server/router/model_infer/mode_backend/generic_post_process.py index e2ccf290e..fc551b08e 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_post_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_post_process.py @@ -1,7 +1,8 @@ import torch from typing import List -from lightllm.common.basemodel.triton_kernel.apply_penalty import apply_penalty -from lightllm.common.basemodel.triton_kernel.apply_penalty_gpu_cache import apply_penalty_gpu_cache +from lightllm.common.basemodel.triton_kernel.post_process.apply_penalty import apply_penalty +from lightllm.common.basemodel.triton_kernel.post_process.apply_penalty_gpu_cache import apply_penalty_gpu_cache +from lightllm.common.basemodel.triton_kernel.post_process.apply_invalid_token import apply_invalid_token_ids from lightllm.server.router.model_infer.infer_batch import InferReq, g_infer_context from lightllm.utils.envs_utils import get_env_start_args @@ -14,7 +15,10 @@ def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]): b_top_ks, b_length_penalty_param, b_mask_eos_reqs, + invalid_token_ids, + cu_invalid_token_num, is_all_greedy, + has_invalid_token_ids, ) = _get_post_sample_tensors(reqs) eos_ids = torch.tensor(eos_id, dtype=torch.int32, device="cpu", pin_memory=True).cuda(non_blocking=True) @@ -59,6 +63,14 @@ def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]): eos_ids=eos_ids, sampling_params_manager=sampling_params_manager, ) + + if has_invalid_token_ids: + apply_invalid_token_ids( + Logits=logits, + invalid_token_ids=invalid_token_ids, + cu_invalid_token_num=cu_invalid_token_num, + ) + logits.div_(b_temperatures.view((-1, 1))) probs = torch.softmax(logits, dim=-1) @@ -112,6 +124,12 @@ def _get_post_sample_tensors(reqs: List[InferReq]): mask_eos_reqs: List[bool] = [] is_all_greedy = True + # invalid token ids + invalid_token_ids: List[int] = [] + has_invalid_token_ids = False + cu_invalid_token_num = [0] + invalid_token_num_start = 0 + for i, req_obj in enumerate(reqs): sample_param = req_obj.sampling_param shm_param = sample_param.shm_param @@ -127,6 +145,11 @@ def _get_post_sample_tensors(reqs: List[InferReq]): if top_k_val > 1: is_all_greedy = False req_idxes.append(req_obj.req_idx) + invalid_token_num_start += len(req_obj.sampling_param.invalid_token_ids) + cu_invalid_token_num.append(invalid_token_num_start) + if len(req_obj.sampling_param.invalid_token_ids) > 0: + has_invalid_token_ids = True + invalid_token_ids.extend(req_obj.sampling_param.invalid_token_ids) req_idxes_cpu = torch.tensor(req_idxes, dtype=torch.int32, device="cpu", pin_memory=True) temperatures_cpu = torch.tensor(temperatures, dtype=torch.float, device="cpu", pin_memory=True) @@ -135,6 +158,10 @@ def _get_post_sample_tensors(reqs: List[InferReq]): length_penalty_param_cpu = torch.tensor(length_penalty_param, dtype=torch.int32, device="cpu", pin_memory=True) mask_eos_reqs_cpu = torch.tensor(mask_eos_reqs, dtype=torch.bool, device="cpu", pin_memory=True) + if has_invalid_token_ids: + invalid_token_ids_cpu = torch.tensor(invalid_token_ids, dtype=torch.int32, device="cpu", pin_memory=True) + cu_invalid_token_num_cpu = torch.tensor(cu_invalid_token_num, dtype=torch.int32, device="cpu", pin_memory=True) + return ( req_idxes_cpu.cuda(non_blocking=True), temperatures_cpu.cuda(non_blocking=True), @@ -142,5 +169,8 @@ def _get_post_sample_tensors(reqs: List[InferReq]): top_ks_cpu.cuda(non_blocking=True), length_penalty_param_cpu.cuda(non_blocking=True), mask_eos_reqs_cpu.cuda(non_blocking=True), + invalid_token_ids_cpu.cuda(non_blocking=True) if has_invalid_token_ids else None, + cu_invalid_token_num_cpu.cuda(non_blocking=True) if has_invalid_token_ids else None, is_all_greedy, + has_invalid_token_ids, ) From 14132d57eb8221c89c254ffbdef2dcfe3a086e2d Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 2 Feb 2026 14:42:33 +0000 Subject: [PATCH 55/71] add unitest for apply_invalid_tokens --- .../triton_kernel/test_apply_invalid_token.py | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 unit_tests/common/basemodel/triton_kernel/test_apply_invalid_token.py diff --git a/unit_tests/common/basemodel/triton_kernel/test_apply_invalid_token.py b/unit_tests/common/basemodel/triton_kernel/test_apply_invalid_token.py new file mode 100644 index 000000000..3b2f159f6 --- /dev/null +++ b/unit_tests/common/basemodel/triton_kernel/test_apply_invalid_token.py @@ -0,0 +1,50 @@ +import pytest +import torch + +from lightllm.common.basemodel.triton_kernel.post_process.apply_invalid_token import ( + apply_invalid_token_ids, +) + + +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) +def test_apply_invalid_token_ids(dtype): + if not torch.cuda.is_available(): + pytest.skip("CUDA is required for Triton kernels.") + + batch_size = 4 + vocab_size = 32 + logits = torch.randn((batch_size, vocab_size), device="cuda", dtype=dtype) + expected = logits.clone() + + invalid_token_ids_per_batch = [ + [1, 3, 5], + [], + [0, 2, 31], + [7], + ] + + flat_ids = [] + cu_invalid_token_num = [0] + invalid_token_num_start = 0 + for ids in invalid_token_ids_per_batch: + flat_ids.extend(ids) + invalid_token_num_start += len(ids) + cu_invalid_token_num.append(invalid_token_num_start) + + invalid_token_ids = torch.tensor(flat_ids, device="cuda", dtype=torch.int32) + cu_invalid_token_num = torch.tensor(cu_invalid_token_num, device="cuda", dtype=torch.int32) + + for batch_idx, ids in enumerate(invalid_token_ids_per_batch): + if ids: + expected[batch_idx, ids] = float("-inf") + + apply_invalid_token_ids( + Logits=logits, + invalid_token_ids=invalid_token_ids, + cu_invalid_token_num=cu_invalid_token_num, + ) + assert torch.equal(logits, expected) + + +if __name__ == "__main__": + pytest.main([__file__]) From ed41960f38fdc5239945d3772302b9e3f9509c21 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Tue, 3 Feb 2026 07:27:08 +0000 Subject: [PATCH 56/71] add gc collect --- lightllm/common/basemodel/basemodel.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 4f37ebf24..c69ae07bd 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -1067,35 +1067,44 @@ def resume_memory_occupation(self, tags: Optional[List[MemoryTag]]): def release_weight(self): self.torch_memory_saver.pause(tag=MemoryTag.WEIGHT) + torch.cuda.empty_cache() + gc.collect() def release_kv_cache(self): self.torch_memory_saver.pause(tag=MemoryTag.KV_CACHE) torch.cuda.empty_cache() + gc.collect() def release_graph(self): self.torch_memory_saver.pause(tag=MemoryTag.GRAPH) torch.cuda.empty_cache() + gc.collect() def release_all(self): self.torch_memory_saver.pause(tag=MemoryTag.WEIGHT) self.torch_memory_saver.pause(tag=MemoryTag.KV_CACHE) self.torch_memory_saver.pause(tag=MemoryTag.GRAPH) torch.cuda.empty_cache() + gc.collect() def resume_weight(self): torch.cuda.empty_cache() + gc.collect() self.torch_memory_saver.resume(tag=MemoryTag.WEIGHT) def resume_kv_cache(self): torch.cuda.empty_cache() + gc.collect() self.torch_memory_saver.resume(tag=MemoryTag.KV_CACHE) def resume_graph(self): torch.cuda.empty_cache() + gc.collect() self.torch_memory_saver.resume(tag=MemoryTag.GRAPH) def resume_all(self): torch.cuda.empty_cache() + gc.collect() self.torch_memory_saver.resume(tag=MemoryTag.WEIGHT) self.torch_memory_saver.resume(tag=MemoryTag.KV_CACHE) self.torch_memory_saver.resume(tag=MemoryTag.GRAPH) From 706ae2e022ade69e0ed255bf0857d7e10cbc2a9e Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Tue, 3 Feb 2026 09:04:51 +0000 Subject: [PATCH 57/71] logit_bias --- lightllm/server/core/objs/sampling_params.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py index 650d15512..1beece342 100644 --- a/lightllm/server/core/objs/sampling_params.py +++ b/lightllm/server/core/objs/sampling_params.py @@ -418,9 +418,9 @@ def init(self, tokenizer, **kwargs): self.allowed_token_ids.initialize(allowed_token_ids) # Initialize invalid_token_ids - invalid_token_ids = kwargs.get("invalid_token_ids", []) + invalid_token_ids = map(int, kwargs.get("logit_bias", {}).keys()) self.invalid_token_ids = InvalidTokenIds() - self.invalid_token_ids.initialize(invalid_token_ids) + self.invalid_token_ids.initialize(list(invalid_token_ids)) if self.do_sample is False: self.temperature = 1.0 From f432f5ae30ddc1badf3dbfdfecbbec0c32ad78f9 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Tue, 3 Feb 2026 10:16:20 +0000 Subject: [PATCH 58/71] logit_bias --- lightllm/server/core/objs/sampling_params.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py index 1beece342..93447830b 100644 --- a/lightllm/server/core/objs/sampling_params.py +++ b/lightllm/server/core/objs/sampling_params.py @@ -420,7 +420,7 @@ def init(self, tokenizer, **kwargs): # Initialize invalid_token_ids invalid_token_ids = map(int, kwargs.get("logit_bias", {}).keys()) self.invalid_token_ids = InvalidTokenIds() - self.invalid_token_ids.initialize(list(invalid_token_ids)) + self.invalid_token_ids.initialize(list[int](invalid_token_ids)) if self.do_sample is False: self.temperature = 1.0 From 8f8ed44ae035467fee240e774bf7c29eafdfc5d2 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Wed, 4 Feb 2026 05:50:51 +0000 Subject: [PATCH 59/71] merge main --- .../neo_chat_moe/layer_infer/transformer_layer_infer.py | 7 ++++--- lightllm/server/api_start.py | 3 ++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py index c5efe1eef..3cf5d1ecb 100644 --- a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py @@ -108,12 +108,13 @@ def _get_qkv_mergekv( ): input = input.view(-1, self.embed_dim_) - q = layer_weight.q_proj.mm(input) # [T, Hq*D] + qkv = layer_weight.qkv_proj.mm(input) + q, cache_kv = qkv.split( + [self.tp_q_head_num_ * self.head_dim_, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_], dim=-1 + ) q_hw = layer_weight.q_hw_proj.mm(input) k_hw = layer_weight.k_hw_proj.mm(input) - cache_kv = layer_weight.kv_proj.mm(input) # [T, (Hk+Hv)*D] - layer_weight.q_norm_weight_(q, eps=self.eps_) layer_weight.q_norm_hw_weight_(q_hw, eps=self.eps_) layer_weight.k_norm_hw_weight_(k_hw, eps=self.eps_) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 58dac941b..afe199d04 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -91,7 +91,6 @@ def _set_envs_and_config(args: StartArgs): def _launch_subprocesses(args: StartArgs): _set_envs_and_config(args) - set_unique_server_name(args) if not args.disable_shm_warning: check_recommended_shm_size(args) @@ -291,6 +290,8 @@ def _launch_subprocesses(args: StartArgs): args.pd_p_allowed_port_min = 20000 args.pd_p_allowed_port_max = 30000 + set_unique_server_name(args) + # p d 分离模式下,decode节点的调度间隙是0 if args.run_mode == "decode": args.router_max_wait_tokens = 0 From cac2edf0a632c589992637e9eff8b767e9013cce Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Fri, 6 Feb 2026 03:12:36 +0000 Subject: [PATCH 60/71] neo moe inferece speedup --- lightllm/models/neo_chat_moe/infer_struct.py | 4 ++ .../layer_infer/transformer_layer_infer.py | 1 + .../context_attention_fwd_neo.py | 63 +++++++------------ .../triton_kernel/get_neo_position.py | 17 +++++ .../models/neo_chat_moe/vision_process.py | 2 +- 5 files changed, 44 insertions(+), 43 deletions(-) diff --git a/lightllm/models/neo_chat_moe/infer_struct.py b/lightllm/models/neo_chat_moe/infer_struct.py index 13d1ba5fc..961ed2a61 100644 --- a/lightllm/models/neo_chat_moe/infer_struct.py +++ b/lightllm/models/neo_chat_moe/infer_struct.py @@ -20,6 +20,9 @@ def __init__(self): def init_some_extra_state(self, model: LlamaTpPartModel): LlamaInferStateInfo.init_some_extra_state(self, model) if self.is_prefill: + self.b_image_token_tag = torch.zeros([self.position_ids.size(0)], dtype=torch.bool, device="cpu").cuda( + non_blocking=True + ) self.position_ids = self.get_neo_position(self.multimodal_params) else: b_position_delta = [0 for _ in range(self.b_seq_len.shape[0])] @@ -95,5 +98,6 @@ def get_neo_position(self, multimodal_params: List[dict]) -> torch.Tensor: b_ready_cache_len=self.b_ready_cache_len, b_q_seq_len=self.b_q_seq_len, b_start_loc=self.b_q_start_loc, + b_image_token_tag=self.b_image_token_tag, ) return position_ids diff --git a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py index 3cf5d1ecb..1518d6874 100644 --- a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py @@ -182,6 +182,7 @@ def _context_attention_kernel( infer_state.b_ready_cache_len, infer_state.max_q_seq_len, infer_state.req_manager.req_to_token_indexs, + infer_state.b_image_token_tag, ) o3 = o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2) o3 = o3[:, :, : self.head_dim_].contiguous() diff --git a/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py b/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py index f5dae493c..42c3254e2 100644 --- a/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py +++ b/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py @@ -34,8 +34,10 @@ def _fwd_kernel( stride_req_to_tokens_s, kv_group_num, b_prompt_cache_len, + b_image_token_tag, H: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, + QK_HEAD_DIM: tl.constexpr, + V_HEAD_DIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): @@ -53,16 +55,19 @@ def _fwd_kernel( cur_batch_req_idx = tl.load(B_req_idx + cur_batch) block_start_loc = BLOCK_M * start_m + if block_start_loc >= cur_batch_seq_len: + return offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) + offs_d_qk = tl.arange(0, QK_HEAD_DIM) + offs_d_v = tl.arange(0, V_HEAD_DIM) offs_m = block_start_loc + tl.arange(0, BLOCK_M) # Q pointers off_q = ( (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh - + offs_d[None, :] * stride_qd + + offs_d_qk[None, :] * stride_qd ) q_valid = offs_m < cur_batch_seq_len @@ -71,24 +76,14 @@ def _fwd_kernel( # online softmax state m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM], dtype=tl.float32) block_end_loc = total_len # absolute q positions in the request q_pos = prompt_cache_len + offs_m # [M] + q_image_token_tag = tl.load(b_image_token_tag + cur_batch_in_all_start_index + offs_m, mask=q_valid, other=False) - # q_gid from packed position_ids (aligned with Q rows) - q_gid = tl.load( - position_ids + cur_batch_in_all_start_index + offs_m, - mask=q_valid, - other=-2147483648, - ).to(tl.int32) - - BIG = tl.full([BLOCK_N], 1000000000, tl.int32) # ensure != any normal gid - - for start_n in range(0, block_mask * block_end_loc, BLOCK_N): + for start_n in range(0, block_end_loc, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) k_pos = start_n + offs_n # [N] @@ -102,32 +97,13 @@ def _fwd_kernel( ).to(tl.int64) # load K - off_k = kv_loc[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd + off_k = kv_loc[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d_qk[:, None] * stride_kd k = tl.load(K + off_k, mask=k_valid[None, :], other=0.0) - - qk = tl.dot(q, k) - - # k_gid: - # - for cached keys (k_pos < prompt_cache_len): set to BIG + k_pos so equality is always false - # - for new keys (k_pos >= prompt_cache_len): read from packed position_ids by (k_pos - prompt_cache_len) - k_in_new = k_pos >= prompt_cache_len - k_new_idx = (k_pos - prompt_cache_len).to(tl.int32) # [N] valid only when k_in_new - k_gid_new = tl.load( - position_ids + cur_batch_in_all_start_index + k_new_idx, - mask=k_valid & k_in_new, - other=-2147483647, - ).to(tl.int32) - - k_gid = tl.where( - k_in_new, - k_gid_new, - (k_pos.to(tl.int32) + BIG), - ) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) # mask: causal OR same gid (only possible inside NEW part) - mask = (q_pos[:, None] >= k_pos[None, :]) | (q_gid[:, None] == k_gid[None, :]) - mask = mask & q_valid[:, None] & k_valid[None, :] - + mask = (q_pos[:, None] >= k_pos[None, :]) | q_image_token_tag[:, None] qk = tl.where(mask, qk * sm_scale, -1.0e8) # online softmax @@ -141,7 +117,7 @@ def _fwd_kernel( acc = acc * alpha[:, None] # load V - off_v = kv_loc[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd + off_v = kv_loc[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d_v[None, :] * stride_vd v = tl.load(V + off_v, mask=k_valid[:, None], other=0.0) p = p.to(v.dtype) @@ -154,7 +130,7 @@ def _fwd_kernel( off_o = ( (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh - + offs_d[None, :] * stride_od + + offs_d_v[None, :] * stride_od ) tl.store(Out + off_o, acc, mask=q_valid[:, None]) @@ -172,6 +148,7 @@ def context_attention_fwd_neo( b_prompt_cache_len, max_input_len, req_to_token_indexs, + b_image_token_tag, ): # minimal safety: position_ids must cover packed q rows assert position_ids.numel() >= q.shape[0], (position_ids.numel(), q.shape[0]) @@ -220,8 +197,10 @@ def context_attention_fwd_neo( req_to_token_indexs.stride(1), kv_group_num=kv_group_num, b_prompt_cache_len=b_prompt_cache_len, + b_image_token_tag=b_image_token_tag, H=head, - BLOCK_DMODEL=Lk, + QK_HEAD_DIM=Lk, + V_HEAD_DIM=Lk // 2, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, num_warps=num_warps, diff --git a/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py b/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py index 955f48bd8..1a3d4af73 100644 --- a/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py +++ b/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py @@ -16,6 +16,7 @@ def _get_neo_position_triton( b_ready_cache_len: torch.Tensor, b_q_seq_len: torch.Tensor, b_start_loc: torch.Tensor, + b_image_token_tag: torch.Tensor, BLOCK_SIZE: tl.constexpr, ) -> torch.Tensor: cur_batch = tl.program_id(0) @@ -36,6 +37,13 @@ def _get_neo_position_triton( t_pos = local_image_start_idx + off * 0 h_pos = off // image_w w_pos = off % image_w + tl.store( + b_image_token_tag + off + image_start_idx, + True, + mask=(off < image_len) + & (off + local_image_start_idx - cache_len < q_seq_len) + & (local_image_start_idx - cache_len + off >= 0), + ) tl.store( position_ids + off + image_start_idx, t_pos, @@ -87,6 +95,7 @@ def get_neo_position_triton( b_ready_cache_len: torch.Tensor, b_q_seq_len: torch.Tensor, b_start_loc: torch.Tensor, + b_image_token_tag: torch.Tensor, ) -> torch.Tensor: batch_size = b_q_seq_len.shape[0] @@ -105,6 +114,7 @@ def get_neo_position_triton( b_ready_cache_len=b_ready_cache_len, b_q_seq_len=b_q_seq_len, b_start_loc=b_start_loc, + b_image_token_tag=b_image_token_tag, BLOCK_SIZE=BLOCK_SIZE, ) @@ -121,6 +131,7 @@ def test(): .expand(3, -1) .contiguous() ) + b_image_token_tag = torch.zeros([position_ids.size(1)], dtype=torch.bool, device="cuda") position_ids[1:].zero_() b_ready_cache_len = torch.tensor([0, 0], dtype=torch.int32, device="cuda") b_q_seq_len = torch.tensor([7, 13], dtype=torch.int32, device="cuda") @@ -135,8 +146,10 @@ def test(): b_ready_cache_len, b_q_seq_len, b_start_loc, + b_image_token_tag, ) + print(b_image_token_tag) print(position_ids) # old_value = torch.cat([position_ids[:, 2:7], position_ids[:, 7 + 2 :]], dim=1) @@ -172,3 +185,7 @@ def test(): [0, 1, 0, 1, 2, 3, 4, 0, 1, 0, 1, 2, 3, 2, 3, 4, 5, 6, 7, 8]], device='cuda:0', dtype=torch.int32) """ + + +if __name__ == "__main__": + test() diff --git a/lightllm/models/neo_chat_moe/vision_process.py b/lightllm/models/neo_chat_moe/vision_process.py index aa008e18f..fbd57a5e9 100644 --- a/lightllm/models/neo_chat_moe/vision_process.py +++ b/lightllm/models/neo_chat_moe/vision_process.py @@ -136,6 +136,6 @@ def load_image_native(image, patch_size=16, downsample_ratio=0.5, min_pixels=655 ) pixel_values, grid_hw = preprocess_pixel_values(transform(new_image).to(torch.float32), patch_size=patch_size) - print(f"Transfer image_size from ({image.height, image.width}) to ({new_image.height, new_image.width})") + # print(f"Transfer image_size from ({image.height, image.width}) to ({new_image.height, new_image.width})") return pixel_values, grid_hw From 02078ad1ded3babcd1b3c8d152e83ee1238ba7d9 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 9 Feb 2026 10:44:23 +0000 Subject: [PATCH 61/71] port random generate --- lightllm/utils/net_utils.py | 33 ++++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/lightllm/utils/net_utils.py b/lightllm/utils/net_utils.py index 486414e88..51ec443d1 100644 --- a/lightllm/utils/net_utils.py +++ b/lightllm/utils/net_utils.py @@ -2,7 +2,6 @@ import subprocess import ipaddress import random -import portpicker from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -13,25 +12,37 @@ def alloc_can_use_network_port(num=3, used_nccl_ports=None, from_port_num=10000) used_nccl_ports = [] port_list = [] + locked_sockets = [] + used_set = set(used_nccl_ports) + max_port = 65535 max_attempts = num * 50 # Allow more attempts to find ports in range for _ in range(max_attempts): if len(port_list) >= num: break - try: - port = portpicker.pick_unused_port() - - if port >= from_port_num and port not in used_nccl_ports: - port_list.append(port) - logger.debug(f"Allocated port: {port}") - else: - logger.debug(f"Port {port} is out of range or in used_nccl_ports, skipping") + # 在 [from_port_num, 65535] 范围内随机选端口,避免多进程同时启动时分配到相同端口 + port = random.randint(from_port_num, max_port) + if port in used_set: + continue - except Exception as e: - logger.warning(f"Failed to allocate port: {e}") + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + try: + sock.bind(("", port)) + port_list.append(port) + used_set.add(port) + locked_sockets.append(sock) + logger.debug(f"Allocated and locked port: {port}") + + except OSError as e: + sock.close() + logger.warning(f"Failed to bind port: {e}") continue + for sock in locked_sockets: + sock.close() + if len(port_list) < num: logger.error(f"Failed to allocate {num} ports, only got {len(port_list)}") return None From 68954b02f49951c0f517a231a95bf5422caf55f2 Mon Sep 17 00:00:00 2001 From: sufubao Date: Mon, 9 Feb 2026 11:11:48 +0000 Subject: [PATCH 62/71] feat: add MoE expert routing capture for R3 rollout replay --- .gitignore | 1 + .../fused_moe/fused_moe_weight.py | 5 + .../fused_moe/gpt_oss_fused_moe_weight_tp.py | 8 + .../meta_weights/fused_moe/impl/base_impl.py | 2 + .../fused_moe/impl/triton_impl.py | 7 + lightllm/common/basemodel/routing_manager.py | 224 ++++++++++++++++++ .../layer_infer/transformer_layer_infer.py | 2 + .../layer_weights/transformer_layer_weight.py | 4 + lightllm/models/deepseek2/model.py | 4 + .../layer_infer/transformer_layer_infer.py | 1 + .../layer_weights/transformer_layer_weight.py | 1 + lightllm/models/gpt_oss/model.py | 7 + lightllm/models/llama/model.py | 16 +- .../models/mixtral/layer_infer/_custom_ops.py | 46 ---- .../layer_infer/transformer_layer_infer.py | 33 +-- .../layer_weights/transformer_layer_weight.py | 1 + lightllm/models/mixtral/model.py | 4 + .../layer_infer/transformer_layer_infer.py | 8 +- .../layer_weights/transformer_layer_weight.py | 6 + lightllm/models/qwen3_moe/model.py | 4 + lightllm/server/api_cli.py | 6 + lightllm/server/api_lightllm.py | 5 + lightllm/server/core/objs/req.py | 65 +++++ lightllm/server/core/objs/shm_array.py | 13 + lightllm/server/core/objs/start_args_type.py | 2 + lightllm/server/httpserver/manager.py | 18 ++ .../server/router/model_infer/infer_batch.py | 19 ++ .../model_infer/mode_backend/base_backend.py | 13 + .../mode_backend/chunked_prefill/impl.py | 4 + .../mode_backend/diverse_backend/impl.py | 2 +- .../mode_backend/dp_backend/impl.py | 13 +- test/test_api/test_r3.py | 99 ++++++++ unit_tests/__init__.py | 0 unit_tests/common/__init__.py | 0 unit_tests/common/basemodel/__init__.py | 0 .../basemodel/test_routing_capture_manager.py | 219 +++++++++++++++++ 36 files changed, 781 insertions(+), 81 deletions(-) create mode 100644 lightllm/common/basemodel/routing_manager.py delete mode 100644 lightllm/models/mixtral/layer_infer/_custom_ops.py create mode 100644 test/test_api/test_r3.py create mode 100644 unit_tests/__init__.py create mode 100644 unit_tests/common/__init__.py create mode 100644 unit_tests/common/basemodel/__init__.py create mode 100644 unit_tests/common/basemodel/test_routing_capture_manager.py diff --git a/.gitignore b/.gitignore index 63408699f..3fb49db8b 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ dist .vscode tmp/ requirements-musa.txt +CLAUDE.md diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py index 77d6d40e9..3dc888b6a 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py @@ -33,6 +33,7 @@ def __init__( num_fused_shared_experts: int = 0, layer_num: int = 0, network_config: Dict[str, Any] = None, + moe_layer_index: int = 0, ) -> None: super().__init__(data_type=data_type) self.w1_weight_name = gate_proj_name @@ -50,6 +51,7 @@ def __init__( self.enable_ep_moe = get_env_start_args().enable_ep_moe self.n_routed_experts = n_routed_experts self.num_fused_shared_experts = num_fused_shared_experts + self.moe_layer_index = moe_layer_index self._init_config(network_config) self._init_redundancy_expert_params() self._init_parallel_params() @@ -130,6 +132,7 @@ def experts( topk_group: int, num_expert_group: int, is_prefill: Optional[bool] = None, + microbatch_index: int = 0, ) -> torch.Tensor: """Backward compatible method that routes to platform-specific implementation.""" return self.fuse_moe_impl( @@ -145,6 +148,8 @@ def experts( topk_group=topk_group, num_expert_group=num_expert_group, is_prefill=is_prefill, + moe_layer_index=self.moe_layer_index, + microbatch_index=microbatch_index, ) def low_latency_dispatch( diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py index 6ed0cef0b..4ca1605be 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py @@ -8,6 +8,7 @@ from lightllm.common.quantization import Quantcfg from lightllm.common.quantization.quantize_method import QuantizationMethod from lightllm.utils.log_utils import init_logger +from lightllm.common.basemodel import routing_manager as _routing_mgr logger = init_logger(__name__) @@ -46,6 +47,7 @@ def __init__( num_fused_shared_experts: int = 0, layer_num: int = 0, network_config: Dict[str, Any] = None, + moe_layer_index: int = 0, ) -> None: network_config["norm_topk_prob"] = None super().__init__( @@ -62,6 +64,7 @@ def __init__( num_fused_shared_experts=num_fused_shared_experts, layer_num=layer_num, network_config=network_config, + moe_layer_index=moe_layer_index, ) self.hidden_size = network_config["hidden_size"] @@ -144,10 +147,15 @@ def experts( topk_group: int, num_expert_group: int, is_prefill: Optional[bool] = None, + microbatch_index: int = 0, ): topk_weights, topk_ids = self._router(router_logits, top_k) + # Rollout router replay + if _routing_mgr.g_routing_capture_manager is not None: + _routing_mgr.g_routing_capture_manager.capture(self.moe_layer_index, topk_ids, microbatch_index) + w1, w1_scale = self.w1 w2, w2_scale = self.w2 use_fp8_w8a8 = self.quant_method is not None diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py index 00587ac18..1c93cb13d 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py @@ -62,5 +62,7 @@ def __call__( topk_group: int, num_expert_group: int, is_prefill: Optional[bool] = None, + moe_layer_index: Optional[int] = None, + microbatch_index: int = 0, ) -> torch.Tensor: pass diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py index 8bcdb4bf9..1e81b226e 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py @@ -3,6 +3,7 @@ from lightllm.common.quantization.no_quant import WeightPack from lightllm.common.quantization.quantize_method import QuantizationMethod from .base_impl import FuseMoeBaseImpl +from lightllm.common.basemodel import routing_manager as _routing_mgr class FuseMoeTriton(FuseMoeBaseImpl): @@ -124,6 +125,8 @@ def __call__( topk_group: int, num_expert_group: int, is_prefill: Optional[bool] = None, + moe_layer_index: Optional[int] = None, + microbatch_index: int = 0, ): topk_weights, topk_ids = self._select_experts( input_tensor=input_tensor, @@ -136,6 +139,10 @@ def __call__( num_expert_group=num_expert_group, scoring_func=scoring_func, ) + + if _routing_mgr.g_routing_capture_manager is not None and moe_layer_index is not None: + _routing_mgr.g_routing_capture_manager.capture(moe_layer_index, topk_ids, microbatch_index) + output = self._fused_experts( input_tensor=input_tensor, w13=w13, diff --git a/lightllm/common/basemodel/routing_manager.py b/lightllm/common/basemodel/routing_manager.py new file mode 100644 index 000000000..9b8c09d8c --- /dev/null +++ b/lightllm/common/basemodel/routing_manager.py @@ -0,0 +1,224 @@ +import atexit +import torch +import numpy as np +from multiprocessing import shared_memory +from typing import Optional +from lightllm.utils.log_utils import init_logger +from lightllm.utils.dist_utils import get_current_rank_in_dp +from lightllm.server.router.dynamic_prompt.shared_arr import SharedArray +from lightllm.utils.envs_utils import get_unique_server_name +from lightllm.utils.shm_utils import create_or_link_shm + +logger = init_logger(__name__) + + +def routing_dtype_id_to_np(dtype_id: int): + if dtype_id == 1: + return np.int8 + elif dtype_id == 2: + return np.int16 + return np.int32 + + +def get_routing_config_shm() -> SharedArray: + service_name = get_unique_server_name() + return SharedArray(f"{service_name}_routing_config", shape=(4,), dtype=np.int32) + + +class RoutingCaptureManager: + def __init__( + self, + num_moe_layers: int, + topk: int, + num_experts: int, + kv_cache_size: int, + max_capture_tokens: int, + ): + self.num_moe_layers = num_moe_layers + self.topk = topk + self.num_experts = num_experts + self.kv_cache_size = kv_cache_size + + self.dtype = torch.int8 if num_experts <= 127 else torch.int16 + dtype_bytes = 1 if self.dtype == torch.int8 else 2 + + # Shape: (num_moe_layers, kv_cache_size, topk) — on CPU to save GPU memory. + # Written after forward() via flush_to_routing_buffer(), read on request finish. + routing_buffer_size = num_moe_layers * kv_cache_size * topk * dtype_bytes + self.routing_buffer = torch.zeros( + (num_moe_layers, kv_cache_size, topk), + dtype=self.dtype, + device="cpu", + ) + + # Capture buffers: simple contiguous tensors written to during forward(). + capture_buf_size = max_capture_tokens * num_moe_layers * topk * dtype_bytes + self._capture_buffer = [ + torch.zeros((max_capture_tokens, num_moe_layers, topk), dtype=self.dtype, device="cuda") for _ in range(2) + ] + + dtype_name = "int8" if self.dtype == torch.int8 else "int16" + logger.info( + f"RoutingCaptureManager initialized: {num_moe_layers} MoE layers, topk={topk}, " + f"routing_buffer(cpu)={routing_buffer_size / 1024 / 1024:.2f}MB, " + f"capture_buffer={capture_buf_size / 1024 / 1024:.2f}MB x2, dtype={dtype_name}" + ) + + @property + def np_dtype(self): + return np.int8 if self.dtype == torch.int8 else np.int16 + + @property + def dtype_id(self) -> int: + return 1 if self.dtype == torch.int8 else 2 + + def capture(self, moe_layer_index: int, topk_ids: torch.Tensor, microbatch_index: int = 0) -> None: + num_tokens = topk_ids.shape[0] + self._capture_buffer[microbatch_index][:num_tokens, moe_layer_index, :] = topk_ids.to(self.dtype) + + def flush_to_routing_buffer(self, mem_indexes: torch.Tensor, num_tokens: int, microbatch_index: int = 0) -> None: + buf = self._capture_buffer[microbatch_index][:num_tokens] # (num_tokens, num_moe_layers, topk) + buf_t = buf.permute(1, 0, 2).cpu() + self.routing_buffer[:, mem_indexes[:num_tokens].cpu(), :] = buf_t + + def extract_routing_data(self, mem_indexes: torch.Tensor) -> np.ndarray: + cpu_indexes = mem_indexes.cpu() if mem_indexes.is_cuda else mem_indexes + return self.routing_buffer[:, cpu_indexes, :].numpy() + + +g_routing_capture_manager: Optional[RoutingCaptureManager] = None + + +def create_routing_capture_manager( + num_moe_layers: int, + topk: int, + num_experts: int, + kv_cache_size: int, + max_capture_tokens: int, +) -> None: + global g_routing_capture_manager + assert g_routing_capture_manager is None, "RoutingCaptureManager already exists" + g_routing_capture_manager = RoutingCaptureManager( + num_moe_layers=num_moe_layers, + topk=topk, + num_experts=num_experts, + kv_cache_size=kv_cache_size, + max_capture_tokens=max_capture_tokens, + ) + + +def preallocate_routing_shm_pool(max_req_num: int, num_moe_layers: int, max_tokens: int, topk: int, np_dtype) -> None: + """Pre-allocate POSIX SHM segments for all request slots. + + Each segment is sized for the maximum possible routing data so it can be + reused across requests without create/destroy overhead. + """ + dtype_bytes = np.dtype(np_dtype).itemsize + segment_size = num_moe_layers * max_tokens * topk * dtype_bytes + service_name = get_unique_server_name() + + for i in range(max_req_num): + name = f"{service_name}_shm_routing_{i}" + shm = create_or_link_shm(name, segment_size, auto_cleanup=True) + shm.close() # close handle; SHM persists in /dev/shm + + logger.info( + f"Pre-allocated {max_req_num} routing SHM segments, " + f"each {segment_size / 1024:.1f} KB (total {max_req_num * segment_size / 1024 / 1024:.1f} MB)" + ) + + +def cleanup_routing_shm_pool() -> None: + """Unlink all pre-allocated routing SHM segments. Called at server shutdown.""" + try: + from lightllm.utils.envs_utils import get_env_start_args + + args = get_env_start_args() + except Exception: + return + + service_name = get_unique_server_name() + + for i in range(args.running_max_req_size): + name = f"{service_name}_shm_routing_{i}" + try: + shm = shared_memory.SharedMemory(name=name) + shm.close() + shm.unlink() + except Exception: + pass + + config_name = f"{service_name}_routing_config" + try: + shm = shared_memory.SharedMemory(name=config_name) + shm.close() + shm.unlink() + except Exception: + pass + + +def init_routing_capture(model, num_moe_layers: int) -> None: + dp_rank = get_current_rank_in_dp() + logger.info(f"init_routing_capture called: num_moe_layers={num_moe_layers}, dp_rank={dp_rank}") + if dp_rank != 0: + logger.info(f"Skipping routing capture initialization on dp_rank={dp_rank}") + return + + if num_moe_layers == 0: + logger.warning( + "enable_return_routed_experts is set but no MoE layers found. Routing capture will not be enabled." + ) + return + + num_experts = model.config.get("n_routed_experts", model.config.get("num_experts", 0)) + topk = model.config.get("num_experts_per_tok", 0) + assert num_experts > 0 and topk > 0 + + from lightllm.utils.envs_utils import get_env_start_args + + args = get_env_start_args() + + # Capture buffer must fit the max tokens in any single forward call. + # For prefill that's batch_max_tokens; for decode it's graph_max_batch_size. + batch_max_tokens = args.batch_max_tokens or args.max_req_total_len or 8192 + max_capture_tokens = max(batch_max_tokens, args.graph_max_batch_size) + + logger.info( + f"Initializing routing capture: num_moe_layers={num_moe_layers}, " + f"topk={topk}, num_experts={num_experts}, max_capture_tokens={max_capture_tokens}" + ) + + create_routing_capture_manager( + num_moe_layers=num_moe_layers, + topk=topk, + num_experts=num_experts, + kv_cache_size=model.mem_manager.size + 1, + max_capture_tokens=max_capture_tokens, + ) + + mgr = g_routing_capture_manager + np_dtype = mgr.np_dtype + dtype_id = mgr.dtype_id + + max_req_total_len = args.max_req_total_len + + # Write config to cross-process SHM + shm = get_routing_config_shm() + shm.arr[0] = num_moe_layers + shm.arr[1] = topk + shm.arr[2] = dtype_id + shm.arr[3] = max_req_total_len + logger.info( + f"Shared routing config set: num_moe_layers={num_moe_layers}, topk={topk}, " + f"dtype_id={dtype_id}, max_tokens={max_req_total_len}" + ) + + preallocate_routing_shm_pool( + max_req_num=args.running_max_req_size, + num_moe_layers=num_moe_layers, + max_tokens=max_req_total_len, + topk=topk, + np_dtype=np_dtype, + ) + + atexit.register(cleanup_routing_shm_pool) diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index 98cc7c229..97015f6b2 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -312,6 +312,7 @@ def _moe_ffn( use_grouped_topk=self.n_group, topk_group=self.topk_group, num_expert_group=self.n_group, + microbatch_index=infer_state.microbatch_index, ) if self.n_shared_experts is not None and layer_weight.num_fused_shared_experts == 0: @@ -339,6 +340,7 @@ def _moe_ffn_edp( topk_group=self.topk_group, num_expert_group=self.n_group, is_prefill=infer_state.is_prefill, + microbatch_index=infer_state.microbatch_index, ) if self.n_shared_experts is not None: diff --git a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py index 3eb09f917..bd7203507 100644 --- a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py @@ -242,6 +242,9 @@ def _init_moe(self): # == 0 时,说明不存在融合共享专家,共享专家单独加载和进行推理。 if self.num_fused_shared_experts == 0: self._load_mlp(f"model.layers.{self.layer_num_}.mlp.shared_experts", is_shared_experts=True) + first_moe = self.network_config_["first_k_dense_replace"] + freq = self.network_config_.get("moe_layer_freq", 1) + moe_layer_index = (self.layer_num_ - first_moe) // freq self.experts = FusedMoeWeight( gate_proj_name="gate_proj", down_proj_name="down_proj", @@ -256,6 +259,7 @@ def _init_moe(self): num_fused_shared_experts=self.num_fused_shared_experts, layer_num=self.layer_num_, network_config=self.network_config_, + moe_layer_index=moe_layer_index, ) def _init_ffn(self): diff --git a/lightllm/models/deepseek2/model.py b/lightllm/models/deepseek2/model.py index e596eed97..79bd32706 100644 --- a/lightllm/models/deepseek2/model.py +++ b/lightllm/models/deepseek2/model.py @@ -6,6 +6,7 @@ from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo from lightllm.models.llama.model import LlamaTpPartModel from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class +from lightllm.common.basemodel.routing_manager import init_routing_capture from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args, get_added_mtp_kv_layer_num from lightllm.distributed.communication_op import dist_group_manager @@ -49,6 +50,9 @@ def _init_some_value(self): def _init_custom(self): self._init_to_get_yarn_rotary() dist_group_manager.new_deepep_group(self.config["n_routed_experts"], self.config["hidden_size"]) + if self.args.enable_return_routed_experts: + num_moe_layers = sum(1 for w in self.trans_layers_weight if w.is_moe) + init_routing_capture(self, num_moe_layers) def _verify_params(self): return super()._verify_params() diff --git a/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py b/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py index d80eefd16..e5672f821 100644 --- a/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py @@ -51,6 +51,7 @@ def _ffn(self, input, infer_state, layer_weight: GptOssTransformerLayerWeight) - use_grouped_topk=False, topk_group=None, num_expert_group=None, + microbatch_index=infer_state.microbatch_index, ) return hidden_states.view(num_tokens, hidden_dim) diff --git a/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py b/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py index 7c8c30940..7278c62fe 100644 --- a/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py @@ -55,6 +55,7 @@ def _init_moe(self): num_fused_shared_experts=0, layer_num=self.layer_num_, network_config=self.network_config_, + moe_layer_index=self.layer_num_, ) def _init_weight_names(self): diff --git a/lightllm/models/gpt_oss/model.py b/lightllm/models/gpt_oss/model.py index 9e9561eb2..cff748933 100644 --- a/lightllm/models/gpt_oss/model.py +++ b/lightllm/models/gpt_oss/model.py @@ -2,6 +2,7 @@ from lightllm.models.gpt_oss.layer_weights.transformer_layer_weight import GptOssTransformerLayerWeight from lightllm.models.llama.model import LlamaTpPartModel from lightllm.models.registry import ModelRegistry +from lightllm.common.basemodel.routing_manager import init_routing_capture from lightllm.utils.envs_utils import get_env_start_args from lightllm.utils.log_utils import init_logger from lightllm.common.basemodel.attention import get_prefill_att_backend_class, get_decode_att_backend_class @@ -21,6 +22,12 @@ class GptOssTpPartModel(LlamaTpPartModel): def __init__(self, kvargs): super().__init__(kvargs) + def _init_custom(self): + super()._init_custom() + if self.args.enable_return_routed_experts: + num_moe_layers = len(self.trans_layers_weight) + init_routing_capture(self, num_moe_layers) + def _init_att_backend(self): self.prefill_att_backend: BaseAttBackend = get_prefill_att_backend_class(index=0, priority_list=["fa3"])( model=self diff --git a/lightllm/models/llama/model.py b/lightllm/models/llama/model.py index f86bd5f83..cc1dc2817 100644 --- a/lightllm/models/llama/model.py +++ b/lightllm/models/llama/model.py @@ -74,16 +74,19 @@ def _init_custom(self): rope_scaling = self.config.get("rope_scaling", None) if rope_scaling is None: self._init_to_get_rotary() - if "rope_theta_hw" in self.config: - self._init_to_get_hw_rotary() - return - - if "rope_type" in rope_scaling: + elif "rope_type" in rope_scaling: scaling_type = rope_scaling["rope_type"] + self._init_rotary_by_scaling_type(scaling_type, rope_scaling) elif "type" in rope_scaling: scaling_type = rope_scaling["type"] + self._init_rotary_by_scaling_type(scaling_type, rope_scaling) else: raise ValueError(f"Unknown RoPE scaling format {rope_scaling}") + if "rope_theta_hw" in self.config: + self._init_to_get_hw_rotary() + super()._init_custom() + + def _init_rotary_by_scaling_type(self, scaling_type, rope_scaling): if scaling_type == "default" or "mrope_section" in rope_scaling: self._init_to_get_rotary() elif scaling_type == "yarn": @@ -98,9 +101,6 @@ def _init_custom(self): self._init_to_get_rotary() else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") - if "rope_theta_hw" in self.config: - self._init_to_get_hw_rotary() - return def _init_to_get_rotary(self, default_base=10000): partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_) diff --git a/lightllm/models/mixtral/layer_infer/_custom_ops.py b/lightllm/models/mixtral/layer_infer/_custom_ops.py deleted file mode 100644 index b0e27ac1d..000000000 --- a/lightllm/models/mixtral/layer_infer/_custom_ops.py +++ /dev/null @@ -1,46 +0,0 @@ -import functools -import json -import os -from typing import Any, Dict, Optional, Tuple - -import torch -import triton -import triton.language as tl -from lightllm.utils.log_utils import init_logger - -logger = init_logger(__name__) - -# Pytorch version -# Triton version in progress -def topk_softmax( - topk_weights, - topk_ids, - token_expert_indicies, - gating_output, - topk=2, -): - scores = torch.softmax(gating_output, dim=-1) - topk_weights, topk_ids = torch.topk(scores, k=topk, dim=-1, sorted=False) - return topk_weights, topk_ids - - -def fused_topk( - hidden_states: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, - alloc_tensor_func=torch.empty, -): - assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" - - M, _ = hidden_states.shape - - topk_weights = alloc_tensor_func((M, topk), dtype=torch.float32, device=hidden_states.device) - topk_ids = alloc_tensor_func((M, topk), dtype=torch.int32, device=hidden_states.device) - token_expert_indicies = alloc_tensor_func((M, topk), dtype=torch.int32, device=hidden_states.device) - topk_weights, topk_ids = topk_softmax(topk_weights, topk_ids, token_expert_indicies, gating_output.float(), topk) - del token_expert_indicies # Not used. Will be used in the future. - - if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - return topk_weights, topk_ids diff --git a/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py b/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py index 44e66cff2..a2968f5ab 100644 --- a/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py @@ -1,9 +1,6 @@ -import os import torch -import torch.nn.functional as F from lightllm.common.basemodel.infer_struct import InferStateInfo from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer -from lightllm.models.mixtral.layer_infer._custom_ops import fused_topk from lightllm.models.mixtral.layer_weights.transformer_layer_weight import MixtralTransformerLayerWeight @@ -19,25 +16,15 @@ def _ffn(self, input, infer_state: InferStateInfo, layer_weight: MixtralTransfor hidden_states = input.view(-1, self.embed_dim_) num_tokens, hidden_dim = hidden_states.shape - router_logits = layer_weight.moe_gate.mm(input.view(-1, self.embed_dim_)) - topk_weights, topk_ids = fused_topk( - hidden_states=hidden_states, - gating_output=router_logits, - topk=self.num_experts_per_tok, + router_logits = layer_weight.moe_gate.mm(hidden_states) + layer_weight.experts.experts( + hidden_states, + router_logits=router_logits, + top_k=self.num_experts_per_tok, renormalize=self.renormalize, - alloc_tensor_func=self.alloc_tensor, - ) - from lightllm.common.fused_moe.grouped_fused_moe import fused_experts_impl - - return fused_experts_impl( - hidden_states=hidden_states, - w1=layer_weight.experts.w1[0], - w2=layer_weight.experts.w2[0], - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - use_fp8_w8a8=False, - w1_scale=None, - w2_scale=None, - alloc_tensor_func=self.alloc_tensor, + use_grouped_topk=False, + topk_group=None, + num_expert_group=None, + microbatch_index=getattr(infer_state, "microbatch_index", 0), ) + return hidden_states.view(num_tokens, hidden_dim) diff --git a/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py b/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py index 51c62fd4c..d93cb5fb5 100644 --- a/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py @@ -57,4 +57,5 @@ def _init_moe(self): quant_method=self.quant_cfg.get_quant_method(self.layer_num_, "fused_moe"), layer_num=self.layer_num_, network_config=self.network_config_, + moe_layer_index=self.layer_num_, ) diff --git a/lightllm/models/mixtral/model.py b/lightllm/models/mixtral/model.py index 3c2d7b4e8..35bf38de5 100644 --- a/lightllm/models/mixtral/model.py +++ b/lightllm/models/mixtral/model.py @@ -2,6 +2,7 @@ import numpy as np from lightllm.models.registry import ModelRegistry from lightllm.common.basemodel.basemodel import TpPartBaseModel +from lightllm.common.basemodel.routing_manager import init_routing_capture from lightllm.common.kv_cache_mem_manager import MemoryManager from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer @@ -45,6 +46,9 @@ def _verify_params(self): def _init_custom(self): self._init_to_get_rotary() + if self.args.enable_return_routed_experts: + num_moe_layers = len(self.trans_layers_weight) + init_routing_capture(self, num_moe_layers) return def _init_mem_manager(self): diff --git a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py index 9eccddffc..af035e81b 100644 --- a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py @@ -93,8 +93,10 @@ def _tpsp_get_qkv( input = gather_input[0 : len(infer_state.input_ids), :] input = input.view(-1, self.embed_dim_) - q = layer_weight.q_proj.mm(input) - cache_kv = layer_weight.kv_proj.mm(input) + qkv = layer_weight.qkv_proj.mm(input) + q, cache_kv = qkv.split( + [self.tp_q_head_num_ * self.head_dim_, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_], dim=-1 + ) layer_weight.q_norm_weight_(q, eps=self.eps_) layer_weight.k_norm_weight_( cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], @@ -130,6 +132,7 @@ def _moe_ffn( use_grouped_topk=False, topk_group=None, num_expert_group=None, + microbatch_index=infer_state.microbatch_index, ) return hidden_states.view(num_tokens, hidden_dim) @@ -150,6 +153,7 @@ def _moe_ffn_edp( topk_group=None, num_expert_group=None, is_prefill=infer_state.is_prefill, + microbatch_index=infer_state.microbatch_index, ) ep_output = ep_output.view(token_num, hidden_dim) diff --git a/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py index 13ba6cbe0..5a857fd09 100644 --- a/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py @@ -52,6 +52,11 @@ def _init_moe(self): tp_rank=0, tp_world_size=1, ) + mlp_only = set(self.network_config_.get("mlp_only_layers", [])) + step = self.network_config_.get("decoder_sparse_step", 1) + moe_layer_index = sum( + 1 for i in range(self.layer_num_) if self.n_routed_experts > 0 and i not in mlp_only and (i + 1) % step == 0 + ) self.experts = FusedMoeWeight( gate_proj_name="gate_proj", down_proj_name="down_proj", @@ -65,6 +70,7 @@ def _init_moe(self): quant_method=self.quant_cfg.get_quant_method(self.layer_num_, "fused_moe"), layer_num=self.layer_num_, network_config=self.network_config_, + moe_layer_index=moe_layer_index, ) def _init_qkv(self): diff --git a/lightllm/models/qwen3_moe/model.py b/lightllm/models/qwen3_moe/model.py index 10a505127..2926a12b1 100644 --- a/lightllm/models/qwen3_moe/model.py +++ b/lightllm/models/qwen3_moe/model.py @@ -4,6 +4,7 @@ from lightllm.models.qwen3_moe.layer_infer.transformer_layer_infer import Qwen3MOETransformerLayerInfer from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight from lightllm.models.qwen3.model import Qwen3TpPartModel +from lightllm.common.basemodel.routing_manager import init_routing_capture from lightllm.utils.log_utils import init_logger from lightllm.distributed.communication_op import dist_group_manager @@ -26,3 +27,6 @@ def __init__(self, kvargs): def _init_custom(self): super()._init_custom() dist_group_manager.new_deepep_group(self.config["num_experts"], self.config["hidden_size"]) + if self.args.enable_return_routed_experts: + num_moe_layers = sum(1 for w in self.trans_layers_weight if w.is_moe) + init_routing_capture(self, num_moe_layers) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 73401f163..409460feb 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -644,4 +644,10 @@ def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: If the op is not implemented for the platform and the hardware support triton, it will use triton implementation.""", ) + parser.add_argument( + "--enable_return_routed_experts", + action="store_true", + default=False, + help="Enable returning routed expert indices for MoE models (R3 feature).", + ) return parser diff --git a/lightllm/server/api_lightllm.py b/lightllm/server/api_lightllm.py index d3592a5f5..5abd90815 100644 --- a/lightllm/server/api_lightllm.py +++ b/lightllm/server/api_lightllm.py @@ -53,6 +53,7 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana prompt_token_ids = None is_first_metadata = True input_usage = None + routed_experts_data = None async for sub_req_id, request_output, metadata, finish_status in results_generator: # when set "--return_all_prompt_logprobs", the first token metadata will contains # prompt_logprobs and prompt_token_ids @@ -78,6 +79,8 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana if finish_status.is_finished(): finish_reason_dict[sub_req_id] = finish_status + if "routed_experts" in metadata: + routed_experts_data = metadata["routed_experts"] n = sampling_params.n sub_ids = list(final_output_dict.keys())[:n] final_output_list = ["".join(final_output_dict[sub_id]) for sub_id in sub_ids] @@ -102,6 +105,8 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana ret["prompt_logprobs"] = prompt_logprobs if input_usage is not None: ret["input_usage"] = input_usage + if routed_experts_data is not None: + ret["routed_experts"] = routed_experts_data return Response(content=json.dumps(ret, ensure_ascii=False).encode("utf-8")) diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py index 7df5ba74e..4a33b659b 100644 --- a/lightllm/server/core/objs/req.py +++ b/lightllm/server/core/objs/req.py @@ -1,6 +1,7 @@ import os import math import ctypes +import base64 import numpy as np import time from .sampling_params import SamplingParams @@ -13,6 +14,7 @@ from lightllm.utils.kv_cache_utils import compute_token_list_hash from typing import List, Any, Union from lightllm.utils.log_utils import init_logger +from lightllm.utils.shm_utils import create_or_link_shm logger = init_logger(__name__) @@ -125,6 +127,8 @@ class Req(ctypes.Structure): ("cpu_cache_match_page_indexes", CpuCachePageList), # 分块hash的块大小 ("cpu_cache_token_page_size", ctypes.c_int), + # Number of tokens in routing data SHM, written by model worker, read by HTTP server. + ("shm_routing_num_tokens", ctypes.c_int), ] def get_str(self): @@ -182,6 +186,7 @@ def init( self._mtp_step = get_env_start_args().mtp_step self.stop_str_matched = False self.stop_str_matched_token_index = -1 + self.shm_routing_num_tokens = 0 self.post_init() @@ -230,6 +235,66 @@ def link_logprobs_shm_array(self): self.shm_logprobs.link_shm() return + def create_routing_data_shm_array(self, num_moe_layers: int, num_tokens: int, topk: int, np_dtype=np.int8): + """Link to a pre-allocated routing SHM and create a numpy view for the actual data shape.""" + service_uni_name = get_unique_server_name() + name = f"{service_uni_name}_shm_routing_{self.index_in_shm_mem}" + shape = (num_moe_layers, num_tokens, topk) + self.shm_routing_data = ShmArray(name, shape, dtype=np_dtype) + self.shm_routing_data.link_shm_partial() + self.shm_routing_num_tokens = num_tokens + return + + def link_routing_data_shm_array(self, num_moe_layers: int, topk: int, np_dtype=np.int8): + """Link to the pre-allocated routing SHM from the reader side (HTTP server).""" + if num_moe_layers == 0: + return + num_tokens = self.shm_routing_num_tokens + if num_tokens <= 0: + return + service_uni_name = get_unique_server_name() + name = f"{service_uni_name}_shm_routing_{self.index_in_shm_mem}" + shape = (num_moe_layers, num_tokens, topk) + self.shm_routing_data = ShmArray(name, shape, dtype=np_dtype) + self.shm_routing_data.link_shm_partial() + return + + def get_routing_data(self): + if not hasattr(self, "shm_routing_data") or self.shm_routing_data is None: + return None + return self.shm_routing_data.arr + + def close_routing_data_shm_array(self): + """Detach from pre-allocated SHM without unlinking it.""" + if hasattr(self, "shm_routing_data") and self.shm_routing_data is not None: + self.shm_routing_data.detach_shm() + self.shm_routing_data = None + self.shm_routing_num_tokens = 0 + return + + def get_routing_metadata(self, num_moe_layers: int, topk: int, dtype_id: int = 1): + if num_moe_layers == 0 or topk == 0: + return None + if self.shm_routing_num_tokens <= 0: + return None + try: + from lightllm.common.basemodel.routing_manager import routing_dtype_id_to_np + + np_dtype = routing_dtype_id_to_np(dtype_id) + if not hasattr(self, "shm_routing_data") or self.shm_routing_data is None: + self.link_routing_data_shm_array(num_moe_layers, topk, np_dtype=np_dtype) + routing_data = self.get_routing_data() + if routing_data is None: + return None + return { + "shape": list(routing_data.shape), + "dtype": str(routing_data.dtype), + "data": base64.b64encode(routing_data.tobytes()).decode("ascii"), + } + except Exception as e: + logger.warning(f"Failed to read routing data for req {self.request_id}: {e}") + return None + def get_prompt_ids(self): return self.shm_prompt_ids.arr[: self.input_len].tolist() diff --git a/lightllm/server/core/objs/shm_array.py b/lightllm/server/core/objs/shm_array.py index c5ad512c6..1bf20535a 100644 --- a/lightllm/server/core/objs/shm_array.py +++ b/lightllm/server/core/objs/shm_array.py @@ -26,6 +26,19 @@ def link_shm(self): self.arr = np.ndarray(self.shape, dtype=self.dtype, buffer=self.shm.buf) return + def link_shm_partial(self): + """Link to an existing SHM that may be larger than the needed shape.""" + self.shm = create_or_link_shm(self.name, -1, force_mode="link") + assert self.shm.size >= self.dest_size, f"SHM {self.name} too small: need {self.dest_size}, got {self.shm.size}" + self.arr = np.ndarray(self.shape, dtype=self.dtype, buffer=self.shm.buf) + + def detach_shm(self): + """Close handle without unlinking (SHM persists for reuse).""" + if self.shm is not None: + self.shm.close() + self.shm = None + self.arr = None + def close_shm(self): if self.shm is not None: self.shm.close() diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 2cb12ed89..4ac0a4dd2 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -184,4 +184,6 @@ class StartArgs: enable_torch_fallback: bool = field(default=False) enable_triton_fallback: bool = field(default=False) + enable_return_routed_experts: bool = field(default=False) + weight_version: str = "default" diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index b6a7b0d12..3ef778ca4 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -29,6 +29,7 @@ from lightllm.server.core.objs.shm_req_manager import ShmReqManager from lightllm.server.core.objs.atomic_array_lock import AtomicShmArrayLock, AsyncLock, AtomicLockItem from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt +from lightllm.common.basemodel.routing_manager import get_routing_config_shm from lightllm.utils.log_utils import init_logger from lightllm.server.metrics.manager import MetricClient from lightllm.server.io_struct import ( @@ -139,6 +140,9 @@ def __init__( self.latest_success_infer_time_mark = SharedInt(f"{get_unique_server_name()}_latest_success_infer_time_mark") self.latest_success_infer_time_mark.set_value(int(time.time())) + # Cache routing config for MoE expert routing data extraction + self._routing_shm = get_routing_config_shm() if args.enable_return_routed_experts else None + self.is_pause = False self.is_pause_cond = asyncio.Condition() @@ -769,6 +773,11 @@ async def recycle_resource_loop(self): for req_status in release_req_status: self.req_id_to_out_inf.pop(req_status.group_req_objs.group_req_id, None) for req in req_status.group_req_objs.shm_req_objs: + if hasattr(req, "shm_routing_data") and req.shm_routing_data is not None: + try: + req.close_routing_data_shm_array() + except Exception as e: + logger.debug(f"Failed to close routing data shm for req {req.request_id}: {e}") await self.shm_req_manager.async_put_back_req_obj(req) await self.shm_req_manager.async_release_req_index(req.index_in_shm_mem) await self._release_multimodal_resources(req_status.group_req_objs.multimodal_params) @@ -868,6 +877,15 @@ async def _handle_recv_generate_request(self, recv_obj: GenerateReqMeta): else: finish_status = FinishStatus(req.finish_status.status) + if self._routing_shm is not None: + _num_moe = int(self._routing_shm.arr[0]) + _topk = int(self._routing_shm.arr[1]) + _dtype_id = int(self._routing_shm.arr[2]) + if _num_moe > 0: + routing_meta = req.get_routing_metadata(_num_moe, _topk, dtype_id=_dtype_id) + if routing_meta is not None: + metadata["routed_experts"] = routing_meta + token_list.append((req_id, text, metadata, finish_status)) else: break diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 2b35fad05..66aeb6e95 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -19,6 +19,7 @@ from lightllm.utils.envs_utils import get_env_start_args from lightllm.server.pd_io_struct import NIXLDecodeNodeInfo from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient +from lightllm.common.basemodel import routing_manager as _routing_mgr logger = init_logger(__name__) @@ -113,6 +114,16 @@ def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_prefix_cache: return req_objs + def _extract_routing_data(self, req: "InferReq"): + if req.shm_req.shm_routing_num_tokens > 0: + return + mem_indexes = self.req_manager.req_to_token_indexs[req.req_idx][0 : req.cur_kv_len] + mgr = _routing_mgr.g_routing_capture_manager + routing_data = mgr.extract_routing_data(mem_indexes) + req.shm_req.create_routing_data_shm_array(mgr.num_moe_layers, req.cur_kv_len, mgr.topk, np_dtype=mgr.np_dtype) + req.shm_req.shm_routing_data.arr[:] = routing_data + req.shm_req.shm_routing_data.detach_shm() + def free_a_req_mem(self, free_token_index: List, req: "InferReq"): if self.radix_cache is None: free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][0 : req.cur_kv_len]) @@ -149,12 +160,18 @@ def _filter(self, finished_request_ids: List[int]): if len(finished_request_ids) == 0: return + need_routing_data = _routing_mgr.g_routing_capture_manager is not None + free_req_index = [] free_token_index = [] for request_id in finished_request_ids: req: InferReq = self.requests_mapping.pop(request_id) if self.args.diverse_mode: req.clear_master_slave_state() + + if need_routing_data: + self._extract_routing_data(req) + self.free_a_req_mem(free_token_index, req) free_req_index.append(req.req_idx) @@ -588,6 +605,8 @@ def handle( shm_req.shm_cur_output_len = self.output_len if finish_status.is_finished(): + if _routing_mgr.g_routing_capture_manager is not None: + g_infer_context._extract_routing_data(req_obj) shm_req.finish_token_index = shm_req.input_len + self.output_len - 1 shm_req.finish_status = req_obj.finish_status diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 2d0e4b14b..ff2ea8c21 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -44,6 +44,7 @@ from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token from lightllm.server.pd_io_struct import NIXLChunckedTransTaskRet +from lightllm.common.basemodel import routing_manager as _routing_mgr from lightllm.utils.torch_memory_saver_utils import MemoryTag from .multi_level_kv_cache import MultiLevelKvCacheModule from lightllm.server.io_struct import ( @@ -996,6 +997,18 @@ def _sample_and_scatter_token( ) return next_token_ids, next_token_ids_cpu, next_token_logprobs_cpu + def _flush_routing_to_kv_buffer(self, mem_indexes: torch.Tensor, microbatch_index: int = 0) -> None: + """Scatter captured routing data from capture buffer to KV-indexed GPU buffer. + + Must be called AFTER model.forward() completes. mem_indexes should be the + original (unpadded) tensor — either CPU or CUDA. + """ + if _routing_mgr.g_routing_capture_manager is not None and mem_indexes is not None: + if not mem_indexes.is_cuda: + mem_indexes = mem_indexes.cuda(non_blocking=True) + num_tokens = mem_indexes.shape[0] + _routing_mgr.g_routing_capture_manager.flush_to_kv_buffer(mem_indexes, num_tokens, microbatch_index) + def _dp_all_gather_prefill_and_decode_req_num( self, prefill_reqs: List[InferReq], decode_reqs: List[InferReq] ) -> Tuple[np.ndarray, np.ndarray]: diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index a8a5224eb..9f4443e48 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -109,6 +109,7 @@ def prefill_normal( model_input, run_reqs = prepare_prefill_inputs(prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) + self._flush_routing_to_kv_buffer(model_input.mem_indexes) _, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( logits=model_output.logits, b_req_idx=model_input.b_req_idx, @@ -148,6 +149,7 @@ def decode_normal( model_input, run_reqs = prepare_decode_inputs(decode_reqs) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) + self._flush_routing_to_kv_buffer(model_input.mem_indexes) _, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( logits=model_output.logits, b_req_idx=model_input.b_req_idx, @@ -186,6 +188,7 @@ def prefill_mtp( model_input, run_reqs = prepare_prefill_inputs(prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) + self._flush_routing_to_kv_buffer(model_input.mem_indexes) next_token_ids, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( logits=model_output.logits, b_req_idx=model_input.b_req_idx, @@ -236,6 +239,7 @@ def decode_mtp( with torch.cuda.stream(g_infer_context.get_overlap_stream()): b_mtp_index_cpu = model_input.b_mtp_index model_output = self.model.forward(model_input) + self._flush_routing_to_kv_buffer(model_input.mem_indexes) next_token_ids, next_token_logprobs = sample(model_output.logits, run_reqs, self.eos_id) # verify the next_token_ids b_req_mtp_start_loc = [index for index, mtp_index in enumerate(b_mtp_index_cpu) if mtp_index == 0] diff --git a/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py index 5a179cb62..ebc55b7ef 100644 --- a/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py @@ -40,8 +40,8 @@ def beam_prefill(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq ) with torch.cuda.stream(g_infer_context.get_overlap_stream()): - model_output = self.model.forward(model_input) + self._flush_routing_to_kv_buffer(model_input.mem_indexes) logits = model_output.logits batch_idx, run_reqs = self._diverse_copy( diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index bb0e848e7..f01e5fe93 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -145,6 +145,7 @@ def prefill_normal( run_reqs_num = len(run_reqs) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) + self._flush_routing_to_kv_buffer(model_input.mem_indexes) if run_reqs_num > 0: _, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( logits=model_output.logits[:run_reqs_num], @@ -188,6 +189,7 @@ def decode_normal(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq run_reqs_num = len(run_reqs) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) + self._flush_routing_to_kv_buffer(model_input.mem_indexes) if run_reqs_num > 0: _, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( logits=model_output.logits[:run_reqs_num], @@ -236,6 +238,8 @@ def prefill_overlap(self, event_pack: OverlapEventPack, prefill_reqs: List[Infer with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output0, model_output1 = self.model.microbatch_overlap_prefill(model_input0, model_input1) + self._flush_routing_to_kv_buffer(model_input0.mem_indexes, microbatch_index=0) + self._flush_routing_to_kv_buffer(model_input1.mem_indexes, microbatch_index=1) logits0 = model_output0.logits logits1 = model_output1.logits @@ -305,6 +309,8 @@ def decode_overlap(self, event_pack: OverlapEventPack, decode_reqs: List[InferRe with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output0, model_output1 = self.model.microbatch_overlap_decode(model_input0, model_input1) + self._flush_routing_to_kv_buffer(model_input0.mem_indexes, microbatch_index=0) + self._flush_routing_to_kv_buffer(model_input1.mem_indexes, microbatch_index=1) logits0 = model_output0.logits logits1 = model_output1.logits @@ -359,6 +365,7 @@ def prefill_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq] req_num = len(run_reqs) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output: ModelOutput = self.model.forward(model_input) + self._flush_routing_to_kv_buffer(model_input.mem_indexes) b_has_out_cpu = model_input.b_prefill_has_output_cpu[0:req_num] logits = model_output.logits[0:req_num, :] b_req_idx = model_input.b_req_idx[0:req_num] @@ -421,6 +428,7 @@ def decode_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq]): with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) + self._flush_routing_to_kv_buffer(model_input.mem_indexes) mtp_accept_len, b_req_mtp_start_loc, next_token_ids = None, None, None if req_num > 0: logits = model_output.logits[0:req_num, :] @@ -629,6 +637,8 @@ def prefill_overlap_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[I ) = padded_overlap_prepare_prefill_inputs(prefill_reqs) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output0, model_output1 = self.model.microbatch_overlap_prefill(model_input0, model_input1) + self._flush_routing_to_kv_buffer(model_input0.mem_indexes, microbatch_index=0) + self._flush_routing_to_kv_buffer(model_input1.mem_indexes, microbatch_index=1) logits0 = model_output0.logits logits1 = model_output1.logits req_num0, req_num1 = len(run_reqs0), len(run_reqs1) @@ -726,8 +736,9 @@ def decode_overlap_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[Inf b_mtp_index_cpu0 = model_input0.b_mtp_index b_mtp_index_cpu1 = model_input1.b_mtp_index with torch.cuda.stream(g_infer_context.get_overlap_stream()): - model_output0, model_output1 = self.model.microbatch_overlap_decode(model_input0, model_input1) + self._flush_routing_to_kv_buffer(model_input0.mem_indexes, microbatch_index=0) + self._flush_routing_to_kv_buffer(model_input1.mem_indexes, microbatch_index=1) logits0 = model_output0.logits logits1 = model_output1.logits run_reqs = run_reqs0 + run_reqs1 diff --git a/test/test_api/test_r3.py b/test/test_api/test_r3.py new file mode 100644 index 000000000..00f34c489 --- /dev/null +++ b/test/test_api/test_r3.py @@ -0,0 +1,99 @@ +import sys +import argparse +import requests +import base64 +import numpy as np + + +def test_routing_export(url: str = "http://localhost:8000"): + print(f"Testing routing export at {url}") + print("-" * 50) + + try: + response = requests.post( + f"{url}/generate", + json={ + "inputs": "What is the capital of France? What is the capital of France?", + "parameters": { + "max_new_tokens": 50, + "return_routed_experts": True, + "repetition_penalty": 1.0, + }, + }, + timeout=60, + ) + except requests.exceptions.ConnectionError: + print(f"ERROR: Cannot connect to server at {url}") + print("Make sure the LightLLM server is running with --enable_return_routed_experts") + return False + except requests.exceptions.Timeout: + print("ERROR: Request timed out") + return False + + print(f"Status: {response.status_code}") + + if response.status_code != 200: + print(f"ERROR: Request failed with status {response.status_code}") + print(f"Response: {response.text}") + return False + + res = response.json() + print(f"Generated text: {res.get('generated_text', 'N/A')[:100]}...") + + if "routed_experts" not in res or not res["routed_experts"]: + print("\nWARNING: No routed_experts in response.") + print("This could mean:") + print(" - The model is not a MoE model") + print(" - The server was not started with --enable_return_routed_experts") + print(" - The routing capture manager was not initialized") + return False + + routing_info = res["routed_experts"] + shape = routing_info["shape"] + dtype_str = routing_info["dtype"] + dtype = np.dtype(dtype_str) + data = base64.b64decode(routing_info["data"]) + routing_array = np.frombuffer(data, dtype=dtype).reshape(shape) + + print(f"\n{'=' * 50}") + print("ROUTING CAPTURE SUCCESS!") + print(f"{'=' * 50}") + print(f"Shape: {shape}") + print(f"Dtype: {dtype}") + print(f"Num MoE layers: {shape[0]}") + print(f"Num tokens: {shape[1]}") + print(f"Top-K: {shape[2]}") + + # Verify dtype is int8 (for models with ≤127 experts) or int16 + if dtype_str not in ("int8", "int16"): + print(f"\nERROR: Expected dtype int8 or int16, got {dtype_str}") + print("This suggests dtype optimization is not working correctly.") + return False + print(f"\nDtype check PASSED: {dtype_str} (compact representation)") + + # Compute payload size savings + int32_size = np.prod(shape) * 4 + actual_size = len(data) + savings = (1 - actual_size / int32_size) * 100 + print(f"Payload: {actual_size} bytes (vs {int32_size} bytes with int32, {savings:.0f}% smaller)") + + print(f"\nSample routing (first layer, first 5 tokens):") + num_tokens_to_show = shape[1] + for i in range(num_tokens_to_show): + print(f" Token {i}: experts {routing_array[0, i, :].tolist()}") + + if np.all(routing_array == 0): + print("\nWARNING: All routing data is zeros. Capture may not be working correctly.") + return False + + print("\nTest PASSED!") + return True + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Test R3 routing export feature") + parser.add_argument("--url", default="http://localhost:8000", help="Server URL") + args = parser.parse_args() + + success = test_routing_export(args.url) + sys.exit(0 if success else 1) diff --git a/unit_tests/__init__.py b/unit_tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/unit_tests/common/__init__.py b/unit_tests/common/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/unit_tests/common/basemodel/__init__.py b/unit_tests/common/basemodel/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/unit_tests/common/basemodel/test_routing_capture_manager.py b/unit_tests/common/basemodel/test_routing_capture_manager.py new file mode 100644 index 000000000..dcc010b37 --- /dev/null +++ b/unit_tests/common/basemodel/test_routing_capture_manager.py @@ -0,0 +1,219 @@ +import torch +import numpy as np + + +class TestRoutingCaptureManager: + def test_capture_and_extract_basic(self): + """Test the core pipeline: capture → flush_to_kv_buffer → extract_from_gpu.""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + manager = RoutingCaptureManager( + num_moe_layers=4, + topk=8, + num_experts=64, + kv_cache_size=1024, + max_capture_tokens=64, + ) + + # Simulate a batch of 10 tokens at KV-cache positions [100..109] + mem_indexes = torch.arange(100, 110, device="cuda") + + # Capture routing for each MoE layer (writes to capture buffer) + for layer_idx in range(4): + topk_ids = torch.randint(0, 64, (10, 8), device="cuda") + manager.capture(moe_layer_index=layer_idx, topk_ids=topk_ids, microbatch_index=0) + + # Flush from capture buffer to KV-indexed gpu_kv_buffer + manager.flush_to_kv_buffer(mem_indexes, num_tokens=10, microbatch_index=0) + + # Extract for those same KV-cache positions + result = manager.extract_from_gpu(mem_indexes) + assert result.shape == (4, 10, 8) + assert result.dtype == np.int8 + + def test_capture_writes_to_correct_kv_positions(self): + """Verify that captured data lands in the right KV-cache positions after flush.""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + manager = RoutingCaptureManager( + num_moe_layers=2, + topk=4, + num_experts=32, + kv_cache_size=256, + max_capture_tokens=16, + ) + + # Use non-contiguous mem_indexes to simulate real KV-cache + mem_indexes = torch.tensor([10, 50, 200], device="cuda") + + topk_ids = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], device="cuda") + manager.capture(moe_layer_index=0, topk_ids=topk_ids, microbatch_index=0) + + topk_ids_layer1 = topk_ids + 20 + manager.capture(moe_layer_index=1, topk_ids=topk_ids_layer1, microbatch_index=0) + + # Flush to KV positions + manager.flush_to_kv_buffer(mem_indexes, num_tokens=3, microbatch_index=0) + + # Extract and verify + result = manager.extract_from_gpu(mem_indexes) + assert result.shape == (2, 3, 4) + np.testing.assert_array_equal(result[0], topk_ids.cpu().numpy()) + np.testing.assert_array_equal(result[1], topk_ids_layer1.cpu().numpy()) + + def test_microbatch_isolation(self): + """Two microbatches writing to different KV positions don't interfere.""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + manager = RoutingCaptureManager( + num_moe_layers=1, + topk=4, + num_experts=32, + kv_cache_size=256, + max_capture_tokens=16, + ) + + # Microbatch 0: positions [10, 11] + mem0 = torch.tensor([10, 11], device="cuda") + ids_0 = torch.ones((2, 4), dtype=torch.int64, device="cuda") + manager.capture(moe_layer_index=0, topk_ids=ids_0, microbatch_index=0) + + # Microbatch 1: positions [20, 21] + mem1 = torch.tensor([20, 21], device="cuda") + ids_1 = torch.ones((2, 4), dtype=torch.int64, device="cuda") * 2 + manager.capture(moe_layer_index=0, topk_ids=ids_1, microbatch_index=1) + + # Flush each microbatch to different KV positions + manager.flush_to_kv_buffer(mem0, num_tokens=2, microbatch_index=0) + manager.flush_to_kv_buffer(mem1, num_tokens=2, microbatch_index=1) + + # Extract microbatch 0 + result0 = manager.extract_from_gpu(mem0) + assert result0.shape == (1, 2, 4) + assert result0[0, 0, 0] == 1 + + # Extract microbatch 1 + result1 = manager.extract_from_gpu(mem1) + assert result1[0, 0, 0] == 2 + + def test_dtype_selection_int8(self): + """Models with ≤127 experts use int8.""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + manager = RoutingCaptureManager( + num_moe_layers=1, + topk=2, + num_experts=64, + kv_cache_size=128, + max_capture_tokens=16, + ) + assert manager.dtype == torch.int8 + assert manager.np_dtype == np.int8 + assert manager.dtype_id == 1 + + def test_dtype_selection_int16(self): + """Models with >127 experts use int16.""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + manager = RoutingCaptureManager( + num_moe_layers=1, + topk=2, + num_experts=256, + kv_cache_size=128, + max_capture_tokens=16, + ) + assert manager.dtype == torch.int16 + assert manager.np_dtype == np.int16 + assert manager.dtype_id == 2 + + def test_extract_preserves_values(self): + """Extracted values exactly match what was captured, no dtype truncation.""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + manager = RoutingCaptureManager( + num_moe_layers=1, + topk=4, + num_experts=64, + kv_cache_size=64, + max_capture_tokens=16, + ) + + mem_indexes = torch.tensor([0, 1, 2], device="cuda") + + topk_ids = torch.tensor([[10, 20, 30, 40], [50, 60, 63, 1], [0, 5, 127, 3]], device="cuda") + manager.capture(moe_layer_index=0, topk_ids=topk_ids, microbatch_index=0) + + # Flush then extract + manager.flush_to_kv_buffer(mem_indexes, num_tokens=3, microbatch_index=0) + result = manager.extract_from_gpu(mem_indexes) + expected = topk_ids.cpu().numpy().astype(np.int8) + np.testing.assert_array_equal(result[0], expected) + + def test_gpu_kv_buffer_shape(self): + """Buffer shape is (num_moe_layers, kv_cache_size, topk).""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + # 127 experts fits in int8 (max value 127) + manager = RoutingCaptureManager( + num_moe_layers=48, + topk=8, + num_experts=127, + kv_cache_size=2048, + max_capture_tokens=256, + ) + assert manager.gpu_kv_buffer.shape == (48, 2048, 8) + assert manager.gpu_kv_buffer.dtype == torch.int8 + assert manager.gpu_kv_buffer.device.type == "cuda" + + # 128 experts requires int16 + manager2 = RoutingCaptureManager( + num_moe_layers=48, + topk=8, + num_experts=128, + kv_cache_size=2048, + max_capture_tokens=256, + ) + assert manager2.gpu_kv_buffer.dtype == torch.int16 + + def test_partial_token_capture(self): + """capture() only writes num_tokens rows to the buffer.""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + manager = RoutingCaptureManager( + num_moe_layers=1, + topk=2, + num_experts=32, + kv_cache_size=128, + max_capture_tokens=16, + ) + + # Capture only 3 tokens, flush to 5 KV positions (first 3 get data) + mem_indexes = torch.tensor([10, 11, 12, 13, 14], device="cuda") + + topk_ids = torch.tensor([[1, 2], [3, 4], [5, 6]], device="cuda") # only 3 tokens + manager.capture(moe_layer_index=0, topk_ids=topk_ids, microbatch_index=0) + + # Flush only the 3 captured tokens + manager.flush_to_kv_buffer(mem_indexes[:3], num_tokens=3, microbatch_index=0) + + # Positions 10-12 should have data, 13-14 should be zeros (from init) + result_written = manager.extract_from_gpu(mem_indexes[:3]) + np.testing.assert_array_equal(result_written[0], topk_ids.cpu().numpy().astype(np.int8)) + + result_unwritten = manager.extract_from_gpu(mem_indexes[3:]) + np.testing.assert_array_equal(result_unwritten[0], np.zeros((2, 2), dtype=np.int8)) + + def test_capture_buffer_shape(self): + """Capture buffer has correct shape (max_tokens, num_moe_layers, topk).""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + manager = RoutingCaptureManager( + num_moe_layers=4, + topk=8, + num_experts=64, + kv_cache_size=1024, + max_capture_tokens=256, + ) + assert manager._capture_buffer[0].shape == (256, 4, 8) + assert manager._capture_buffer[1].shape == (256, 4, 8) + assert manager._capture_buffer[0].dtype == torch.int8 From 3569d53a6acad7d886638679f56b355f2fc5cddf Mon Sep 17 00:00:00 2001 From: sufubao Date: Mon, 9 Feb 2026 11:20:43 +0000 Subject: [PATCH 63/71] fix --- lightllm/server/router/model_infer/mode_backend/base_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index ff2ea8c21..70b0ec9eb 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -1007,7 +1007,7 @@ def _flush_routing_to_kv_buffer(self, mem_indexes: torch.Tensor, microbatch_inde if not mem_indexes.is_cuda: mem_indexes = mem_indexes.cuda(non_blocking=True) num_tokens = mem_indexes.shape[0] - _routing_mgr.g_routing_capture_manager.flush_to_kv_buffer(mem_indexes, num_tokens, microbatch_index) + _routing_mgr.g_routing_capture_manager.flush_to_routing_buffer(mem_indexes, num_tokens, microbatch_index) def _dp_all_gather_prefill_and_decode_req_num( self, prefill_reqs: List[InferReq], decode_reqs: List[InferReq] From fe54253e8de8aa6b0b4bacc8a9ce2b0240f5d255 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 9 Feb 2026 11:29:05 +0000 Subject: [PATCH 64/71] add node-id for env_utils --- lightllm/utils/envs_utils.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index 3a0e28bcb..59315108a 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -11,15 +11,18 @@ def set_unique_server_name(args): if args.run_mode == "pd_master": - os.environ["LIGHTLLM_UNIQUE_SERVICE_NAME_ID"] = str(args.port) + "_pd_master" + os.environ[f"LIGHTLLM_UNIQUE_SERVICE_NAME_ID_{args.pd_node_id}"] = str(args.port) + "_pd_master" else: - os.environ["LIGHTLLM_UNIQUE_SERVICE_NAME_ID"] = str(args.nccl_port) + "_" + str(args.node_rank) + os.environ[f"LIGHTLLM_UNIQUE_SERVICE_NAME_ID_{args.pd_node_id}"] = ( + str(args.nccl_port) + "_" + str(args.node_rank) + ) return @lru_cache(maxsize=None) def get_unique_server_name(): - service_uni_name = os.getenv("LIGHTLLM_UNIQUE_SERVICE_NAME_ID") + args = get_env_start_args() + service_uni_name = os.getenv(f"LIGHTLLM_UNIQUE_SERVICE_NAME_ID_{args.pd_node_id}") return service_uni_name @@ -33,7 +36,7 @@ def set_env_start_args(args): set_cuda_arch(args) if not isinstance(args, dict): args = vars(args) - os.environ["LIGHTLLM_START_ARGS"] = json.dumps(args) + os.environ[f"LIGHTLLM_START_ARGS_{args.pd_node_id}"] = json.dumps(args) return @@ -41,7 +44,8 @@ def set_env_start_args(args): def get_env_start_args(): from lightllm.server.core.objs.start_args_type import StartArgs - start_args: StartArgs = json.loads(os.environ["LIGHTLLM_START_ARGS"]) + args = get_env_start_args() + start_args: StartArgs = json.loads(os.environ[f"LIGHTLLM_START_ARGS_{args.pd_node_id}"]) start_args: StartArgs = EasyDict(start_args) return start_args From 8eead2b14c492d928c6488560b7e07765ea2b516 Mon Sep 17 00:00:00 2001 From: sufubao Date: Mon, 9 Feb 2026 11:58:52 +0000 Subject: [PATCH 65/71] Revert "add node-id for env_utils" This reverts commit fe54253e8de8aa6b0b4bacc8a9ce2b0240f5d255. --- lightllm/utils/envs_utils.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index 59315108a..3a0e28bcb 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -11,18 +11,15 @@ def set_unique_server_name(args): if args.run_mode == "pd_master": - os.environ[f"LIGHTLLM_UNIQUE_SERVICE_NAME_ID_{args.pd_node_id}"] = str(args.port) + "_pd_master" + os.environ["LIGHTLLM_UNIQUE_SERVICE_NAME_ID"] = str(args.port) + "_pd_master" else: - os.environ[f"LIGHTLLM_UNIQUE_SERVICE_NAME_ID_{args.pd_node_id}"] = ( - str(args.nccl_port) + "_" + str(args.node_rank) - ) + os.environ["LIGHTLLM_UNIQUE_SERVICE_NAME_ID"] = str(args.nccl_port) + "_" + str(args.node_rank) return @lru_cache(maxsize=None) def get_unique_server_name(): - args = get_env_start_args() - service_uni_name = os.getenv(f"LIGHTLLM_UNIQUE_SERVICE_NAME_ID_{args.pd_node_id}") + service_uni_name = os.getenv("LIGHTLLM_UNIQUE_SERVICE_NAME_ID") return service_uni_name @@ -36,7 +33,7 @@ def set_env_start_args(args): set_cuda_arch(args) if not isinstance(args, dict): args = vars(args) - os.environ[f"LIGHTLLM_START_ARGS_{args.pd_node_id}"] = json.dumps(args) + os.environ["LIGHTLLM_START_ARGS"] = json.dumps(args) return @@ -44,8 +41,7 @@ def set_env_start_args(args): def get_env_start_args(): from lightllm.server.core.objs.start_args_type import StartArgs - args = get_env_start_args() - start_args: StartArgs = json.loads(os.environ[f"LIGHTLLM_START_ARGS_{args.pd_node_id}"]) + start_args: StartArgs = json.loads(os.environ["LIGHTLLM_START_ARGS"]) start_args: StartArgs = EasyDict(start_args) return start_args From 27f9e87d5edf971c4ffe9f7f7b969b672fd07684 Mon Sep 17 00:00:00 2001 From: sufubao Date: Mon, 9 Feb 2026 11:58:58 +0000 Subject: [PATCH 66/71] Revert "port random generate" This reverts commit 02078ad1ded3babcd1b3c8d152e83ee1238ba7d9. --- lightllm/utils/net_utils.py | 33 +++++++++++---------------------- 1 file changed, 11 insertions(+), 22 deletions(-) diff --git a/lightllm/utils/net_utils.py b/lightllm/utils/net_utils.py index 51ec443d1..486414e88 100644 --- a/lightllm/utils/net_utils.py +++ b/lightllm/utils/net_utils.py @@ -2,6 +2,7 @@ import subprocess import ipaddress import random +import portpicker from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -12,36 +13,24 @@ def alloc_can_use_network_port(num=3, used_nccl_ports=None, from_port_num=10000) used_nccl_ports = [] port_list = [] - locked_sockets = [] - used_set = set(used_nccl_ports) - max_port = 65535 max_attempts = num * 50 # Allow more attempts to find ports in range for _ in range(max_attempts): if len(port_list) >= num: break - # 在 [from_port_num, 65535] 范围内随机选端口,避免多进程同时启动时分配到相同端口 - port = random.randint(from_port_num, max_port) - if port in used_set: - continue - - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) try: - sock.bind(("", port)) - port_list.append(port) - used_set.add(port) - locked_sockets.append(sock) - logger.debug(f"Allocated and locked port: {port}") - - except OSError as e: - sock.close() - logger.warning(f"Failed to bind port: {e}") - continue + port = portpicker.pick_unused_port() - for sock in locked_sockets: - sock.close() + if port >= from_port_num and port not in used_nccl_ports: + port_list.append(port) + logger.debug(f"Allocated port: {port}") + else: + logger.debug(f"Port {port} is out of range or in used_nccl_ports, skipping") + + except Exception as e: + logger.warning(f"Failed to allocate port: {e}") + continue if len(port_list) < num: logger.error(f"Failed to allocate {num} ports, only got {len(port_list)}") From 6fa8f74e473632f598da4ec96b1d841a374dd0b2 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 9 Feb 2026 13:03:14 +0000 Subject: [PATCH 67/71] add assert none --- lightllm/utils/envs_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index 3a0e28bcb..03816d3ab 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -13,6 +13,7 @@ def set_unique_server_name(args): if args.run_mode == "pd_master": os.environ["LIGHTLLM_UNIQUE_SERVICE_NAME_ID"] = str(args.port) + "_pd_master" else: + assert str(args.nccl_port) != "None", "nccl_port is not set" os.environ["LIGHTLLM_UNIQUE_SERVICE_NAME_ID"] = str(args.nccl_port) + "_" + str(args.node_rank) return From bf83078ae3e29547ccc9d823c6864102f119ed96 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Tue, 10 Feb 2026 05:00:10 +0000 Subject: [PATCH 68/71] set_unique_server_name --- lightllm/server/api_start.py | 3 +-- lightllm/utils/envs_utils.py | 2 ++ 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index afe199d04..58dac941b 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -91,6 +91,7 @@ def _set_envs_and_config(args: StartArgs): def _launch_subprocesses(args: StartArgs): _set_envs_and_config(args) + set_unique_server_name(args) if not args.disable_shm_warning: check_recommended_shm_size(args) @@ -290,8 +291,6 @@ def _launch_subprocesses(args: StartArgs): args.pd_p_allowed_port_min = 20000 args.pd_p_allowed_port_max = 30000 - set_unique_server_name(args) - # p d 分离模式下,decode节点的调度间隙是0 if args.run_mode == "decode": args.router_max_wait_tokens = 0 diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index 03816d3ab..a702a465b 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -21,6 +21,8 @@ def set_unique_server_name(args): @lru_cache(maxsize=None) def get_unique_server_name(): service_uni_name = os.getenv("LIGHTLLM_UNIQUE_SERVICE_NAME_ID") + assert "None" not in service_uni_name, "service_uni_name is not set" + return service_uni_name From 3eab5a746162ab6d50172df61791f614848d309d Mon Sep 17 00:00:00 2001 From: sufubao Date: Tue, 10 Feb 2026 05:45:48 +0000 Subject: [PATCH 69/71] fix return_routed_experts --- lightllm/common/basemodel/routing_manager.py | 12 +++--- lightllm/server/api_lightllm.py | 6 ++- lightllm/server/core/objs/sampling_params.py | 39 +++++++++++--------- 3 files changed, 33 insertions(+), 24 deletions(-) diff --git a/lightllm/common/basemodel/routing_manager.py b/lightllm/common/basemodel/routing_manager.py index 9b8c09d8c..01caa3666 100644 --- a/lightllm/common/basemodel/routing_manager.py +++ b/lightllm/common/basemodel/routing_manager.py @@ -14,7 +14,7 @@ def routing_dtype_id_to_np(dtype_id: int): if dtype_id == 1: - return np.int8 + return np.uint8 elif dtype_id == 2: return np.int16 return np.int32 @@ -39,8 +39,8 @@ def __init__( self.num_experts = num_experts self.kv_cache_size = kv_cache_size - self.dtype = torch.int8 if num_experts <= 127 else torch.int16 - dtype_bytes = 1 if self.dtype == torch.int8 else 2 + self.dtype = torch.uint8 if num_experts <= 255 else torch.int16 + dtype_bytes = 1 if self.dtype == torch.uint8 else 2 # Shape: (num_moe_layers, kv_cache_size, topk) — on CPU to save GPU memory. # Written after forward() via flush_to_routing_buffer(), read on request finish. @@ -57,7 +57,7 @@ def __init__( torch.zeros((max_capture_tokens, num_moe_layers, topk), dtype=self.dtype, device="cuda") for _ in range(2) ] - dtype_name = "int8" if self.dtype == torch.int8 else "int16" + dtype_name = "uint8" if self.dtype == torch.uint8 else "int16" logger.info( f"RoutingCaptureManager initialized: {num_moe_layers} MoE layers, topk={topk}, " f"routing_buffer(cpu)={routing_buffer_size / 1024 / 1024:.2f}MB, " @@ -66,11 +66,11 @@ def __init__( @property def np_dtype(self): - return np.int8 if self.dtype == torch.int8 else np.int16 + return np.uint8 if self.dtype == torch.uint8 else np.int16 @property def dtype_id(self) -> int: - return 1 if self.dtype == torch.int8 else 2 + return 1 if self.dtype == torch.uint8 else 2 def capture(self, moe_layer_index: int, topk_ids: torch.Tensor, microbatch_index: int = 0) -> None: num_tokens = topk_ids.shape[0] diff --git a/lightllm/server/api_lightllm.py b/lightllm/server/api_lightllm.py index 5abd90815..d15bec648 100644 --- a/lightllm/server/api_lightllm.py +++ b/lightllm/server/api_lightllm.py @@ -35,6 +35,9 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana prompt = request_dict.pop("inputs") sample_params_dict = request_dict["parameters"] return_details = sample_params_dict.pop("return_details", False) + return_routed_experts = sample_params_dict.pop( + "return_routed_experts", httpserver_manager.args.enable_return_routed_experts + ) sampling_params = SamplingParams() sampling_params.init(tokenizer=httpserver_manager.tokenizer, **sample_params_dict) sampling_params.verify() @@ -105,7 +108,7 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana ret["prompt_logprobs"] = prompt_logprobs if input_usage is not None: ret["input_usage"] = input_usage - if routed_experts_data is not None: + if return_routed_experts and routed_experts_data is not None: ret["routed_experts"] = routed_experts_data return Response(content=json.dumps(ret, ensure_ascii=False).encode("utf-8")) @@ -117,6 +120,7 @@ async def lightllm_generate_stream(request: Request, httpserver_manager: HttpSer prompt = request_dict.pop("inputs") sample_params_dict = request_dict["parameters"] _ = sample_params_dict.pop("return_details", False) + _ = sample_params_dict.pop("return_routed_experts", None) sampling_params = SamplingParams() sampling_params.init(tokenizer=httpserver_manager.tokenizer, **sample_params_dict) sampling_params.verify() diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py index 93447830b..1c0642862 100644 --- a/lightllm/server/core/objs/sampling_params.py +++ b/lightllm/server/core/objs/sampling_params.py @@ -357,23 +357,28 @@ class SamplingParams(ctypes.Structure): def init(self, tokenizer, **kwargs): super().__init__() - self.best_of = kwargs.get("best_of", 1) - self.n = kwargs.get("n", self.best_of) - self.do_sample = kwargs.get("do_sample", SamplingParams._do_sample) - self.presence_penalty = kwargs.get("presence_penalty", SamplingParams._presence_penalty) - self.frequency_penalty = kwargs.get("frequency_penalty", SamplingParams._frequency_penalty) - self.repetition_penalty = kwargs.get("repetition_penalty", SamplingParams._repetition_penalty) - self.temperature = kwargs.get("temperature", SamplingParams._temperature) - self.top_p = kwargs.get("top_p", SamplingParams._top_p) - self.top_k = kwargs.get("top_k", SamplingParams._top_k) - self.ignore_eos = kwargs.get("ignore_eos", False) - self.min_pixels = kwargs.get("min_pixels", -1) - self.max_pixels = kwargs.get("max_pixels", -1) - self.max_new_tokens = kwargs.get("max_new_tokens", 16) - self.min_new_tokens = kwargs.get("min_new_tokens", 1) - self.input_penalty = kwargs.get("input_penalty", DEFAULT_INPUT_PENALTY) - self.group_request_id = kwargs.get("group_request_id", -1) - self.suggested_dp_index = kwargs.get("suggested_dp_index", -1) + + def _get(key, default): + v = kwargs.get(key) + return v if v is not None else default + + self.best_of = _get("best_of", 1) + self.n = _get("n", self.best_of) + self.do_sample = _get("do_sample", SamplingParams._do_sample) + self.presence_penalty = _get("presence_penalty", SamplingParams._presence_penalty) + self.frequency_penalty = _get("frequency_penalty", SamplingParams._frequency_penalty) + self.repetition_penalty = _get("repetition_penalty", SamplingParams._repetition_penalty) + self.temperature = _get("temperature", SamplingParams._temperature) + self.top_p = _get("top_p", SamplingParams._top_p) + self.top_k = _get("top_k", SamplingParams._top_k) + self.ignore_eos = _get("ignore_eos", False) + self.min_pixels = _get("min_pixels", -1) + self.max_pixels = _get("max_pixels", -1) + self.max_new_tokens = _get("max_new_tokens", 16) + self.min_new_tokens = _get("min_new_tokens", 1) + self.input_penalty = _get("input_penalty", DEFAULT_INPUT_PENALTY) + self.group_request_id = _get("group_request_id", -1) + self.suggested_dp_index = _get("suggested_dp_index", -1) self.skip_special_tokens = kwargs.get("skip_special_tokens", SKIP_SPECIAL_TOKENS) self.disable_prompt_cache = kwargs.get("disable_prompt_cache", False) From 14cfc9511715e6f72b56d17279f8a41e5fb67ec6 Mon Sep 17 00:00:00 2001 From: sufubao Date: Tue, 10 Feb 2026 11:04:06 +0000 Subject: [PATCH 70/71] fix r3 --- lightllm/common/basemodel/routing_manager.py | 41 ++----------------- .../server/core/objs/py_sampling_params.py | 19 +++++---- lightllm/server/core/objs/req.py | 19 +++++---- lightllm/server/core/objs/sampling_params.py | 19 +++++---- test/test_api/test_r3.py | 19 +++------ 5 files changed, 45 insertions(+), 72 deletions(-) diff --git a/lightllm/common/basemodel/routing_manager.py b/lightllm/common/basemodel/routing_manager.py index 01caa3666..77b611130 100644 --- a/lightllm/common/basemodel/routing_manager.py +++ b/lightllm/common/basemodel/routing_manager.py @@ -7,7 +7,6 @@ from lightllm.utils.dist_utils import get_current_rank_in_dp from lightllm.server.router.dynamic_prompt.shared_arr import SharedArray from lightllm.utils.envs_utils import get_unique_server_name -from lightllm.utils.shm_utils import create_or_link_shm logger = init_logger(__name__) @@ -42,11 +41,11 @@ def __init__( self.dtype = torch.uint8 if num_experts <= 255 else torch.int16 dtype_bytes = 1 if self.dtype == torch.uint8 else 2 - # Shape: (num_moe_layers, kv_cache_size, topk) — on CPU to save GPU memory. + # Shape: (kv_cache_size, num_moe_layers, topk) — on CPU to save GPU memory. # Written after forward() via flush_to_routing_buffer(), read on request finish. routing_buffer_size = num_moe_layers * kv_cache_size * topk * dtype_bytes self.routing_buffer = torch.zeros( - (num_moe_layers, kv_cache_size, topk), + (kv_cache_size, num_moe_layers, topk), dtype=self.dtype, device="cpu", ) @@ -78,12 +77,11 @@ def capture(self, moe_layer_index: int, topk_ids: torch.Tensor, microbatch_index def flush_to_routing_buffer(self, mem_indexes: torch.Tensor, num_tokens: int, microbatch_index: int = 0) -> None: buf = self._capture_buffer[microbatch_index][:num_tokens] # (num_tokens, num_moe_layers, topk) - buf_t = buf.permute(1, 0, 2).cpu() - self.routing_buffer[:, mem_indexes[:num_tokens].cpu(), :] = buf_t + self.routing_buffer[mem_indexes[:num_tokens].cpu(), :, :] = buf.cpu() def extract_routing_data(self, mem_indexes: torch.Tensor) -> np.ndarray: cpu_indexes = mem_indexes.cpu() if mem_indexes.is_cuda else mem_indexes - return self.routing_buffer[:, cpu_indexes, :].numpy() + return self.routing_buffer[cpu_indexes, :, :].numpy() g_routing_capture_manager: Optional[RoutingCaptureManager] = None @@ -107,27 +105,6 @@ def create_routing_capture_manager( ) -def preallocate_routing_shm_pool(max_req_num: int, num_moe_layers: int, max_tokens: int, topk: int, np_dtype) -> None: - """Pre-allocate POSIX SHM segments for all request slots. - - Each segment is sized for the maximum possible routing data so it can be - reused across requests without create/destroy overhead. - """ - dtype_bytes = np.dtype(np_dtype).itemsize - segment_size = num_moe_layers * max_tokens * topk * dtype_bytes - service_name = get_unique_server_name() - - for i in range(max_req_num): - name = f"{service_name}_shm_routing_{i}" - shm = create_or_link_shm(name, segment_size, auto_cleanup=True) - shm.close() # close handle; SHM persists in /dev/shm - - logger.info( - f"Pre-allocated {max_req_num} routing SHM segments, " - f"each {segment_size / 1024:.1f} KB (total {max_req_num * segment_size / 1024 / 1024:.1f} MB)" - ) - - def cleanup_routing_shm_pool() -> None: """Unlink all pre-allocated routing SHM segments. Called at server shutdown.""" try: @@ -197,7 +174,6 @@ def init_routing_capture(model, num_moe_layers: int) -> None: ) mgr = g_routing_capture_manager - np_dtype = mgr.np_dtype dtype_id = mgr.dtype_id max_req_total_len = args.max_req_total_len @@ -212,13 +188,4 @@ def init_routing_capture(model, num_moe_layers: int) -> None: f"Shared routing config set: num_moe_layers={num_moe_layers}, topk={topk}, " f"dtype_id={dtype_id}, max_tokens={max_req_total_len}" ) - - preallocate_routing_shm_pool( - max_req_num=args.running_max_req_size, - num_moe_layers=num_moe_layers, - max_tokens=max_req_total_len, - topk=topk, - np_dtype=np_dtype, - ) - atexit.register(cleanup_routing_shm_pool) diff --git a/lightllm/server/core/objs/py_sampling_params.py b/lightllm/server/core/objs/py_sampling_params.py index 9194a235d..08921317e 100644 --- a/lightllm/server/core/objs/py_sampling_params.py +++ b/lightllm/server/core/objs/py_sampling_params.py @@ -112,13 +112,18 @@ def __init__( def load_generation_cfg(cls, weight_dir): try: generation_cfg = GenerationConfig.from_pretrained(weight_dir, trust_remote_code=True).to_dict() - cls._do_sample = generation_cfg.get("do_sample", False) - cls._presence_penalty = generation_cfg.get("presence_penalty", 0.0) - cls._frequency_penalty = generation_cfg.get("frequency_penalty", 0.0) - cls._repetition_penalty = generation_cfg.get("repetition_penalty", 1.0) - cls._temperature = generation_cfg.get("temperature", 1.0) - cls._top_p = generation_cfg.get("top_p", 1.0) - cls._top_k = generation_cfg.get("top_k", -1) + + def _cfg(key, default): + v = generation_cfg.get(key) + return v if v is not None else default + + cls._do_sample = _cfg("do_sample", False) + cls._presence_penalty = _cfg("presence_penalty", 0.0) + cls._frequency_penalty = _cfg("frequency_penalty", 0.0) + cls._repetition_penalty = _cfg("repetition_penalty", 1.0) + cls._temperature = _cfg("temperature", 1.0) + cls._top_p = _cfg("top_p", 1.0) + cls._top_k = _cfg("top_k", -1) cls._stop_sequences = generation_cfg.get("stop", None) except: pass diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py index 4a33b659b..5c7e56843 100644 --- a/lightllm/server/core/objs/req.py +++ b/lightllm/server/core/objs/req.py @@ -236,17 +236,20 @@ def link_logprobs_shm_array(self): return def create_routing_data_shm_array(self, num_moe_layers: int, num_tokens: int, topk: int, np_dtype=np.int8): - """Link to a pre-allocated routing SHM and create a numpy view for the actual data shape.""" + """Create routing SHM at actual size (on-demand, not pre-allocated). + + Uses smart mode: links if same-sized SHM exists, otherwise creates new. + """ service_uni_name = get_unique_server_name() name = f"{service_uni_name}_shm_routing_{self.index_in_shm_mem}" - shape = (num_moe_layers, num_tokens, topk) + shape = (num_tokens, num_moe_layers, topk) self.shm_routing_data = ShmArray(name, shape, dtype=np_dtype) - self.shm_routing_data.link_shm_partial() + self.shm_routing_data.create_shm() self.shm_routing_num_tokens = num_tokens return def link_routing_data_shm_array(self, num_moe_layers: int, topk: int, np_dtype=np.int8): - """Link to the pre-allocated routing SHM from the reader side (HTTP server).""" + """Link to routing SHM from the reader side (HTTP server).""" if num_moe_layers == 0: return num_tokens = self.shm_routing_num_tokens @@ -254,9 +257,9 @@ def link_routing_data_shm_array(self, num_moe_layers: int, topk: int, np_dtype=n return service_uni_name = get_unique_server_name() name = f"{service_uni_name}_shm_routing_{self.index_in_shm_mem}" - shape = (num_moe_layers, num_tokens, topk) + shape = (num_tokens, num_moe_layers, topk) self.shm_routing_data = ShmArray(name, shape, dtype=np_dtype) - self.shm_routing_data.link_shm_partial() + self.shm_routing_data.link_shm() return def get_routing_data(self): @@ -265,9 +268,9 @@ def get_routing_data(self): return self.shm_routing_data.arr def close_routing_data_shm_array(self): - """Detach from pre-allocated SHM without unlinking it.""" + """Close and unlink routing SHM (on-demand, no longer pooled).""" if hasattr(self, "shm_routing_data") and self.shm_routing_data is not None: - self.shm_routing_data.detach_shm() + self.shm_routing_data.close_shm() self.shm_routing_data = None self.shm_routing_num_tokens = 0 return diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py index 1c0642862..31e2fbefe 100644 --- a/lightllm/server/core/objs/sampling_params.py +++ b/lightllm/server/core/objs/sampling_params.py @@ -443,13 +443,18 @@ def _get(key, default): def load_generation_cfg(cls, weight_dir): try: generation_cfg = GenerationConfig.from_pretrained(weight_dir, trust_remote_code=True).to_dict() - cls._do_sample = generation_cfg.get("do_sample", False) - cls._presence_penalty = generation_cfg.get("presence_penalty", 0.0) - cls._frequency_penalty = generation_cfg.get("frequency_penalty", 0.0) - cls._repetition_penalty = generation_cfg.get("repetition_penalty", 1.0) - cls._temperature = generation_cfg.get("temperature", 1.0) - cls._top_p = generation_cfg.get("top_p", 1.0) - cls._top_k = generation_cfg.get("top_k", -1) + + def _cfg(key, default): + v = generation_cfg.get(key) + return v if v is not None else default + + cls._do_sample = _cfg("do_sample", False) + cls._presence_penalty = _cfg("presence_penalty", 0.0) + cls._frequency_penalty = _cfg("frequency_penalty", 0.0) + cls._repetition_penalty = _cfg("repetition_penalty", 1.0) + cls._temperature = _cfg("temperature", 1.0) + cls._top_p = _cfg("top_p", 1.0) + cls._top_k = _cfg("top_k", -1) except: pass diff --git a/test/test_api/test_r3.py b/test/test_api/test_r3.py index 00f34c489..85c4e44ef 100644 --- a/test/test_api/test_r3.py +++ b/test/test_api/test_r3.py @@ -16,8 +16,8 @@ def test_routing_export(url: str = "http://localhost:8000"): "inputs": "What is the capital of France? What is the capital of France?", "parameters": { "max_new_tokens": 50, - "return_routed_experts": True, - "repetition_penalty": 1.0, + # "return_routed_experts": True, + # "repetition_penalty": 1.0, }, }, timeout=60, @@ -60,17 +60,10 @@ def test_routing_export(url: str = "http://localhost:8000"): print(f"{'=' * 50}") print(f"Shape: {shape}") print(f"Dtype: {dtype}") - print(f"Num MoE layers: {shape[0]}") - print(f"Num tokens: {shape[1]}") + print(f"Num tokens: {shape[0]}") + print(f"Num MoE layers: {shape[1]}") print(f"Top-K: {shape[2]}") - # Verify dtype is int8 (for models with ≤127 experts) or int16 - if dtype_str not in ("int8", "int16"): - print(f"\nERROR: Expected dtype int8 or int16, got {dtype_str}") - print("This suggests dtype optimization is not working correctly.") - return False - print(f"\nDtype check PASSED: {dtype_str} (compact representation)") - # Compute payload size savings int32_size = np.prod(shape) * 4 actual_size = len(data) @@ -78,9 +71,9 @@ def test_routing_export(url: str = "http://localhost:8000"): print(f"Payload: {actual_size} bytes (vs {int32_size} bytes with int32, {savings:.0f}% smaller)") print(f"\nSample routing (first layer, first 5 tokens):") - num_tokens_to_show = shape[1] + num_tokens_to_show = shape[0] for i in range(num_tokens_to_show): - print(f" Token {i}: experts {routing_array[0, i, :].tolist()}") + print(f" Token {i}: experts {routing_array[i, 0, :].tolist()}") if np.all(routing_array == 0): print("\nWARNING: All routing data is zeros. Capture may not be working correctly.") From 0a41a2a22ebe084f244fc288ef7c6cdde6792ecb Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Wed, 11 Feb 2026 09:15:04 +0000 Subject: [PATCH 71/71] fix bmm weight slice --- .../layer_weights/meta_weights/mm_weight/mm_weight.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py index 3630bc2c0..da9b3f432 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py @@ -193,7 +193,9 @@ def _create_weight(self): def load_hf_weights(self, weights: Dict[str, torch.Tensor]): for weight_name in self.weight_names: if weight_name in weights: - weight = self.param_slicer._slice_weight(weights[weight_name]) + tp_start = self.tp_rank_ * self.dim0 + tp_end = (self.tp_rank_ + 1) * self.dim0 + weight = weights[weight_name][tp_start:tp_end, :, :] self.weight.copy_(weight) self.weight.load_ok = True return