[Metrics] [KVConnector] Add connector prefix cache hit rate stats (#26245)

Signed-off-by: tovam <tovam@pliops.com>
This commit is contained in:
Tova Movshovitz 2025-10-23 13:21:08 +03:00 committed by GitHub
parent d00ce29d89
commit 88afa11010
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 152 additions and 11 deletions

View File

@ -1014,6 +1014,66 @@ def test_kv_connector_basic():
) )
def test_external_prefix_cache_metrics():
"""
Verify connector prefix cache metrics are updated
correctly when the scheduler processes requests with KV connector hits.
"""
# Setup Scheduler.
scheduler = create_scheduler(
enable_prefix_caching=False,
use_kv_connector=True,
)
# Mock connector to simulate a partial external cache hit
NUM_MATCHED_NEW_TOKENS = 4
scheduler.connector.get_num_new_matched_tokens = Mock(name="method")
scheduler.connector.get_num_new_matched_tokens.return_value = (
NUM_MATCHED_NEW_TOKENS,
False,
)
# --- Prepare simple requests ---
NUM_REQUESTS = 2
NUM_TOKENS = 8
MAX_TOKENS = 2
requests = create_requests(
num_requests=NUM_REQUESTS,
num_tokens=NUM_TOKENS,
max_tokens=MAX_TOKENS,
)
for req in requests:
scheduler.add_request(req)
# --- Trigger scheduling and simulate model output ---
output = scheduler.schedule()
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=[r.request_id for r in requests],
req_id_to_index={r.request_id: i for i, r in enumerate(requests)},
sampled_token_ids=[[1000]] * NUM_REQUESTS,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
# Update scheduler stats
ecos = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
# --- Assertions ---
assert ecos is not None and len(ecos) > 0
assert ecos[0].scheduler_stats is not None
external_stats = ecos[0].scheduler_stats.connector_prefix_cache_stats
assert external_stats is not None
assert external_stats.queries == NUM_TOKENS * NUM_REQUESTS
assert external_stats.hits == NUM_MATCHED_NEW_TOKENS * NUM_REQUESTS
assert external_stats.requests == NUM_REQUESTS
assert external_stats.preempted_requests == 0
def test_kv_connector_unable_to_allocate(): def test_kv_connector_unable_to_allocate():
""" """
Test whether scheduler with KVConnector is able to handle Test whether scheduler with KVConnector is able to handle

View File

@ -208,16 +208,11 @@ class KVCacheManager:
if self.log_stats: if self.log_stats:
assert self.prefix_cache_stats is not None assert self.prefix_cache_stats is not None
if request.num_preemptions > 0: self.prefix_cache_stats.record(
# Previously preempted request num_tokens=request.num_tokens,
self.prefix_cache_stats.preempted_requests += 1 num_hits=num_new_computed_tokens,
self.prefix_cache_stats.preempted_queries += request.num_tokens preempted=request.num_preemptions > 0,
self.prefix_cache_stats.preempted_hits += num_new_computed_tokens )
else:
# New request
self.prefix_cache_stats.requests += 1
self.prefix_cache_stats.queries += request.num_tokens
self.prefix_cache_stats.hits += num_new_computed_tokens
return self.create_kv_cache_blocks(computed_blocks), num_new_computed_tokens return self.create_kv_cache_blocks(computed_blocks), num_new_computed_tokens

View File

