Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 11 additions & 15 deletions lightllm/common/req_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from lightllm.common.basemodel.triton_kernel.gen_sampling_params import update_req_to_token_id_counter
from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args
from lightllm.utils.config_utils import get_vocab_size
from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager

logger = init_logger(__name__)

Expand Down Expand Up @@ -155,7 +156,11 @@ def init_req_sampling_params(self, req):
else:
self.req_to_out_token_id_counter[req.req_idx].fill_(0)
if req.sampling_param.shm_param.input_penalty and req.need_out_token_id_statistics:
prompt_ids = torch.from_numpy(req.shm_req.get_prompt_ids_numpy()).pin_memory().cuda(non_blocking=True)
prompt_ids = g_pin_mem_manager.gen_from_list(
key="prompt_ids_for_penalty",
data=req.shm_req.get_prompt_ids_numpy(),
dtype=torch.int32,
).cuda(non_blocking=True)
token_id_counter(
prompt_ids=prompt_ids, out_token_id_counter=self.req_to_out_token_id_counter[req.req_idx]
)
Expand Down Expand Up @@ -214,22 +219,13 @@ def gen_cpu_out_token_counter_sampling_params(self, req_objs: List):
cum_sum_len += len(id_to_count)
p_cumsum_seq_len.append(cum_sum_len)

from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager

p_token_ids_tensor = g_pin_mem_manager.alloc_pin_tensor(
key="p_token_ids", size=len(p_token_ids), dtype=torch.int32
)
p_token_ids_tensor.numpy()[:] = p_token_ids

