mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 09:25:34 +08:00
feat(metrics): Add prefill KV compute metric excluding cached tokens (#30189)
Signed-off-by: Ziliang Peng <ziliang@character.ai>
This commit is contained in:
parent
60d17251c9
commit
f1599ca55d
@ -1,8 +1,109 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from vllm.v1.metrics.stats import IterationStats
|
||||
from vllm.v1.engine import FinishReason
|
||||
from vllm.v1.metrics.stats import IterationStats, RequestStateStats
|
||||
|
||||
|
||||
def test_iteration_stats_repr():
|
||||
iteration_stats = IterationStats()
|
||||
assert repr(iteration_stats).startswith("IterationStats(")
|
||||
|
||||
|
||||
def test_prefill_kv_computed_with_cache():
|
||||
"""Test that prefill KV compute correctly excludes cached tokens."""
|
||||
iteration_stats = IterationStats()
|
||||
req_stats = RequestStateStats(arrival_time=0.0)
|
||||
req_stats.scheduled_ts = 0.1
|
||||
req_stats.first_token_ts = 0.5
|
||||
req_stats.last_token_ts = 5.0
|
||||
req_stats.num_generation_tokens = 50
|
||||
|
||||
# Case 1: With prefix cache (1200 tokens cached)
|
||||
iteration_stats.update_from_finished_request(
|
||||
finish_reason=FinishReason.STOP,
|
||||
num_prompt_tokens=10000,
|
||||
max_tokens_param=100,
|
||||
req_stats=req_stats,
|
||||
num_cached_tokens=1200,
|
||||
)
|
||||
|
||||
finished_req = iteration_stats.finished_requests[0]
|
||||
assert finished_req.num_prompt_tokens == 10000
|
||||
assert finished_req.num_cached_tokens == 1200
|
||||
|
||||
# Verify calculation: prefill KV = prompt tokens - cached tokens
|
||||
prefill_kv_computed = finished_req.num_prompt_tokens - max(
|
||||
finished_req.num_cached_tokens, 0
|
||||
)
|
||||
assert prefill_kv_computed == 8800 # 10000 - 1200
|
||||
|
||||
|
||||
def test_prefill_kv_computed_no_cache():
|
||||
"""Test prefill KV compute without prefix caching."""
|
||||
iteration_stats = IterationStats()
|
||||
req_stats = RequestStateStats(arrival_time=0.0)
|
||||
req_stats.scheduled_ts = 0.1
|
||||
req_stats.first_token_ts = 0.5
|
||||
req_stats.last_token_ts = 2.0
|
||||
req_stats.num_generation_tokens = 10
|
||||
|
||||
# Case 2: No prefix cache
|
||||
iteration_stats.update_from_finished_request(
|
||||
finish_reason=FinishReason.STOP,
|
||||
num_prompt_tokens=2000,
|
||||
max_tokens_param=100,
|
||||
req_stats=req_stats,
|
||||
num_cached_tokens=0,
|
||||
)
|
||||
|
||||
finished_req = iteration_stats.finished_requests[0]
|
||||
assert finished_req.num_prompt_tokens == 2000
|
||||
assert finished_req.num_cached_tokens == 0
|
||||
|
||||
# Verify calculation: prefill KV = full prompt when no cache
|
||||
prefill_kv_computed = finished_req.num_prompt_tokens - max(
|
||||
finished_req.num_cached_tokens, 0
|
||||
)
|
||||
assert prefill_kv_computed == 2000
|
||||
|
||||
|
||||
def test_prefill_kv_computed_edge_cases():
|
||||
"""Test edge cases for prefill KV compute calculation."""
|
||||
iteration_stats = IterationStats()
|
||||
req_stats = RequestStateStats(arrival_time=0.0)
|
||||
req_stats.scheduled_ts = 0.1
|
||||
req_stats.first_token_ts = 0.5
|
||||
req_stats.last_token_ts = 1.0
|
||||
req_stats.num_generation_tokens = 1
|
||||
|
||||
# Case 3: Negative num_cached_tokens (shouldn't happen, but handle gracefully)
|
||||
iteration_stats.update_from_finished_request(
|
||||
finish_reason=FinishReason.STOP,
|
||||
num_prompt_tokens=100,
|
||||
max_tokens_param=10,
|
||||
req_stats=req_stats,
|
||||
num_cached_tokens=-1,
|
||||
)
|
||||
|
||||
finished_req = iteration_stats.finished_requests[0]
|
||||
# max() should handle negative values
|
||||
prefill_kv_computed = finished_req.num_prompt_tokens - max(
|
||||
finished_req.num_cached_tokens, 0
|
||||
)
|
||||
assert prefill_kv_computed == 100 # Should treat negative as 0
|
||||
|
||||
# Case 4: All tokens cached (shouldn't happen in practice)
|
||||
iteration_stats2 = IterationStats()
|
||||
iteration_stats2.update_from_finished_request(
|
||||
finish_reason=FinishReason.STOP,
|
||||
num_prompt_tokens=100,
|
||||
max_tokens_param=10,
|
||||
req_stats=req_stats,
|
||||
num_cached_tokens=100,
|
||||
)
|
||||
|
||||
finished_req2 = iteration_stats2.finished_requests[0]
|
||||
prefill_kv_computed2 = finished_req2.num_prompt_tokens - max(
|
||||
finished_req2.num_cached_tokens, 0
|
||||
)
|
||||
assert prefill_kv_computed2 == 0 # All cached, nothing computed
|
||||
|
||||
@ -650,6 +650,7 @@ class OutputProcessor:
|
||||
),
|
||||
max_tokens_param=req_state.max_tokens_param,
|
||||
req_stats=req_state.stats,
|
||||
num_cached_tokens=req_state.num_cached_tokens,
|
||||
)
|
||||
self.lora_states.request_finished(req_state.request_id, req_state.lora_name)
|
||||
|
||||
|
||||
@ -870,6 +870,19 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
|
||||
histogram_decode_time_request, engine_indexes, model_name
|
||||
)
|
||||
|
||||
histogram_prefill_kv_computed_request = self._histogram_cls(
|
||||
name="vllm:request_prefill_kv_computed_tokens",
|
||||
documentation=(
|
||||
"Histogram of new KV tokens computed during prefill "
|
||||
"(excluding cached tokens)."
|
||||
),
|
||||
buckets=build_1_2_5_buckets(max_model_len),
|
||||
labelnames=labelnames,
|
||||
)
|
||||
self.histogram_prefill_kv_computed_request = make_per_engine(
|
||||
histogram_prefill_kv_computed_request, engine_indexes, model_name
|
||||
)
|
||||
|
||||
#
|
||||
# KV Cache residency metrics
|
||||
#
|
||||
@ -1118,6 +1131,13 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
|
||||
self.histogram_decode_time_request[engine_idx].observe(
|
||||
finished_request.decode_time
|
||||
)
|
||||
# Calculate prefill KV compute (excludes cached tokens)
|
||||
prefill_kv_computed = finished_request.num_prompt_tokens - max(
|
||||
finished_request.num_cached_tokens, 0
|
||||
)
|
||||
self.histogram_prefill_kv_computed_request[engine_idx].observe(
|
||||
prefill_kv_computed
|
||||
)
|
||||
self.histogram_num_prompt_tokens_request[engine_idx].observe(
|
||||
finished_request.num_prompt_tokens
|
||||
)
|
||||
|
||||
@ -224,6 +224,7 @@ class FinishedRequestStats:
|
||||
decode_time: float = 0.0
|
||||
mean_time_per_output_token: float = 0.0
|
||||
is_corrupted: bool = False
|
||||
num_cached_tokens: int = 0
|
||||
|
||||
|
||||
class IterationStats:
|
||||
@ -330,6 +331,7 @@ class IterationStats:
|
||||
num_prompt_tokens: int,
|
||||
max_tokens_param: int | None,
|
||||
req_stats: RequestStateStats,
|
||||
num_cached_tokens: int = 0,
|
||||
):
|
||||
e2e_latency = self._time_since(req_stats.arrival_time)
|
||||
|
||||
@ -367,6 +369,7 @@ class IterationStats:
|
||||
decode_time=decode_time,
|
||||
mean_time_per_output_token=mean_time_per_output_token,
|
||||
is_corrupted=req_stats.is_corrupted,
|
||||
num_cached_tokens=num_cached_tokens,
|
||||
)
|
||||
self.finished_requests.append(finished_req)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user