[Core][Observability] Add KV cache residency metrics (#27793)

Introduces three new Prometheus histograms for fine-grained observability of KV cache residency behavior:

vllm:kv_block_lifetime_seconds — total lifetime from allocation to free
vllm:kv_block_idle_before_evict_seconds — idle duration before eviction
vllm:kv_block_reuse_gap_seconds — time between consecutive reuses of the same block

These metrics help operators analyze KV cache efficiency, reuse patterns, and eviction timing beyond simple utilization rates.

Implementation uses monotonic timestamps for accuracy, 1% sampling for minimal overhead (~48 bytes/block), and is fully thread-safe with zero runtime cost when disabled.

Two new runtime flags are introduced:

--kv-cache-metrics – enable KV cache residency metrics
--kv-cache-metrics-sample – control sampling ratio (default: 0.01)

Signed-off-by: Shivam <shivamprasad91@gmail.com>
This commit is contained in:
shivampr 2025-12-01 10:27:53 -08:00 committed by GitHub
parent ec7035c9d4
commit cabc77cc86
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 534 additions and 13 deletions

View File

@ -263,6 +263,29 @@ record:
- End-to-end latency - the interval between frontend `arrival_time`
and the frontend receiving the final token.
### KV Cache Residency Metrics
We also emit a set of histograms that describe how long sampled KV cache
blocks stay resident and how often they are reused. Sampling
(`--kv-cache-metrics-sample`) keeps the overhead tiny; when a block is
chosen we record:
- `lifetime` allocation ⟶ eviction
- `idle before eviction` last touch ⟶ eviction
- `reuse gaps` the pauses between touches when the block gets reused
Those map directly to the Prometheus metrics:
- `vllm:kv_block_lifetime_seconds` how long each sampled block exists.
- `vllm:kv_block_idle_before_evict_seconds` idle tail after the final access.
- `vllm:kv_block_reuse_gap_seconds` time between consecutive touches.
The engine core only ships raw eviction events via `SchedulerStats`; the
frontend drains them, turns them into Prometheus observations, and also
exposes the same data through `LLM.get_metrics()` when logging is on.
Looking at lifetime and idle time on one chart makes it easy to spot
stranded cache or workloads that pin prompts for a long decode.
### Metrics Publishing - Logging
The `LoggingStatLogger` metrics publisher outputs a log `INFO` message

View File

@ -0,0 +1,224 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import patch
import pytest
from vllm.v1.core.kv_cache_metrics import (
BlockMetricsState,
KVCacheMetricsCollector,
)
from vllm.v1.core.kv_cache_utils import KVCacheBlock
class TestBlockMetricsState:
def test_init(self):
with patch("time.monotonic_ns", return_value=1000000000):
state = BlockMetricsState()
assert state.birth_time_ns == 1000000000
assert state.last_access_ns == 1000000000
assert len(state.access_history) == 0
def test_access_tracking(self):
with patch("time.monotonic_ns", return_value=1000000000):
state = BlockMetricsState()
with patch("time.monotonic_ns", return_value=2000000000):
state.record_access()
assert state.last_access_ns == 2000000000
assert list(state.access_history) == [2000000000]
def test_ring_buffer_wraps_at_4(self):
with patch("time.monotonic_ns", return_value=1000000000):
state = BlockMetricsState()
for i in range(5):
t = 1000000000 + (i + 1) * 1000000000
with patch("time.monotonic_ns", return_value=t):
state.record_access()
assert len(state.access_history) == 4
assert list(state.access_history) == [
3000000000,
4000000000,
5000000000,
6000000000,
]
def test_lifetime(self):
with patch("time.monotonic_ns", return_value=1000000000):
state = BlockMetricsState()
with patch("time.monotonic_ns", return_value=6500000000):
assert abs(state.get_lifetime_seconds() - 5.5) < 0.001
def test_idle_time(self):
with patch("time.monotonic_ns", return_value=1000000000):
state = BlockMetricsState()
state.last_access_ns = 2000000000
with patch("time.monotonic_ns", return_value=5200000000):
assert abs(state.get_idle_time_seconds() - 3.2) < 0.001
def test_reuse_gaps(self):
with patch("time.monotonic_ns", return_value=1000000000):
state = BlockMetricsState()
base = 1000000000
for offset in [0, 1.5, 3.0, 5.5]:
state.access_history.append(base + int(offset * 1e9))
gaps = state.get_reuse_gaps_seconds()
assert len(gaps) == 3
assert gaps[0] == 1.5 and gaps[1] == 1.5 and gaps[2] == 2.5
def test_ring_wrap_only_gives_3_gaps(self):
# 5 accesses in size-4 buffer = 3 gaps
with patch("time.monotonic_ns", return_value=1000000000):
state = BlockMetricsState()
for i in range(5):
state.access_history.append(1000000000 + i * 1000000000)
assert len(state.get_reuse_gaps_seconds()) == 3
class TestKVCacheMetricsCollector:
def test_sample_rate_validation(self):
with pytest.raises(AssertionError):
KVCacheMetricsCollector(sample_rate=-0.1)
with pytest.raises(AssertionError):
KVCacheMetricsCollector(sample_rate=1.5)
with pytest.raises(AssertionError):
KVCacheMetricsCollector(sample_rate=0.0)
def test_sampling(self):
c = KVCacheMetricsCollector(sample_rate=1.0)
assert sum(1 for _ in range(100) if c.should_sample_block()) == 100
c = KVCacheMetricsCollector(sample_rate=0.5)
samples = sum(1 for _ in range(1000) if c.should_sample_block())
assert 400 < samples < 600
def test_alloc(self):
c = KVCacheMetricsCollector(sample_rate=1.0)
blocks = [KVCacheBlock(block_id=i) for i in range(5)]
with patch("time.monotonic_ns", return_value=1000000000):
for block in blocks:
c.on_block_allocated(block)
assert len(c.block_metrics) == 5
def test_access(self):
c = KVCacheMetricsCollector(sample_rate=1.0)
block = KVCacheBlock(block_id=0)
with patch("time.monotonic_ns", return_value=1000000000):
c.on_block_allocated(block)
for i in range(3):
t = 1000000000 + (i + 1) * 1000000000
with patch("time.monotonic_ns", return_value=t):
c.on_block_accessed(block)
assert len(c.block_metrics[0].access_history) == 3
def test_evict_no_accesses(self):
# lifetime should equal idle if never accessed
c = KVCacheMetricsCollector(sample_rate=1.0)
block = KVCacheBlock(block_id=0)
with patch("time.monotonic_ns", return_value=1000000000):
c.on_block_allocated(block)
with patch("time.monotonic_ns", return_value=6000000000):
c.on_block_evicted(block)
events = c.drain_events()
assert len(events) == 1
assert abs(events[0].lifetime_seconds - 5.0) < 0.001
assert abs(events[0].idle_seconds - 5.0) < 0.001
def test_evict(self):
c = KVCacheMetricsCollector(sample_rate=1.0)
block = KVCacheBlock(block_id=0)
with patch("time.monotonic_ns", return_value=1000000000):
c.on_block_allocated(block)
with patch("time.monotonic_ns", return_value=2000000000):
c.on_block_accessed(block)
with patch("time.monotonic_ns", return_value=3000000000):
c.on_block_accessed(block)
with patch("time.monotonic_ns", return_value=4000000000):
c.on_block_evicted(block)
events = c.drain_events()
assert len(events) == 1
sample = events[0]
assert abs(sample.lifetime_seconds - 3.0) < 0.001
assert abs(sample.idle_seconds - 1.0) < 0.001
assert sample.reuse_gaps_seconds == (1.0,)
assert 0 not in c.block_metrics
def test_reset(self):
c = KVCacheMetricsCollector(sample_rate=1.0)
with patch("time.monotonic_ns", return_value=1000000000):
for i in range(5):
c.on_block_allocated(KVCacheBlock(block_id=i))
assert len(c.block_metrics) == 5
c.reset()
assert len(c.block_metrics) == 0
with patch("time.monotonic_ns", return_value=2000000000):
c.on_block_allocated(KVCacheBlock(block_id=10))
assert 10 in c.block_metrics
def test_huge_time_jump(self):
c = KVCacheMetricsCollector(sample_rate=1.0)
block = KVCacheBlock(block_id=0)
with patch("time.monotonic_ns", return_value=1000000000):
c.on_block_allocated(block)
with patch("time.monotonic_ns", return_value=9999999999999999):
c.on_block_evicted(block)
events = c.drain_events()
assert len(events) == 1
assert events[0].lifetime_seconds > 0
def test_kv_cache_metrics_collector_smoke() -> None:
"""Simple smoke test for KVCacheMetricsCollector on CPU."""
collector = KVCacheMetricsCollector(sample_rate=1.0)
block = KVCacheBlock(block_id=123)
# Allocate at t = 1.0s.
with patch("time.monotonic_ns", return_value=1_000_000_000):
collector.on_block_allocated(block)
# Access at t = 2.0s and t = 3.0s.
with patch("time.monotonic_ns", return_value=2_000_000_000):
collector.on_block_accessed(block)
with patch("time.monotonic_ns", return_value=3_000_000_000):
collector.on_block_accessed(block)
# Evict at t = 4.0s.
with patch("time.monotonic_ns", return_value=4_000_000_000):
collector.on_block_evicted(block)
events = collector.drain_events()
assert len(events) == 1
event = events[0]
# Lifetime: 1.0s → 4.0s.
assert abs(event.lifetime_seconds - 3.0) < 1e-6
# Idle: last access at 3.0s, evicted at 4.0s.
assert abs(event.idle_seconds - 1.0) < 1e-6
# One reuse gap between the two accesses.
assert event.reuse_gaps_seconds == (1.0,)

View File

@ -5,7 +5,7 @@ from functools import cached_property
from typing import Any, Literal, cast
from packaging.version import parse
from pydantic import field_validator, model_validator
from pydantic import Field, field_validator, model_validator
from pydantic.dataclasses import dataclass
from vllm import version
@ -47,6 +47,14 @@ class ObservabilityConfig:
Note that collecting detailed timing information for each request can be
expensive."""
kv_cache_metrics: bool = False
"""Enable KV cache residency metrics (lifetime, idle time, reuse gaps).
Uses sampling to minimize overhead.
Requires log stats to be enabled (i.e., --disable-log-stats not set)."""
kv_cache_metrics_sample: float = Field(default=0.01, gt=0, le=1)
"""Sampling rate for KV cache metrics (0.0, 1.0]. Default 0.01 = 1% of blocks."""
@cached_property
def collect_model_forward_time(self) -> bool:
"""Whether to collect model forward time for the request."""

View File

@ -517,6 +517,10 @@ class EngineArgs:
collect_detailed_traces: list[DetailedTraceModules] | None = (
ObservabilityConfig.collect_detailed_traces
)
kv_cache_metrics: bool = ObservabilityConfig.kv_cache_metrics
kv_cache_metrics_sample: float = get_field(
ObservabilityConfig, "kv_cache_metrics_sample"
)
scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
scheduler_cls: str | type[object] | None = SchedulerConfig.scheduler_cls
@ -1013,6 +1017,13 @@ class EngineArgs:
"--collect-detailed-traces",
**observability_kwargs["collect_detailed_traces"],
)
observability_group.add_argument(
"--kv-cache-metrics", **observability_kwargs["kv_cache_metrics"]
)
observability_group.add_argument(
"--kv-cache-metrics-sample",
**observability_kwargs["kv_cache_metrics_sample"],
)
# Scheduler arguments
scheduler_kwargs = get_kwargs(SchedulerConfig)
@ -1698,6 +1709,8 @@ class EngineArgs:
show_hidden_metrics_for_version=self.show_hidden_metrics_for_version,
otlp_traces_endpoint=self.otlp_traces_endpoint,
collect_detailed_traces=self.collect_detailed_traces,
kv_cache_metrics=self.kv_cache_metrics,
kv_cache_metrics_sample=self.kv_cache_metrics_sample,
)
# Compilation config overrides

View File

@ -11,6 +11,7 @@ from vllm.distributed.kv_events import (
KVCacheEvent,
)
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_metrics import KVCacheMetricsCollector
from vllm.v1.core.kv_cache_utils import (
BlockHash,
BlockHashList,
@ -140,6 +141,7 @@ class BlockPool:
where different KV cache groups have different block sizes, the
actual block size can be a multiple of hash_block_size.
enable_kv_cache_events: Whether to enable kv cache events.
metrics_collector: Optional metrics collector for tracking block residency.
"""
def __init__(
@ -148,6 +150,7 @@ class BlockPool:
enable_caching: bool,
hash_block_size: int,
enable_kv_cache_events: bool = False,
metrics_collector: KVCacheMetricsCollector | None = None,
):
assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0
self.num_gpu_blocks = num_gpu_blocks
@ -174,6 +177,8 @@ class BlockPool:
self.enable_kv_cache_events = enable_kv_cache_events
self.kv_event_queue: list[KVCacheEvent] = []
self.metrics_collector = metrics_collector
def get_cached_block(
self, block_hash: BlockHash, kv_cache_group_ids: list[int]
) -> list[KVCacheBlock] | None:
@ -308,10 +313,14 @@ class BlockPool:
self._maybe_evict_cached_block(block)
assert block.ref_cnt == 0
block.ref_cnt += 1
if self.metrics_collector:
self.metrics_collector.on_block_allocated(block)
else:
for block in ret:
assert block.ref_cnt == 0
block.ref_cnt += 1
if self.metrics_collector:
self.metrics_collector.on_block_allocated(block)
return ret
def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool:
@ -325,6 +334,10 @@ class BlockPool:
Returns:
True if the block is evicted, False otherwise.
"""
# Clean up metrics tracking first to prevent leaks
if self.metrics_collector:
self.metrics_collector.on_block_evicted(block)
block_hash = block.block_hash
if block_hash is None:
# The block doesn't have hash, eviction is not needed
@ -365,6 +378,8 @@ class BlockPool:
if block.ref_cnt == 0 and not block.is_null:
self.free_block_queue.remove(block)
block.ref_cnt += 1
if self.metrics_collector:
self.metrics_collector.on_block_accessed(block)
def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None:
"""Free a list of blocks. The blocks should be ordered by their
@ -407,6 +422,9 @@ class BlockPool:
for block in self.blocks:
block.reset_hash()
if self.metrics_collector:
self.metrics_collector.reset()
logger.info("Successfully reset prefix cache")
if self.enable_kv_cache_events:

View File

@ -5,6 +5,7 @@ from collections.abc import Sequence
from math import lcm
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_metrics import KVCacheMetricsCollector
from vllm.v1.core.kv_cache_utils import (
BlockHash,
BlockHashList,
@ -39,6 +40,7 @@ class KVCacheCoordinator(ABC):
dcp_world_size: int,
pcp_world_size: int,
hash_block_size: int,
metrics_collector: KVCacheMetricsCollector | None = None,
):
self.kv_cache_config = kv_cache_config
self.max_model_len = max_model_len
@ -49,6 +51,7 @@ class KVCacheCoordinator(ABC):
enable_caching,
hash_block_size,
enable_kv_cache_events,
metrics_collector,
)
# Needs special handling for find_longest_cache_hit if eagle is enabled
@ -228,6 +231,7 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
dcp_world_size: int,
pcp_world_size: int,
hash_block_size: int,
metrics_collector: KVCacheMetricsCollector | None = None,
):
super().__init__(
kv_cache_config,
@ -238,6 +242,7 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
hash_block_size=hash_block_size,
metrics_collector=metrics_collector,
)
self.num_single_type_manager = len(self.single_type_managers)
@ -272,6 +277,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
dcp_world_size: int,
pcp_world_size: int,
hash_block_size: int,
metrics_collector: KVCacheMetricsCollector | None = None,
):
super().__init__(
kv_cache_config,
@ -282,6 +288,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
hash_block_size=hash_block_size,
metrics_collector=metrics_collector,
)
self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[0].kv_cache_spec
self.block_size = self.kv_cache_spec.block_size
@ -338,6 +345,7 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
dcp_world_size: int,
pcp_world_size: int,
hash_block_size: int,
metrics_collector: KVCacheMetricsCollector | None = None,
):
super().__init__(
kv_cache_config,
@ -348,6 +356,7 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
hash_block_size=hash_block_size,
metrics_collector=metrics_collector,
)
# hash_block_size: the block size used to compute block hashes.
# The actual block size usually equals hash_block_size, but in cases where
@ -523,6 +532,7 @@ def get_kv_cache_coordinator(
dcp_world_size: int,
pcp_world_size: int,
hash_block_size: int,
metrics_collector: KVCacheMetricsCollector | None = None,
) -> KVCacheCoordinator:
if not enable_caching:
return KVCacheCoordinatorNoPrefixCache(
@ -530,9 +540,10 @@ def get_kv_cache_coordinator(
max_model_len,
use_eagle,
enable_kv_cache_events,
dcp_world_size,
pcp_world_size,
hash_block_size,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
hash_block_size=hash_block_size,
metrics_collector=metrics_collector,
)
if len(kv_cache_config.kv_cache_groups) == 1:
return UnitaryKVCacheCoordinator(
@ -541,9 +552,10 @@ def get_kv_cache_coordinator(
use_eagle,
enable_caching,
enable_kv_cache_events,
dcp_world_size,
pcp_world_size,
hash_block_size,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
hash_block_size=hash_block_size,
metrics_collector=metrics_collector,
)
return HybridKVCacheCoordinator(
kv_cache_config,
@ -551,7 +563,8 @@ def get_kv_cache_coordinator(
use_eagle,
enable_caching,
enable_kv_cache_events,
dcp_world_size,
pcp_world_size,
hash_block_size,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
hash_block_size=hash_block_size,
metrics_collector=metrics_collector,
)

View File

@ -9,6 +9,7 @@ from typing import Literal, overload
from vllm.distributed.kv_events import KVCacheEvent
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator
from vllm.v1.core.kv_cache_metrics import KVCacheMetricsCollector
from vllm.v1.core.kv_cache_utils import KVCacheBlock
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.stats import PrefixCacheStats
@ -102,12 +103,14 @@ class KVCacheManager:
enable_kv_cache_events: bool = False,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
metrics_collector: KVCacheMetricsCollector | None = None,
) -> None:
self.max_model_len = max_model_len
self.enable_caching = enable_caching
self.use_eagle = use_eagle
self.log_stats = log_stats
self.metrics_collector = metrics_collector
# FIXME: make prefix cache stats conditional on log_stats. We still need
# this comment because when the log stats is enabled there are still
# potential configs we could expose in the future.
@ -122,6 +125,7 @@ class KVCacheManager:
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
hash_block_size=hash_block_size,
metrics_collector=self.metrics_collector,
)
self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups)
self.block_pool = self.coordinator.block_pool