p_token_counts_tensor = g_pin_mem_manager.alloc_pin_tensor(
key="p_token_counts", size=len(p_token_counts), dtype=torch.int32
p_token_ids_tensor = g_pin_mem_manager.gen_from_list(key="p_token_ids", data=p_token_ids, dtype=torch.int32)
p_token_counts_tensor = g_pin_mem_manager.gen_from_list(
key="p_token_counts", data=p_token_counts, dtype=torch.int32
)
p_token_counts_tensor.numpy()[:] = p_token_counts

p_cumsum_seq_len_tensor = g_pin_mem_manager.alloc_pin_tensor(
key="p_cumsum_seq_len", size=len(p_cumsum_seq_len), dtype=torch.int32
p_cumsum_seq_len_tensor = g_pin_mem_manager.gen_from_list(
key="p_cumsum_seq_len", data=p_cumsum_seq_len, dtype=torch.int32
)
p_cumsum_seq_len_tensor.numpy()[:] = p_cumsum_seq_len

return (
p_token_ids_tensor.cuda(non_blocking=True),
Expand Down
3 changes: 2 additions & 1 deletion lightllm/server/router/dynamic_prompt/radix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,8 @@ def match_prefix(self, key, update_refs=False):
value = torch.zeros((0,), device="cpu", dtype=self._value_dtype)
return tree_node, len(value), value
else:
self.dec_node_ref_counter(self.root_node)
if update_refs:
self.dec_node_ref_counter(self.root_node)
return None, 0, None

def _match_prefix_helper(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from lightllm.common.basemodel.triton_kernel.apply_penalty import apply_penalty
from lightllm.common.basemodel.triton_kernel.apply_penalty_gpu_cache import apply_penalty_gpu_cache
from lightllm.server.router.model_infer.infer_batch import InferReq, g_infer_context
from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager
from lightllm.utils.envs_utils import get_env_start_args


Expand All @@ -16,7 +17,7 @@ def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]):
b_mask_eos_reqs,
is_all_greedy,
) = _get_post_sample_tensors(reqs)
eos_ids = torch.tensor(eos_id, dtype=torch.int32, device="cpu", pin_memory=True).cuda(non_blocking=True)
eos_ids = g_pin_mem_manager.gen_from_list(key="eos_ids", data=eos_id, dtype=torch.int32).cuda(non_blocking=True)

sampling_params_manager = g_infer_context.req_manager.req_sampling_params_manager

Expand Down Expand Up @@ -128,12 +129,14 @@ def _get_post_sample_tensors(reqs: List[InferReq]):
is_all_greedy = False
req_idxes.append(req_obj.req_idx)

req_idxes_cpu = torch.tensor(req_idxes, dtype=torch.int32, device="cpu", pin_memory=True)
temperatures_cpu = torch.tensor(temperatures, dtype=torch.float, device="cpu", pin_memory=True)
top_ps_cpu = torch.tensor(top_ps, dtype=torch.float, device="cpu", pin_memory=True)
top_ks_cpu = torch.tensor(top_ks, dtype=torch.int32, device="cpu", pin_memory=True)
length_penalty_param_cpu = torch.tensor(length_penalty_param, dtype=torch.int32, device="cpu", pin_memory=True)
mask_eos_reqs_cpu = torch.tensor(mask_eos_reqs, dtype=torch.bool, device="cpu", pin_memory=True)
req_idxes_cpu = g_pin_mem_manager.gen_from_list(key="req_idxes", data=req_idxes, dtype=torch.int32)
temperatures_cpu = g_pin_mem_manager.gen_from_list(key="temperatures", data=temperatures, dtype=torch.float32)
top_ps_cpu = g_pin_mem_manager.gen_from_list(key="top_ps", data=top_ps, dtype=torch.float32)
top_ks_cpu = g_pin_mem_manager.gen_from_list(key="top_ks", data=top_ks, dtype=torch.int32)
length_penalty_param_cpu = g_pin_mem_manager.gen_from_list(
key="length_penalty_param", data=length_penalty_param, dtype=torch.int32
)
mask_eos_reqs_cpu = g_pin_mem_manager.gen_from_list(key="mask_eos_reqs", data=mask_eos_reqs, dtype=torch.bool)

return (
req_idxes_cpu.cuda(non_blocking=True),
Expand Down
136 changes: 99 additions & 37 deletions test/benchmark/service/benchmark_qps.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ def get_custom_input_data(data_path, output_len, tokenizer, range_ratio):
model_name = []


# Minimal fix: one retry on transient network errors.
_DEFAULT_RETRY = 1


async def async_post_stream_openai(url, prompt, max_new_tokens, session):
try:
text_input, input_len = prompt
Expand All @@ -116,21 +120,34 @@ async def async_post_stream_openai(url, prompt, max_new_tokens, session):
"best_of": 1,
}
headers = {"Content-Type": "application/json"}
used_time = []
start_time = time.time()
last_time = start_time
async with session.post(url, headers=headers, json=data) as response:
if response.status != 200:
return []

async for line in response.content:
line = line.strip()
if line:
current_time = time.time()
elapsed_time = current_time - last_time
used_time.append(elapsed_time)
last_time = current_time
return used_time, input_len

for attempt in range(_DEFAULT_RETRY + 1):
used_time = []
start_time = time.time()
last_time = start_time
try:
async with session.post(url, headers=headers, json=data) as response:
if response.status != 200:
return []
Comment on lines +130 to +131
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The retry loop is intended to handle transient errors, but returning [] here on any non-200 status code prevents retries for HTTP errors (e.g., 503 Service Unavailable). Consider checking the status code and only returning immediately for non-retriable errors (like 4xx client errors), while continuing the loop for retriable server errors (e.g., 5xx).


try:
async for line in response.content:
line = line.strip()
if line:
current_time = time.time()
elapsed_time = current_time - last_time
used_time.append(elapsed_time)
last_time = current_time
except Exception:
# server may disconnect mid-stream; keep partial timings if any.
pass

if used_time or attempt >= _DEFAULT_RETRY:
return used_time, input_len
except Exception as e:
if attempt >= _DEFAULT_RETRY:
print(e)
return []
except Exception as e:
print(e)
pass
Expand All @@ -149,21 +166,33 @@ async def async_post_stream_lightllm(url, prompt, max_new_tokens, session):
},
}
headers = {"Content-Type": "application/json"}
used_time = []
start_time = time.time()
last_time = start_time
async with session.post(url, headers=headers, json=data) as response:
if response.status != 200:
return []

