diff --git a/tests/v1/metrics/test_stats.py b/tests/v1/metrics/test_stats.py index 48067def8357..7d902bbc6fc2 100644 --- a/tests/v1/metrics/test_stats.py +++ b/tests/v1/metrics/test_stats.py @@ -1,8 +1,109 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm.v1.metrics.stats import IterationStats +from vllm.v1.engine import FinishReason +from vllm.v1.metrics.stats import IterationStats, RequestStateStats def test_iteration_stats_repr(): iteration_stats = IterationStats() assert repr(iteration_stats).startswith("IterationStats(") + + +def test_prefill_kv_computed_with_cache(): + """Test that prefill KV compute correctly excludes cached tokens.""" + iteration_stats = IterationStats() + req_stats = RequestStateStats(arrival_time=0.0) + req_stats.scheduled_ts = 0.1 + req_stats.first_token_ts = 0.5 + req_stats.last_token_ts = 5.0 + req_stats.num_generation_tokens = 50 + + # Case 1: With prefix cache (1200 tokens cached) + iteration_stats.update_from_finished_request( + finish_reason=FinishReason.STOP, + num_prompt_tokens=10000, + max_tokens_param=100, + req_stats=req_stats, + num_cached_tokens=1200, + ) + + finished_req = iteration_stats.finished_requests[0] + assert finished_req.num_prompt_tokens == 10000 + assert finished_req.num_cached_tokens == 1200 + + # Verify calculation: prefill KV = prompt tokens - cached tokens + prefill_kv_computed = finished_req.num_prompt_tokens - max( + finished_req.num_cached_tokens, 0 + ) + assert prefill_kv_computed == 8800 # 10000 - 1200 + + +def test_prefill_kv_computed_no_cache(): + """Test prefill KV compute without prefix caching.""" + iteration_stats = IterationStats() + req_stats = RequestStateStats(arrival_time=0.0) + req_stats.scheduled_ts = 0.1 + req_stats.first_token_ts = 0.5 + req_stats.last_token_ts = 2.0 + req_stats.num_generation_tokens = 10 + + # Case 2: No prefix cache + iteration_stats.update_from_finished_request( + finish_reason=FinishReason.STOP, + num_prompt_tokens=2000, + max_tokens_param=100, + req_stats=req_stats, + num_cached_tokens=0, + ) + + finished_req = iteration_stats.finished_requests[0] + assert finished_req.num_prompt_tokens == 2000 + assert finished_req.num_cached_tokens == 0 + + # Verify calculation: prefill KV = full prompt when no cache + prefill_kv_computed = finished_req.num_prompt_tokens - max( + finished_req.num_cached_tokens, 0 + ) + assert prefill_kv_computed == 2000 + + +def test_prefill_kv_computed_edge_cases(): + """Test edge cases for prefill KV compute calculation.""" + iteration_stats = IterationStats() + req_stats = RequestStateStats(arrival_time=0.0) + req_stats.scheduled_ts = 0.1 + req_stats.first_token_ts = 0.5 + req_stats.last_token_ts = 1.0 + req_stats.num_generation_tokens = 1 + + # Case 3: Negative num_cached_tokens (shouldn't happen, but handle gracefully) + iteration_stats.update_from_finished_request( + finish_reason=FinishReason.STOP, + num_prompt_tokens=100, + max_tokens_param=10, + req_stats=req_stats, + num_cached_tokens=-1, + ) + + finished_req = iteration_stats.finished_requests[0] + # max() should handle negative values + prefill_kv_computed = finished_req.num_prompt_tokens - max( + finished_req.num_cached_tokens, 0 + ) + assert prefill_kv_computed == 100 # Should treat negative as 0 + + # Case 4: All tokens cached (shouldn't happen in practice) + iteration_stats2 = IterationStats() + iteration_stats2.update_from_finished_request( + finish_reason=FinishReason.STOP, + num_prompt_tokens=100, + max_tokens_param=10, + req_stats=req_stats, + num_cached_tokens=100, + ) + + finished_req2 = iteration_stats2.finished_requests[0] + prefill_kv_computed2 = finished_req2.num_prompt_tokens - max( + finished_req2.num_cached_tokens, 0 + ) + assert prefill_kv_computed2 == 0 # All cached, nothing computed diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index e85fbb4ee0fb..9be3f4da7352 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -650,6 +650,7 @@ class OutputProcessor: ), max_tokens_param=req_state.max_tokens_param, req_stats=req_state.stats, + num_cached_tokens=req_state.num_cached_tokens, ) self.lora_states.request_finished(req_state.request_id, req_state.lora_name) diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 882e0ce0b2e0..9eaee1bb97bb 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -870,6 +870,19 @@ class PrometheusStatLogger(AggregateStatLoggerBase): histogram_decode_time_request, engine_indexes, model_name ) + histogram_prefill_kv_computed_request = self._histogram_cls( + name="vllm:request_prefill_kv_computed_tokens", + documentation=( + "Histogram of new KV tokens computed during prefill " + "(excluding cached tokens)." + ), + buckets=build_1_2_5_buckets(max_model_len), + labelnames=labelnames, + ) + self.histogram_prefill_kv_computed_request = make_per_engine( + histogram_prefill_kv_computed_request, engine_indexes, model_name + ) + # # KV Cache residency metrics # @@ -1118,6 +1131,13 @@ class PrometheusStatLogger(AggregateStatLoggerBase): self.histogram_decode_time_request[engine_idx].observe( finished_request.decode_time ) + # Calculate prefill KV compute (excludes cached tokens) + prefill_kv_computed = finished_request.num_prompt_tokens - max( + finished_request.num_cached_tokens, 0 + ) + self.histogram_prefill_kv_computed_request[engine_idx].observe( + prefill_kv_computed + ) self.histogram_num_prompt_tokens_request[engine_idx].observe( finished_request.num_prompt_tokens ) diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index 733d3ae12e67..a0cc58d0a64e 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -224,6 +224,7 @@ class FinishedRequestStats: decode_time: float = 0.0 mean_time_per_output_token: float = 0.0 is_corrupted: bool = False + num_cached_tokens: int = 0 class IterationStats: @@ -330,6 +331,7 @@ class IterationStats: num_prompt_tokens: int, max_tokens_param: int | None, req_stats: RequestStateStats, + num_cached_tokens: int = 0, ): e2e_latency = self._time_since(req_stats.arrival_time) @@ -367,6 +369,7 @@ class IterationStats: decode_time=decode_time, mean_time_per_output_token=mean_time_per_output_token, is_corrupted=req_stats.is_corrupted, + num_cached_tokens=num_cached_tokens, ) self.finished_requests.append(finished_req)