diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index 40c8aa993..33bdca447 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -7,6 +7,7 @@ from lightllm.common.basemodel.triton_kernel.gen_sampling_params import update_req_to_token_id_counter from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args from lightllm.utils.config_utils import get_vocab_size +from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager logger = init_logger(__name__) @@ -155,7 +156,11 @@ def init_req_sampling_params(self, req): else: self.req_to_out_token_id_counter[req.req_idx].fill_(0) if req.sampling_param.shm_param.input_penalty and req.need_out_token_id_statistics: - prompt_ids = torch.from_numpy(req.shm_req.get_prompt_ids_numpy()).pin_memory().cuda(non_blocking=True) + prompt_ids = g_pin_mem_manager.gen_from_list( + key="prompt_ids_for_penalty", + data=req.shm_req.get_prompt_ids_numpy(), + dtype=torch.int32, + ).cuda(non_blocking=True) token_id_counter( prompt_ids=prompt_ids, out_token_id_counter=self.req_to_out_token_id_counter[req.req_idx] ) @@ -214,22 +219,13 @@ def gen_cpu_out_token_counter_sampling_params(self, req_objs: List): cum_sum_len += len(id_to_count) p_cumsum_seq_len.append(cum_sum_len) - from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager - - p_token_ids_tensor = g_pin_mem_manager.alloc_pin_tensor( - key="p_token_ids", size=len(p_token_ids), dtype=torch.int32 - ) - p_token_ids_tensor.numpy()[:] = p_token_ids - - p_token_counts_tensor = g_pin_mem_manager.alloc_pin_tensor( - key="p_token_counts", size=len(p_token_counts), dtype=torch.int32 + p_token_ids_tensor = g_pin_mem_manager.gen_from_list(key="p_token_ids", data=p_token_ids, dtype=torch.int32) + p_token_counts_tensor = g_pin_mem_manager.gen_from_list( + key="p_token_counts", data=p_token_counts, dtype=torch.int32 ) - p_token_counts_tensor.numpy()[:] = p_token_counts - - p_cumsum_seq_len_tensor = g_pin_mem_manager.alloc_pin_tensor( - key="p_cumsum_seq_len", size=len(p_cumsum_seq_len), dtype=torch.int32 + p_cumsum_seq_len_tensor = g_pin_mem_manager.gen_from_list( + key="p_cumsum_seq_len", data=p_cumsum_seq_len, dtype=torch.int32 ) - p_cumsum_seq_len_tensor.numpy()[:] = p_cumsum_seq_len return ( p_token_ids_tensor.cuda(non_blocking=True), diff --git a/lightllm/server/router/dynamic_prompt/radix_cache.py b/lightllm/server/router/dynamic_prompt/radix_cache.py index c51774898..88b099459 100644 --- a/lightllm/server/router/dynamic_prompt/radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/radix_cache.py @@ -241,7 +241,8 @@ def match_prefix(self, key, update_refs=False): value = torch.zeros((0,), device="cpu", dtype=self._value_dtype) return tree_node, len(value), value else: - self.dec_node_ref_counter(self.root_node) + if update_refs: + self.dec_node_ref_counter(self.root_node) return None, 0, None def _match_prefix_helper( diff --git a/lightllm/server/router/model_infer/mode_backend/generic_post_process.py b/lightllm/server/router/model_infer/mode_backend/generic_post_process.py index e2ccf290e..ca3901ebd 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_post_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_post_process.py @@ -3,6 +3,7 @@ from lightllm.common.basemodel.triton_kernel.apply_penalty import apply_penalty from lightllm.common.basemodel.triton_kernel.apply_penalty_gpu_cache import apply_penalty_gpu_cache from lightllm.server.router.model_infer.infer_batch import InferReq, g_infer_context +from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager from lightllm.utils.envs_utils import get_env_start_args @@ -16,7 +17,7 @@ def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]): b_mask_eos_reqs, is_all_greedy, ) = _get_post_sample_tensors(reqs) - eos_ids = torch.tensor(eos_id, dtype=torch.int32, device="cpu", pin_memory=True).cuda(non_blocking=True) + eos_ids = g_pin_mem_manager.gen_from_list(key="eos_ids", data=eos_id, dtype=torch.int32).cuda(non_blocking=True) sampling_params_manager = g_infer_context.req_manager.req_sampling_params_manager @@ -128,12 +129,14 @@ def _get_post_sample_tensors(reqs: List[InferReq]): is_all_greedy = False req_idxes.append(req_obj.req_idx) - req_idxes_cpu = torch.tensor(req_idxes, dtype=torch.int32, device="cpu", pin_memory=True) - temperatures_cpu = torch.tensor(temperatures, dtype=torch.float, device="cpu", pin_memory=True) - top_ps_cpu = torch.tensor(top_ps, dtype=torch.float, device="cpu", pin_memory=True) - top_ks_cpu = torch.tensor(top_ks, dtype=torch.int32, device="cpu", pin_memory=True) - length_penalty_param_cpu = torch.tensor(length_penalty_param, dtype=torch.int32, device="cpu", pin_memory=True) - mask_eos_reqs_cpu = torch.tensor(mask_eos_reqs, dtype=torch.bool, device="cpu", pin_memory=True) + req_idxes_cpu = g_pin_mem_manager.gen_from_list(key="req_idxes", data=req_idxes, dtype=torch.int32) + temperatures_cpu = g_pin_mem_manager.gen_from_list(key="temperatures", data=temperatures, dtype=torch.float32) + top_ps_cpu = g_pin_mem_manager.gen_from_list(key="top_ps", data=top_ps, dtype=torch.float32) + top_ks_cpu = g_pin_mem_manager.gen_from_list(key="top_ks", data=top_ks, dtype=torch.int32) + length_penalty_param_cpu = g_pin_mem_manager.gen_from_list( + key="length_penalty_param", data=length_penalty_param, dtype=torch.int32 + ) + mask_eos_reqs_cpu = g_pin_mem_manager.gen_from_list(key="mask_eos_reqs", data=mask_eos_reqs, dtype=torch.bool) return ( req_idxes_cpu.cuda(non_blocking=True), diff --git a/test/benchmark/service/benchmark_qps.py b/test/benchmark/service/benchmark_qps.py index abf312ee2..8249ae2c4 100644 --- a/test/benchmark/service/benchmark_qps.py +++ b/test/benchmark/service/benchmark_qps.py @@ -103,6 +103,10 @@ def get_custom_input_data(data_path, output_len, tokenizer, range_ratio): model_name = [] +# Minimal fix: one retry on transient network errors. +_DEFAULT_RETRY = 1 + + async def async_post_stream_openai(url, prompt, max_new_tokens, session): try: text_input, input_len = prompt @@ -116,21 +120,34 @@ async def async_post_stream_openai(url, prompt, max_new_tokens, session): "best_of": 1, } headers = {"Content-Type": "application/json"} - used_time = [] - start_time = time.time() - last_time = start_time - async with session.post(url, headers=headers, json=data) as response: - if response.status != 200: - return [] - - async for line in response.content: - line = line.strip() - if line: - current_time = time.time() - elapsed_time = current_time - last_time - used_time.append(elapsed_time) - last_time = current_time - return used_time, input_len + + for attempt in range(_DEFAULT_RETRY + 1): + used_time = [] + start_time = time.time() + last_time = start_time + try: + async with session.post(url, headers=headers, json=data) as response: + if response.status != 200: + return [] + + try: + async for line in response.content: + line = line.strip() + if line: + current_time = time.time() + elapsed_time = current_time - last_time + used_time.append(elapsed_time) + last_time = current_time + except Exception: + # server may disconnect mid-stream; keep partial timings if any. + pass + + if used_time or attempt >= _DEFAULT_RETRY: + return used_time, input_len + except Exception as e: + if attempt >= _DEFAULT_RETRY: + print(e) + return [] except Exception as e: print(e) pass @@ -149,21 +166,33 @@ async def async_post_stream_lightllm(url, prompt, max_new_tokens, session): }, } headers = {"Content-Type": "application/json"} - used_time = [] - start_time = time.time() - last_time = start_time - async with session.post(url, headers=headers, json=data) as response: - if response.status != 200: - return [] - - async for line in response.content: - if line and line.startswith(b"data:"): - # print(line) - current_time = time.time() - elapsed_time = current_time - last_time - used_time.append(elapsed_time) - last_time = current_time - return used_time, input_len + + for attempt in range(_DEFAULT_RETRY + 1): + used_time = [] + start_time = time.time() + last_time = start_time + try: + async with session.post(url, headers=headers, json=data) as response: + if response.status != 200: + return [] + + try: + async for line in response.content: + if line and line.startswith(b"data:"): + current_time = time.time() + elapsed_time = current_time - last_time + used_time.append(elapsed_time) + last_time = current_time + except Exception: + # server may disconnect mid-stream; keep partial timings if any. + pass + + if used_time or attempt >= _DEFAULT_RETRY: + return used_time, input_len + except Exception as e: + if attempt >= _DEFAULT_RETRY: + print(e) + return [] except Exception as e: print(e) pass @@ -187,6 +216,7 @@ async def continuous_sender( while not stop_send.is_set(): if not continuous_send and sent_count[0] >= max_count: break + prompt = prompts[prompt_index % len(prompts)] max_tokens = max_new_tokens[prompt_index % len(max_new_tokens)] @@ -212,18 +242,42 @@ async def response_collector( force_terminate, pending_tasks, ): + # 单个请求在 collector 侧的最大等待时间,避免网络异常导致永久卡住 + task_timeout_s = 600 try: while True: try: task = await asyncio.wait_for(request_queue.get(), timeout=1.0) - result, input_len = await task - request_queue.task_done() - assert result is not None - if len(result) >= 1 and not stop_send.is_set(): - results.append((result, input_len)) + result = None + input_len = 0 + try: + try: + result_tuple = await asyncio.wait_for(task, timeout=task_timeout_s) + except asyncio.TimeoutError: + print("\nError collecting response: task timeout") + if not task.done(): + task.cancel() + result_tuple = None + + if isinstance(result_tuple, tuple) and len(result_tuple) == 2: + result, input_len = result_tuple + else: + result = None + input_len = 0 + except Exception as e: + print(f"\nError collecting response: {e}") + finally: + # 确保队列不会因为 continue/exception 而永久积压 + request_queue.task_done() + + # 无论成功失败都推进计数,避免等待 remaining responses 时卡死 current_count = counter[0] + 1 counter[0] = current_count print(f"\rfinished_reqs:{current_count} / target_reqs:{reqs_num} / sent_reqs:{sent_count[0]}", end="") + + if result is not None: + if len(result) >= 1 and not stop_send.is_set(): + results.append((result, input_len)) if len(results) >= reqs_num and not stop_send.is_set(): end_time[0] = time.time() print("\nReached target number of responses") @@ -245,6 +299,7 @@ async def response_collector( continue except Exception as e: print(f"\nError collecting response: {e}") + continue finally: if force_terminate: for task in pending_tasks: @@ -253,7 +308,15 @@ async def response_collector( async def run_continuous_benchmark( - async_task, url, prompts, max_new_tokens, reqs_num, num_clients, input_qps, force_terminate, continuous_send + async_task, + url, + prompts, + max_new_tokens, + reqs_num, + num_clients, + input_qps, + force_terminate, + continuous_send, ): request_queue = asyncio.Queue() stop_event = asyncio.Event() @@ -414,7 +477,6 @@ def main(): ) ) loop.close() - print(len(results)) first_token_time = [] decode_token_time = [] request_time = []