async for line in response.content:
if line and line.startswith(b"data:"):
# print(line)
current_time = time.time()
elapsed_time = current_time - last_time
used_time.append(elapsed_time)
last_time = current_time
return used_time, input_len

for attempt in range(_DEFAULT_RETRY + 1):
used_time = []
start_time = time.time()
last_time = start_time
try:
async with session.post(url, headers=headers, json=data) as response:
if response.status != 200:
return []

try:
async for line in response.content:
if line and line.startswith(b"data:"):
current_time = time.time()
elapsed_time = current_time - last_time
used_time.append(elapsed_time)
last_time = current_time
except Exception:
# server may disconnect mid-stream; keep partial timings if any.
pass

if used_time or attempt >= _DEFAULT_RETRY:
return used_time, input_len
except Exception as e:
if attempt >= _DEFAULT_RETRY:
print(e)
return []
except Exception as e:
print(e)
pass
Expand All @@ -187,6 +216,7 @@ async def continuous_sender(
while not stop_send.is_set():
if not continuous_send and sent_count[0] >= max_count:
break

prompt = prompts[prompt_index % len(prompts)]
max_tokens = max_new_tokens[prompt_index % len(max_new_tokens)]

Expand All @@ -212,18 +242,42 @@ async def response_collector(
force_terminate,
pending_tasks,
):
# 单个请求在 collector 侧的最大等待时间,避免网络异常导致永久卡住
task_timeout_s = 600
try:
while True:
try:
task = await asyncio.wait_for(request_queue.get(), timeout=1.0)
result, input_len = await task
request_queue.task_done()
assert result is not None
if len(result) >= 1 and not stop_send.is_set():
results.append((result, input_len))
result = None
input_len = 0
try:
try:
result_tuple = await asyncio.wait_for(task, timeout=task_timeout_s)
except asyncio.TimeoutError:
print("\nError collecting response: task timeout")
if not task.done():
task.cancel()
result_tuple = None

if isinstance(result_tuple, tuple) and len(result_tuple) == 2:
result, input_len = result_tuple
else:
result = None
input_len = 0
except Exception as e:
print(f"\nError collecting response: {e}")
finally:
# 确保队列不会因为 continue/exception 而永久积压
request_queue.task_done()

# 无论成功失败都推进计数,避免等待 remaining responses 时卡死
current_count = counter[0] + 1
counter[0] = current_count
print(f"\rfinished_reqs:{current_count} / target_reqs:{reqs_num} / sent_reqs:{sent_count[0]}", end="")

if result is not None:
if len(result) >= 1 and not stop_send.is_set():
results.append((result, input_len))
if len(results) >= reqs_num and not stop_send.is_set():
end_time[0] = time.time()
print("\nReached target number of responses")
Expand All @@ -245,6 +299,7 @@ async def response_collector(
continue
except Exception as e:
print(f"\nError collecting response: {e}")
continue
finally:
if force_terminate:
for task in pending_tasks:
Expand All @@ -253,7 +308,15 @@ async def response_collector(


async def run_continuous_benchmark(
async_task, url, prompts, max_new_tokens, reqs_num, num_clients, input_qps, force_terminate, continuous_send
async_task,
url,
prompts,
max_new_tokens,
reqs_num,
num_clients,
input_qps,
force_terminate,
continuous_send,
):
request_queue = asyncio.Queue()
stop_event = asyncio.Event()
Expand Down Expand Up @@ -414,7 +477,6 @@ def main():
)
)
loop.close()
print(len(results))
first_token_time = []
decode_token_time = []
request_time = []
Expand Down