diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 319e6e84fba1..4bf6bbbfeae2 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -513,27 +513,27 @@ def test_hash_request_tokens_no_mm_inputs(hash_fn): assert block_hashes[1] == hash_fn((block_hashes[0], (3, 4, 5), None)) +def _stats(requests: int, queries: int, hits: int) -> PrefixCacheStats: + return PrefixCacheStats(requests=requests, queries=queries, hits=hits) + + def test_metrics(): """ Test the prefix caching metrics. """ - - def stats(requests, queries, hits): - return PrefixCacheStats(requests=requests, queries=queries, hits=hits) - metrics = PrefixCachingMetrics(max_recent_requests=5) assert metrics.hit_rate == 0.0 - metrics.observe(stats(1, 20, 9)) + metrics.observe(_stats(1, 20, 9)) # 9 / 20 = 0.45 assert metrics.hit_rate == 0.45 - metrics.observe(stats(4, 80, 16)) + metrics.observe(_stats(4, 80, 16)) # 25 / 100 = 0.25 assert metrics.hit_rate == 0.25 - metrics.observe(stats(1, 10, 2)) + metrics.observe(_stats(1, 10, 2)) # Remove (20, 9) and add (10, 2): 18 / 90 = 0.2 assert metrics.aggregated_requests == 5 @@ -549,6 +549,38 @@ def test_metrics(): assert not metrics.query_queue +def test_metrics_empty_stats(): + """ + Test the prefix caching metrics with empty stats. + """ + metrics = PrefixCachingMetrics(max_recent_requests=5) + metrics.observe(_stats(0, 0, 0)) + metrics.observe(_stats(1, 20, 9)) + metrics.observe(_stats(0, 0, 0)) + metrics.observe(_stats(4, 80, 16)) + metrics.observe(_stats(0, 0, 0)) + metrics.observe(_stats(1, 10, 2)) + # Remove (20, 9) and add (10, 2): 18 / 90 = 0.2 + assert metrics.aggregated_requests == 5 + assert metrics.aggregated_query_total == 90 + assert metrics.aggregated_query_hit == 18 + assert metrics.hit_rate == 0.2 + + # Only the latest added stats preserved 10 / 20 = 0.5 + metrics.observe(_stats(11, 20, 10)) + assert metrics.aggregated_requests == 11 + assert metrics.aggregated_query_total == 20 + assert metrics.aggregated_query_hit == 10 + assert metrics.hit_rate == 0.5 + + # Only the latest added stats preserved 30 / 40 = 0.75 + metrics.observe(_stats(22, 40, 30)) + assert metrics.aggregated_requests == 22 + assert metrics.aggregated_query_total == 40 + assert metrics.aggregated_query_hit == 30 + assert metrics.hit_rate == 0.75 + + def test_get_kv_cache_configs_multiple_workers(): model_config = ModelConfig(max_model_len=16) vllm_config = VllmConfig(model_config=model_config) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 9fab36aba91b..bc2ec5e42ea2 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -127,14 +127,23 @@ class PrefixCachingMetrics: if stats.reset: self.reset() + # DO NOT appending empty stats to avoid helpful info get kicked out + # due to sliding window. + if stats.requests == 0: + return + # Update the metrics. self.query_queue.append((stats.requests, stats.queries, stats.hits)) self.aggregated_requests += stats.requests self.aggregated_query_total += stats.queries self.aggregated_query_hit += stats.hits - # Remove the oldest stats if the number of requests exceeds. - if self.aggregated_requests > self.max_recent_requests: + # Remove the oldest stats until number of requests does not exceed + # the limit. + # NOTE: We preserve the latest added stats regardless. + while len( + self.query_queue + ) > 1 and self.aggregated_requests > self.max_recent_requests: old_requests, old_queries, old_hits = self.query_queue.popleft() self.aggregated_requests -= old_requests self.aggregated_query_total -= old_queries