Skip to content

Commit f1599ca

Browse files
authored
feat(metrics): Add prefill KV compute metric excluding cached tokens (#30189)
Signed-off-by: Ziliang Peng <ziliang@character.ai>
1 parent 60d1725 commit f1599ca

File tree

4 files changed

+126
-1
lines changed

4 files changed

+126
-1
lines changed

tests/v1/metrics/test_stats.py

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,109 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
from vllm.v1.metrics.stats import IterationStats
3+
from vllm.v1.engine import FinishReason
4+
from vllm.v1.metrics.stats import IterationStats, RequestStateStats
45

56

67
def test_iteration_stats_repr():
78
iteration_stats = IterationStats()
89
assert repr(iteration_stats).startswith("IterationStats(")
10+
11+
12+
def test_prefill_kv_computed_with_cache():
13+
"""Test that prefill KV compute correctly excludes cached tokens."""
14+
iteration_stats = IterationStats()
15+
req_stats = RequestStateStats(arrival_time=0.0)
16+
req_stats.scheduled_ts = 0.1
17+
req_stats.first_token_ts = 0.5
18+
req_stats.last_token_ts = 5.0
19+
req_stats.num_generation_tokens = 50
20+
21+
# Case 1: With prefix cache (1200 tokens cached)
22+
iteration_stats.update_from_finished_request(
23+
finish_reason=FinishReason.STOP,
24+
num_prompt_tokens=10000,
25+
max_tokens_param=100,
26+
req_stats=req_stats,
27+
num_cached_tokens=1200,
28+
)
29+
30+
finished_req = iteration_stats.finished_requests[0]
31+
assert finished_req.num_prompt_tokens == 10000
32+
assert finished_req.num_cached_tokens == 1200
33+
34+
# Verify calculation: prefill KV = prompt tokens - cached tokens
35+
prefill_kv_computed = finished_req.num_prompt_tokens - max(
36+
finished_req.num_cached_tokens, 0
37+
)
38+
assert prefill_kv_computed == 8800 # 10000 - 1200
39+
40+
41+
def test_prefill_kv_computed_no_cache():
42+
"""Test prefill KV compute without prefix caching."""
43+
iteration_stats = IterationStats()
44+
req_stats = RequestStateStats(arrival_time=0.0)
45+
req_stats.scheduled_ts = 0.1
46+
req_stats.first_token_ts = 0.5
47+
req_stats.last_token_ts = 2.0
48+
req_stats.num_generation_tokens = 10
49+
50+
# Case 2: No prefix cache
51+
iteration_stats.update_from_finished_request(
52+
finish_reason=FinishReason.STOP,
53+
num_prompt_tokens=2000,
54+
max_tokens_param=100,
55+
req_stats=req_stats,
56+
num_cached_tokens=0,
57+
)
58+
59+
finished_req = iteration_stats.finished_requests[0]
60+
assert finished_req.num_prompt_tokens == 2000
61+
assert finished_req.num_cached_tokens == 0
62+
63+
# Verify calculation: prefill KV = full prompt when no cache
64+
prefill_kv_computed = finished_req.num_prompt_tokens - max(
65+
finished_req.num_cached_tokens, 0
66+
)
67+
assert prefill_kv_computed == 2000
68+
69+
70+
def test_prefill_kv_computed_edge_cases():
71+
"""Test edge cases for prefill KV compute calculation."""
72+
iteration_stats = IterationStats()
73+
req_stats = RequestStateStats(arrival_time=0.0)
74+
req_stats.scheduled_ts = 0.1
75+
req_stats.first_token_ts = 0.5
76+
req_stats.last_token_ts = 1.0
77+
req_stats.num_generation_tokens = 1
78+
79+
# Case 3: Negative num_cached_tokens (shouldn't happen, but handle gracefully)
80+
iteration_stats.update_from_finished_request(
81+
finish_reason=FinishReason.STOP,
82+
num_prompt_tokens=100,
83+
max_tokens_param=10,
84+
req_stats=req_stats,
85+
num_cached_tokens=-1,
86+
)
87+
88+
finished_req = iteration_stats.finished_requests[0]
89+
# max() should handle negative values
90+
prefill_kv_computed = finished_req.num_prompt_tokens - max(
91+
finished_req.num_cached_tokens, 0
92+
)
93+
assert prefill_kv_computed == 100 # Should treat negative as 0
94+
95+
# Case 4: All tokens cached (shouldn't happen in practice)
96+
iteration_stats2 = IterationStats()
97+
iteration_stats2.update_from_finished_request(
98+
finish_reason=FinishReason.STOP,
99+
num_prompt_tokens=100,
100+
max_tokens_param=10,
101+
req_stats=req_stats,
102+
num_cached_tokens=100,
103+
)
104+
105+
finished_req2 = iteration_stats2.finished_requests[0]
106+
prefill_kv_computed2 = finished_req2.num_prompt_tokens - max(
107+
finished_req2.num_cached_tokens, 0
108+
)
109+
assert prefill_kv_computed2 == 0 # All cached, nothing computed

vllm/v1/engine/output_processor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -650,6 +650,7 @@ def _update_stats_from_finished(
650650
),
651651
max_tokens_param=req_state.max_tokens_param,
652652
req_stats=req_state.stats,
653+
num_cached_tokens=req_state.num_cached_tokens,
653654
)
654655
self.lora_states.request_finished(req_state.request_id, req_state.lora_name)
655656