View File

@ -0,0 +1,96 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""KV cache metrics tracking."""
import random
import time
from collections import deque
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from vllm.v1.core.kv_cache_utils import KVCacheBlock
from vllm.v1.metrics.stats import KVCacheEvictionEvent
class BlockMetricsState:
"""Tracks lifecycle metrics for a single KV cache block."""
def __init__(self):
now_ns = time.monotonic_ns()
self.birth_time_ns = now_ns
self.last_access_ns = now_ns
# Bounded to prevent unbounded growth if a block is accessed many times.
self.access_history: deque[int] = deque(maxlen=4)
def record_access(self) -> None:
now_ns = time.monotonic_ns()
self.last_access_ns = now_ns
self.access_history.append(now_ns)
def get_lifetime_seconds(self) -> float:
now_ns = time.monotonic_ns()
return (now_ns - self.birth_time_ns) / 1e9
def get_idle_time_seconds(self) -> float:
now_ns = time.monotonic_ns()
return (now_ns - self.last_access_ns) / 1e9
def get_reuse_gaps_seconds(self) -> list[float]:
if len(self.access_history) < 2:
return []
history = list(self.access_history)
return [(history[i] - history[i - 1]) / 1e9 for i in range(1, len(history))]
class KVCacheMetricsCollector:
"""Collects KV cache residency metrics with sampling."""
def __init__(self, sample_rate: float = 0.01):
assert 0 < sample_rate <= 1.0, (
f"sample_rate must be in (0, 1.0], got {sample_rate}"
)
self.sample_rate = sample_rate
self.block_metrics: dict[int, BlockMetricsState] = {}
self._eviction_events: list[KVCacheEvictionEvent] = []
def should_sample_block(self) -> bool:
return random.random() < self.sample_rate
def on_block_allocated(self, block: "KVCacheBlock") -> None:
if self.should_sample_block():
self.block_metrics[block.block_id] = BlockMetricsState()
def on_block_accessed(self, block: "KVCacheBlock") -> None:
metrics = self.block_metrics.get(block.block_id)
if metrics:
metrics.record_access()
def on_block_evicted(self, block: "KVCacheBlock") -> None:
metrics = self.block_metrics.pop(block.block_id, None)
if not metrics:
return
lifetime = metrics.get_lifetime_seconds()
idle_time = metrics.get_idle_time_seconds()
reuse_gaps = tuple(metrics.get_reuse_gaps_seconds())
self._eviction_events.append(
KVCacheEvictionEvent(
lifetime_seconds=lifetime,
idle_seconds=idle_time,
reuse_gaps_seconds=reuse_gaps,
)
)
def reset(self) -> None:
"""Clear all state on cache reset."""
self.block_metrics.clear()
self._eviction_events.clear()
def drain_events(self) -> list[KVCacheEvictionEvent]:
events = self._eviction_events
self._eviction_events = []
return events

View File

@ -29,6 +29,7 @@ from vllm.v1.core.encoder_cache_manager import (
compute_encoder_budget,
)
from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager
from vllm.v1.core.kv_cache_metrics import KVCacheMetricsCollector
from vllm.v1.core.sched.interface import SchedulerInterface
from vllm.v1.core.sched.output import (
CachedRequestData,
@ -40,7 +41,10 @@ from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_qu
from vllm.v1.core.sched.utils import check_stop, remove_all
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.stats import PrefixCacheStats, SchedulerStats
from vllm.v1.metrics.stats import (
PrefixCacheStats,
SchedulerStats,
)
from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.spec_decode.metrics import SpecDecodingStats
@ -69,6 +73,12 @@ class Scheduler(SchedulerInterface):
self.kv_events_config = vllm_config.kv_events_config
self.parallel_config = vllm_config.parallel_config
self.log_stats = log_stats
self.observability_config = vllm_config.observability_config
self.kv_metrics_collector: KVCacheMetricsCollector | None = None
if self.observability_config.kv_cache_metrics:
self.kv_metrics_collector = KVCacheMetricsCollector(
self.observability_config.kv_cache_metrics_sample,
)
self.structured_output_manager = structured_output_manager
self.is_encoder_decoder = vllm_config.model_config.is_encoder_decoder
@ -187,6 +197,7 @@ class Scheduler(SchedulerInterface):
dcp_world_size=self.dcp_world_size,
pcp_world_size=self.pcp_world_size,
hash_block_size=self.block_size,
metrics_collector=self.kv_metrics_collector,
)
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER
@ -1356,14 +1367,24 @@ class Scheduler(SchedulerInterface):
prefix_cache_stats = self.kv_cache_manager.make_prefix_cache_stats()
assert prefix_cache_stats is not None
connector_prefix_cache_stats = self._make_connector_prefix_cache_stats()
eviction_events = (
self.kv_metrics_collector.drain_events()
if self.kv_metrics_collector is not None
else []
)
spec_stats = spec_decoding_stats
connector_stats_payload = (
kv_connector_stats.data if kv_connector_stats else None
)
return SchedulerStats(
num_running_reqs=len(self.running),
num_waiting_reqs=len(self.waiting),
kv_cache_usage=self.kv_cache_manager.usage,
prefix_cache_stats=prefix_cache_stats,
connector_prefix_cache_stats=connector_prefix_cache_stats,
spec_decoding_stats=spec_decoding_stats,
kv_connector_stats=kv_connector_stats.data if kv_connector_stats else None,
kv_cache_eviction_events=eviction_events,
spec_decoding_stats=spec_stats,
kv_connector_stats=connector_stats_payload,
)
def make_spec_decoding_stats(

View File

@ -375,6 +375,9 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
# Use this flag to hide metrics that were deprecated in
# a previous release and which will be removed future
self.show_hidden_metrics = vllm_config.observability_config.show_hidden_metrics
self.kv_cache_metrics_enabled = (
vllm_config.observability_config.kv_cache_metrics
)
labelnames = ["model_name", "engine"]
model_name = vllm_config.model_config.served_model_name
@ -853,6 +856,79 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
histogram_decode_time_request, engine_indexes, model_name
)
#
# KV Cache residency metrics
#
if self.kv_cache_metrics_enabled:
kv_cache_residency_buckets = [
0.001,
0.002,
0.005,
0.01,
0.02,
0.05,
0.1,
0.2,
0.5,
1,
2,
5,
10,
20,
30,
60,
120,
300,
600,
1200,
1800,
]
histogram_kv_block_lifetime = self._histogram_cls(
name="vllm:kv_block_lifetime_seconds",
documentation=(
"Histogram of KV cache block lifetime from allocation to eviction. "
"Sampled metrics (controlled by --kv-cache-metrics-sample)."
),
buckets=kv_cache_residency_buckets,
labelnames=labelnames,
)
self.histogram_kv_block_lifetime = make_per_engine(
histogram_kv_block_lifetime, engine_indexes, model_name
)
histogram_kv_block_idle_before_evict = self._histogram_cls(
name="vllm:kv_block_idle_before_evict_seconds",
documentation=(
"Histogram of idle time before KV cache block eviction. "
"Sampled metrics (controlled by --kv-cache-metrics-sample)."
),
buckets=kv_cache_residency_buckets,
labelnames=labelnames,
)
self.histogram_kv_block_idle_before_evict = make_per_engine(
histogram_kv_block_idle_before_evict, engine_indexes, model_name
)
histogram_kv_block_reuse_gap = self._histogram_cls(
name="vllm:kv_block_reuse_gap_seconds",
documentation=(
"Histogram of time gaps between consecutive KV cache block "
"accesses. Only the most recent accesses are recorded "
"(ring buffer). Sampled metrics (controlled by "
"--kv-cache-metrics-sample)."
),
buckets=kv_cache_residency_buckets,
labelnames=labelnames,
)
self.histogram_kv_block_reuse_gap = make_per_engine(
histogram_kv_block_reuse_gap, engine_indexes, model_name
)
else:
self.histogram_kv_block_lifetime = {}
self.histogram_kv_block_idle_before_evict = {}
self.histogram_kv_block_reuse_gap = {}
#
# LoRA metrics
#
@ -944,6 +1020,20 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
scheduler_stats.kv_connector_stats, engine_idx
)
if (
self.kv_cache_metrics_enabled
and scheduler_stats.kv_cache_eviction_events
):
lifetime_hist = self.histogram_kv_block_lifetime[engine_idx]
idle_hist = self.histogram_kv_block_idle_before_evict[engine_idx]
reuse_hist = self.histogram_kv_block_reuse_gap[engine_idx]
for event in scheduler_stats.kv_cache_eviction_events:
lifetime_hist.observe(event.lifetime_seconds)
idle_hist.observe(event.idle_seconds)
for gap in event.reuse_gaps_seconds:
reuse_hist.observe(gap)
if self.gauge_lora_info is not None:
running_lora_adapters = ",".join(
scheduler_stats.running_lora_adapters.keys()

View File

@ -150,6 +150,15 @@ class MultiModalCacheStats(BaseCacheStats):
"""
@dataclass
class KVCacheEvictionEvent:
"""Single KV cache block eviction sample."""
lifetime_seconds: float
idle_seconds: float
reuse_gaps_seconds: tuple[float, ...]
@dataclass
class SchedulerStats:
"""Stats associated with the scheduler."""
@ -166,6 +175,8 @@ class SchedulerStats:
prefix_cache_stats: PrefixCacheStats = field(default_factory=PrefixCacheStats)
connector_prefix_cache_stats: PrefixCacheStats | None = None
kv_cache_eviction_events: list[KVCacheEvictionEvent] = field(default_factory=list)
spec_decoding_stats: SpecDecodingStats | None = None
kv_connector_stats: dict[str, Any] | None = None