@ -28,7 +28,7 @@ from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_qu
from vllm.v1.core.sched.utils import check_stop, remove_all from vllm.v1.core.sched.utils import check_stop, remove_all
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.metrics.stats import PrefixCacheStats, SchedulerStats
from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus
from vllm.v1.spec_decode.metrics import SpecDecodingStats from vllm.v1.spec_decode.metrics import SpecDecodingStats
@ -84,6 +84,7 @@ class Scheduler(SchedulerInterface):
# will have a corresponding KVConnector with Role=WORKER. # will have a corresponding KVConnector with Role=WORKER.
# KV Connector pushes/pull of remote KVs for P/D and offloading. # KV Connector pushes/pull of remote KVs for P/D and offloading.
self.connector = None self.connector = None
self.connector_prefix_cache_stats: PrefixCacheStats | None = None
if self.vllm_config.kv_transfer_config is not None: if self.vllm_config.kv_transfer_config is not None:
assert len(self.kv_cache_config.kv_cache_groups) == 1, ( assert len(self.kv_cache_config.kv_cache_groups) == 1, (
"Multiple KV cache groups are not currently supported " "Multiple KV cache groups are not currently supported "
@ -95,6 +96,8 @@ class Scheduler(SchedulerInterface):
self.connector = KVConnectorFactory.create_connector( self.connector = KVConnectorFactory.create_connector(
config=self.vllm_config, role=KVConnectorRole.SCHEDULER config=self.vllm_config, role=KVConnectorRole.SCHEDULER
) )
if self.log_stats:
self.connector_prefix_cache_stats = PrefixCacheStats()
self.kv_event_publisher = EventPublisherFactory.create( self.kv_event_publisher = EventPublisherFactory.create(
self.kv_events_config, self.kv_events_config,
@ -526,6 +529,9 @@ class Scheduler(SchedulerInterface):
new_computed_blocks + new_blocks, new_computed_blocks + new_blocks,
num_external_computed_tokens, num_external_computed_tokens,
) )
self._update_connector_prefix_cache_stats(
request, num_external_computed_tokens
)
# Request was already popped from self.waiting # Request was already popped from self.waiting
# unless it was re-added above due to new_blocks being None. # unless it was re-added above due to new_blocks being None.
@ -1247,11 +1253,13 @@ class Scheduler(SchedulerInterface):
return None return None
prefix_cache_stats = self.kv_cache_manager.make_prefix_cache_stats() prefix_cache_stats = self.kv_cache_manager.make_prefix_cache_stats()
assert prefix_cache_stats is not None assert prefix_cache_stats is not None
connector_prefix_cache_stats = self._make_connector_prefix_cache_stats()
return SchedulerStats( return SchedulerStats(
num_running_reqs=len(self.running), num_running_reqs=len(self.running),
num_waiting_reqs=len(self.waiting), num_waiting_reqs=len(self.waiting),
kv_cache_usage=self.kv_cache_manager.usage, kv_cache_usage=self.kv_cache_manager.usage,
prefix_cache_stats=prefix_cache_stats, prefix_cache_stats=prefix_cache_stats,
connector_prefix_cache_stats=connector_prefix_cache_stats,
spec_decoding_stats=spec_decoding_stats, spec_decoding_stats=spec_decoding_stats,
num_corrupted_reqs=sum(req.is_output_corrupted for req in self.running), num_corrupted_reqs=sum(req.is_output_corrupted for req in self.running),
kv_connector_stats=kv_connector_stats.data if kv_connector_stats else None, kv_connector_stats=kv_connector_stats.data if kv_connector_stats else None,
@ -1282,6 +1290,25 @@ class Scheduler(SchedulerInterface):
# KV Connector Related Methods # KV Connector Related Methods
######################################################################## ########################################################################
def _update_connector_prefix_cache_stats(
self, request: Request, num_external_tokens: int
) -> None:
if self.connector_prefix_cache_stats is None:
return
self.connector_prefix_cache_stats.record(
num_tokens=request.num_tokens,
num_hits=num_external_tokens,
preempted=request.num_preemptions > 0,
)
def _make_connector_prefix_cache_stats(self) -> PrefixCacheStats | None:
if self.connector_prefix_cache_stats is None:
return None
stats = self.connector_prefix_cache_stats
self.connector_prefix_cache_stats = PrefixCacheStats()
return stats
def get_kv_connector(self) -> KVConnectorBase_V1 | None: def get_kv_connector(self) -> KVConnectorBase_V1 | None:
return self.connector return self.connector

View File

@ -93,6 +93,7 @@ class LoggingStatLogger(StatLoggerBase):
# Caching metrics. This cannot be reset. # Caching metrics. This cannot be reset.
# TODO: Make the interval configurable. # TODO: Make the interval configurable.
self.prefix_caching_metrics = CachingMetrics() self.prefix_caching_metrics = CachingMetrics()
self.connector_prefix_caching_metrics = CachingMetrics()
self.mm_caching_metrics = CachingMetrics() self.mm_caching_metrics = CachingMetrics()
self.spec_decoding_logging = SpecDecodingLogging() self.spec_decoding_logging = SpecDecodingLogging()
@ -140,6 +141,11 @@ class LoggingStatLogger(StatLoggerBase):
if scheduler_stats is not None: if scheduler_stats is not None:
self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats) self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats)
if scheduler_stats.connector_prefix_cache_stats is not None:
self.connector_prefix_caching_metrics.observe(
scheduler_stats.connector_prefix_cache_stats
)
if scheduler_stats.spec_decoding_stats is not None: if scheduler_stats.spec_decoding_stats is not None:
self.spec_decoding_logging.observe(scheduler_stats.spec_decoding_stats) self.spec_decoding_logging.observe(scheduler_stats.spec_decoding_stats)
if kv_connector_stats := scheduler_stats.kv_connector_stats: if kv_connector_stats := scheduler_stats.kv_connector_stats:
@ -192,6 +198,9 @@ class LoggingStatLogger(StatLoggerBase):
self.last_scheduler_stats.kv_cache_usage * 100, self.last_scheduler_stats.kv_cache_usage * 100,
self.prefix_caching_metrics.hit_rate * 100, self.prefix_caching_metrics.hit_rate * 100,
] ]
if not self.connector_prefix_caching_metrics.empty:
log_parts.append("External prefix cache hit rate: %.1f%%")
log_args.append(self.connector_prefix_caching_metrics.hit_rate * 100)
if not self.mm_caching_metrics.empty: if not self.mm_caching_metrics.empty:
log_parts.append("MM cache hit rate: %.1f%%") log_parts.append("MM cache hit rate: %.1f%%")
log_args.append(self.mm_caching_metrics.hit_rate * 100) log_args.append(self.mm_caching_metrics.hit_rate * 100)
@ -457,6 +466,34 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
counter_prefix_cache_hits, engine_indexes, model_name counter_prefix_cache_hits, engine_indexes, model_name
) )
#
# External - KV connector prefix cache
#
counter_connector_prefix_cache_queries = self._counter_cls(
name="vllm:external_prefix_cache_queries",
documentation=(
"External prefix cache queries from KV connector "
"cross-instance cache sharing, in terms of number of queried tokens."
),
labelnames=labelnames,
)
self.counter_connector_prefix_cache_queries = make_per_engine(
counter_connector_prefix_cache_queries, engine_indexes, model_name
)
counter_connector_prefix_cache_hits = self._counter_cls(
name="vllm:external_prefix_cache_hits",
documentation=(
"External prefix cache hits from KV connector "
"cross-instance cache sharing, in terms of number of cached tokens."
),
labelnames=labelnames,
)
self.counter_connector_prefix_cache_hits = make_per_engine(
counter_connector_prefix_cache_hits, engine_indexes, model_name
)
# #
# Multi-modal cache # Multi-modal cache
# #
@ -883,6 +920,14 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
scheduler_stats.prefix_cache_stats.hits scheduler_stats.prefix_cache_stats.hits
) )
if scheduler_stats.connector_prefix_cache_stats is not None:
self.counter_connector_prefix_cache_queries[engine_idx].inc(
scheduler_stats.connector_prefix_cache_stats.queries
)
self.counter_connector_prefix_cache_hits[engine_idx].inc(
scheduler_stats.connector_prefix_cache_stats.hits
)
if scheduler_stats.spec_decoding_stats is not None: if scheduler_stats.spec_decoding_stats is not None:
self.spec_decoding_prom.observe( self.spec_decoding_prom.observe(
scheduler_stats.spec_decoding_stats, engine_idx scheduler_stats.spec_decoding_stats, engine_idx

View File

@ -126,6 +126,19 @@ class PrefixCacheStats(BaseCacheStats):
preempted_hits: int = 0 preempted_hits: int = 0
"""The `hits` number for preempted requests.""" """The `hits` number for preempted requests."""
def record(self, num_tokens: int, num_hits: int, preempted: bool) -> None:
"""Aggregate request information into the stats."""
if preempted:
# Previously preempted request
self.preempted_requests += 1
self.preempted_queries += num_tokens
self.preempted_hits += num_hits
else:
# New request
self.requests += 1
self.queries += num_tokens
self.hits += num_hits
@dataclass @dataclass
class MultiModalCacheStats(BaseCacheStats): class MultiModalCacheStats(BaseCacheStats):
@ -151,6 +164,7 @@ class SchedulerStats:
kv_cache_usage: float = 0.0 kv_cache_usage: float = 0.0
prefix_cache_stats: PrefixCacheStats = field(default_factory=PrefixCacheStats) prefix_cache_stats: PrefixCacheStats = field(default_factory=PrefixCacheStats)
connector_prefix_cache_stats: PrefixCacheStats | None = None
spec_decoding_stats: SpecDecodingStats | None = None spec_decoding_stats: SpecDecodingStats | None = None
kv_connector_stats: dict[str, Any] | None = None kv_connector_stats: dict[str, Any] | None = None