vllm/v1/metrics/loggers.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -870,6 +870,19 @@ def __init__(
870870
histogram_decode_time_request, engine_indexes, model_name
871871
)
872872

873+
histogram_prefill_kv_computed_request = self._histogram_cls(
874+
name="vllm:request_prefill_kv_computed_tokens",
875+
documentation=(
876+
"Histogram of new KV tokens computed during prefill "
877+
"(excluding cached tokens)."
878+
),
879+
buckets=build_1_2_5_buckets(max_model_len),
880+
labelnames=labelnames,
881+
)
882+
self.histogram_prefill_kv_computed_request = make_per_engine(
883+
histogram_prefill_kv_computed_request, engine_indexes, model_name
884+
)
885+
873886
#
874887
# KV Cache residency metrics
875888
#
@@ -1118,6 +1131,13 @@ def record(
11181131
self.histogram_decode_time_request[engine_idx].observe(
11191132
finished_request.decode_time
11201133
)
1134+
# Calculate prefill KV compute (excludes cached tokens)
1135+
prefill_kv_computed = finished_request.num_prompt_tokens - max(
1136+
finished_request.num_cached_tokens, 0
1137+
)
1138+
self.histogram_prefill_kv_computed_request[engine_idx].observe(
1139+
prefill_kv_computed
1140+
)
11211141
self.histogram_num_prompt_tokens_request[engine_idx].observe(
11221142
finished_request.num_prompt_tokens
11231143
)

vllm/v1/metrics/stats.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ class FinishedRequestStats:
224224
decode_time: float = 0.0
225225
mean_time_per_output_token: float = 0.0
226226
is_corrupted: bool = False
227+
num_cached_tokens: int = 0
227228

228229

229230
class IterationStats:
@@ -330,6 +331,7 @@ def update_from_finished_request(
330331
num_prompt_tokens: int,
331332
max_tokens_param: int | None,
332333
req_stats: RequestStateStats,
334+
num_cached_tokens: int = 0,
333335
):
334336
e2e_latency = self._time_since(req_stats.arrival_time)
335337

@@ -367,6 +369,7 @@ def update_from_finished_request(
367369
decode_time=decode_time,
368370
mean_time_per_output_token=mean_time_per_output_token,
369371
is_corrupted=req_stats.is_corrupted,
372+
num_cached_tokens=num_cached_tokens,
370373
)
371374
self.finished_requests.append(finished_req)
372375

0 commit comments

Comments
 (0)