vllm/tests/v1/metrics/test_stats.py
Victor Ziliang Peng f1599ca55d
feat(metrics): Add prefill KV compute metric excluding cached tokens (#30189)
Signed-off-by: Ziliang Peng <ziliang@character.ai>
2025-12-09 00:08:48 +00:00

110 lines
3.7 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.v1.engine import FinishReason
from vllm.v1.metrics.stats import IterationStats, RequestStateStats
def test_iteration_stats_repr():
iteration_stats = IterationStats()
assert repr(iteration_stats).startswith("IterationStats(")
def test_prefill_kv_computed_with_cache():
"""Test that prefill KV compute correctly excludes cached tokens."""
iteration_stats = IterationStats()
req_stats = RequestStateStats(arrival_time=0.0)
req_stats.scheduled_ts = 0.1
req_stats.first_token_ts = 0.5
req_stats.last_token_ts = 5.0
req_stats.num_generation_tokens = 50
# Case 1: With prefix cache (1200 tokens cached)
iteration_stats.update_from_finished_request(
finish_reason=FinishReason.STOP,
num_prompt_tokens=10000,
max_tokens_param=100,
req_stats=req_stats,
num_cached_tokens=1200,
)
finished_req = iteration_stats.finished_requests[0]
assert finished_req.num_prompt_tokens == 10000
assert finished_req.num_cached_tokens == 1200
# Verify calculation: prefill KV = prompt tokens - cached tokens
prefill_kv_computed = finished_req.num_prompt_tokens - max(
finished_req.num_cached_tokens, 0
)
assert prefill_kv_computed == 8800 # 10000 - 1200
def test_prefill_kv_computed_no_cache():
"""Test prefill KV compute without prefix caching."""
iteration_stats = IterationStats()
req_stats = RequestStateStats(arrival_time=0.0)
req_stats.scheduled_ts = 0.1
req_stats.first_token_ts = 0.5
req_stats.last_token_ts = 2.0
req_stats.num_generation_tokens = 10
# Case 2: No prefix cache
iteration_stats.update_from_finished_request(
finish_reason=FinishReason.STOP,
num_prompt_tokens=2000,
max_tokens_param=100,
req_stats=req_stats,
num_cached_tokens=0,
)
finished_req = iteration_stats.finished_requests[0]
assert finished_req.num_prompt_tokens == 2000
assert finished_req.num_cached_tokens == 0
# Verify calculation: prefill KV = full prompt when no cache
prefill_kv_computed = finished_req.num_prompt_tokens - max(
finished_req.num_cached_tokens, 0
)
assert prefill_kv_computed == 2000
def test_prefill_kv_computed_edge_cases():
"""Test edge cases for prefill KV compute calculation."""
iteration_stats = IterationStats()
req_stats = RequestStateStats(arrival_time=0.0)
req_stats.scheduled_ts = 0.1
req_stats.first_token_ts = 0.5
req_stats.last_token_ts = 1.0
req_stats.num_generation_tokens = 1
# Case 3: Negative num_cached_tokens (shouldn't happen, but handle gracefully)
iteration_stats.update_from_finished_request(
finish_reason=FinishReason.STOP,
num_prompt_tokens=100,
max_tokens_param=10,
req_stats=req_stats,
num_cached_tokens=-1,
)
finished_req = iteration_stats.finished_requests[0]
# max() should handle negative values
prefill_kv_computed = finished_req.num_prompt_tokens - max(
finished_req.num_cached_tokens, 0
)
assert prefill_kv_computed == 100 # Should treat negative as 0
# Case 4: All tokens cached (shouldn't happen in practice)
iteration_stats2 = IterationStats()
iteration_stats2.update_from_finished_request(
finish_reason=FinishReason.STOP,
num_prompt_tokens=100,
max_tokens_param=10,
req_stats=req_stats,
num_cached_tokens=100,
)
finished_req2 = iteration_stats2.finished_requests[0]
prefill_kv_computed2 = finished_req2.num_prompt_tokens - max(
finished_req2.num_cached_tokens, 0
)
assert prefill_kv_computed2 == 0 # All cached, nothing computed