From c6fa3895e90f6daef4d223188f6b4156311f40c9 Mon Sep 17 00:00:00 2001 From: Mark McLoughlin Date: Fri, 21 Nov 2025 22:45:00 +0000 Subject: [PATCH] [KV Connector] Fix async connector prefix cache metrics (#28585) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Mark McLoughlin Co-authored-by: Nicolò Lucchesi --- tests/v1/core/test_scheduler.py | 17 +++++++++++++---- vllm/v1/core/sched/scheduler.py | 16 ++++++++-------- vllm/v1/request.py | 3 +++ 3 files changed, 24 insertions(+), 12 deletions(-) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 04e738293cd77..d9a69a77c9797 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -1057,7 +1057,8 @@ def test_kv_connector_basic(is_async: bool): ) -def test_external_prefix_cache_metrics(): +@pytest.mark.parametrize("is_async", [False, True]) +def test_external_prefix_cache_metrics(is_async: bool): """ Verify connector prefix cache metrics are updated correctly when the scheduler processes requests with KV connector hits. @@ -1067,7 +1068,9 @@ def test_external_prefix_cache_metrics(): NUM_MATCHED_NEW_TOKENS = 4 scheduler = create_scheduler( enable_prefix_caching=False, - use_kv_connector=mock_kv(matched_tokens=NUM_MATCHED_NEW_TOKENS, is_async=False), + use_kv_connector=mock_kv( + matched_tokens=NUM_MATCHED_NEW_TOKENS, is_async=is_async + ), ) # --- Prepare simple requests --- @@ -1079,9 +1082,15 @@ def test_external_prefix_cache_metrics(): num_tokens=NUM_TOKENS, max_tokens=MAX_TOKENS, ) + req_ids = [] + req_to_index = {} + for i, request in enumerate(requests): + scheduler.add_request(request) + req_ids.append(request.request_id) + req_to_index[request.request_id] = i - for req in requests: - scheduler.add_request(req) + if is_async: + _step_until_kv_transfer_finished(scheduler, req_ids) # --- Trigger scheduling and simulate model output --- output = scheduler.schedule() diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 9195b112d8690..4cb5348cbacc3 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -470,6 +470,7 @@ class Scheduler(SchedulerInterface): skipped_waiting_requests.prepend_request(request) continue + request.num_external_computed_tokens = ext_tokens num_external_computed_tokens = ext_tokens # Total computed tokens (local + external). @@ -576,9 +577,6 @@ class Scheduler(SchedulerInterface): new_computed_blocks + new_blocks, num_external_computed_tokens, ) - self._update_connector_prefix_cache_stats( - request, num_external_computed_tokens - ) # Request was already popped from self.waiting # unless it was re-added above due to new_blocks being None. @@ -590,6 +588,8 @@ class Scheduler(SchedulerInterface): request.status = RequestStatus.WAITING_FOR_REMOTE_KVS continue + self._update_connector_prefix_cache_stats(request) + req_index += 1 self.running.append(request) if self.log_stats: @@ -1380,15 +1380,13 @@ class Scheduler(SchedulerInterface): # KV Connector Related Methods ######################################################################## - def _update_connector_prefix_cache_stats( - self, request: Request, num_external_tokens: int - ) -> None: + def _update_connector_prefix_cache_stats(self, request: Request) -> None: if self.connector_prefix_cache_stats is None: return self.connector_prefix_cache_stats.record( num_tokens=request.num_tokens, - num_hits=num_external_tokens, + num_hits=request.num_external_computed_tokens, preempted=request.num_preemptions > 0, ) @@ -1571,9 +1569,11 @@ class Scheduler(SchedulerInterface): marked_invalid_block = True # Truncate the computed tokens at the first failed block request.num_computed_tokens = idx * self.block_size - total_affected_tokens += ( + num_affected_tokens = ( req_num_computed_tokens - request.num_computed_tokens ) + total_affected_tokens += num_affected_tokens + request.num_external_computed_tokens -= num_affected_tokens if is_affected: if not marked_invalid_block: diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 3d92906fbf4b1..366cdadf5a583 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -121,6 +121,9 @@ class Request: # The number of requests being preempted by the scheduler self.num_preemptions = 0 + # The number of tokens that have been computed remotely. + self.num_external_computed_tokens = 0 + self.block_hashes: list[BlockHash] = [] self.get_hash_new_full_blocks: Callable[[], list[BlockHash]] | None = None if block_hasher is not None: