From ec965569d94c09eb1c85d235319b24d1b795d048 Mon Sep 17 00:00:00 2001 From: Yihua Cheng Date: Wed, 17 Dec 2025 21:31:34 -0800 Subject: [PATCH] [KV connector][LMCache] Only record the cuda event when there are request to store/load (#30814) Signed-off-by: ApostaC --- .../multi_process_adapter.py | 1 + .../kv_connector/v1/lmcache_mp_connector.py | 56 +++++++++++++------ 2 files changed, 40 insertions(+), 17 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py index 6acfb73997f25..6656b5a25f83d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py @@ -262,6 +262,7 @@ class LMCacheMPWorkerAdapter: ): keys = [] block_ids = [] + for op in ops: keys.extend(self._block_hashes_to_keys(op.block_hashes)) block_ids.extend(op.block_ids) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py index 78256a6552c22..995708b89bc26 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py @@ -24,6 +24,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration import ( ) from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.outputs import KVConnectorOutput +from vllm.v1.request import RequestStatus from vllm.v1.utils import ConstantList if TYPE_CHECKING: @@ -211,7 +212,7 @@ class LMCacheMPRequestTracker: """ self.num_stored_blocks += num_new_blocks - def update_block_ids( + def append_block_ids( self, new_block_ids: list[int], ): @@ -455,10 +456,6 @@ class LMCacheMPConnector(KVConnectorBase_V1): metadata = self._get_connector_metadata() assert isinstance(metadata, LMCacheMPConnectorMetadata) - with torch.cuda.stream(torch.cuda.current_stream()): - event = torch.cuda.Event(interprocess=True) - event.record() - request_ids = [] ops = [] @@ -468,10 +465,14 @@ class LMCacheMPConnector(KVConnectorBase_V1): request_ids.append(meta.request_id) ops.append(meta.op) - if len(request_ids) > 0: - self.worker_adapter.batched_submit_retrieve_requests( - request_ids, ops, event - ) + if len(request_ids) == 0: + return + + with torch.cuda.stream(torch.cuda.current_stream()): + event = torch.cuda.Event(interprocess=True) + event.record() + + self.worker_adapter.batched_submit_retrieve_requests(request_ids, ops, event) def wait_for_layer_load(self, layer_name: str) -> None: """ @@ -518,10 +519,6 @@ class LMCacheMPConnector(KVConnectorBase_V1): metadata = self._get_connector_metadata() assert isinstance(metadata, LMCacheMPConnectorMetadata) - with torch.cuda.stream(torch.cuda.current_stream()): - event = torch.cuda.Event(interprocess=True) - event.record() - request_ids = [] ops = [] for meta in metadata.requests: @@ -530,8 +527,14 @@ class LMCacheMPConnector(KVConnectorBase_V1): request_ids.append(meta.request_id) ops.append(meta.op) - if len(request_ids) > 0: - self.worker_adapter.batched_submit_store_requests(request_ids, ops, event) + if len(request_ids) == 0: + return + + with torch.cuda.stream(torch.cuda.current_stream()): + event = torch.cuda.Event(interprocess=True) + event.record() + + self.worker_adapter.batched_submit_store_requests(request_ids, ops, event) def get_finished( self, finished_req_ids: set[str] @@ -627,6 +630,9 @@ class LMCacheMPConnector(KVConnectorBase_V1): into account. """ tracker = self._get_or_create_request_tracker(request) + # TODO: support loading KV for preempted requests in the future + if request.status == RequestStatus.PREEMPTED: + return 0, False self.scheduler_adapter.maybe_submit_lookup_request( request.request_id, convert_block_hashes_to_bytes(request.block_hashes) @@ -683,7 +689,7 @@ class LMCacheMPConnector(KVConnectorBase_V1): # No matter we need to retrieve or not, we need to update # the block ids into the tracker - tracker.update_block_ids(block_ids) + tracker.append_block_ids(block_ids) # Update the state of the tracker condition = tracker.needs_retrieve() @@ -866,7 +872,8 @@ class LMCacheMPConnector(KVConnectorBase_V1): # Update block ids new_block_ids = reformat_block_ids(cached_reqs.new_block_ids[idx]) - request_tracker.update_block_ids(new_block_ids) + if request_id not in cached_reqs.resumed_req_ids: + request_tracker.append_block_ids(new_block_ids) # Update new scheduled tokens num_new_tokens = cached_reqs.num_computed_tokens[idx] @@ -889,6 +896,21 @@ class LMCacheMPConnector(KVConnectorBase_V1): self, request: "Request" ) -> LMCacheMPRequestTracker: request_id = request.request_id + # Remove the old trackers that is created before the preemption + if ( + request.status == RequestStatus.PREEMPTED + and request_id in self.request_trackers + ): + tracker = self.request_trackers[request_id] + + # NOTE: since this function may be called multiple times + # for a single request (because get_num_new_matched_tokens + # may be called multiple times) for the same request, we + # will only do the remove if the tracker is not in the "fresh" + # state, i.e., PREFETCHING + if tracker.state != LMCacheMPRequestState.PREFETCHING: + self.request_trackers.pop(request_id) + if request_id not in self.request_trackers: new_tracker = LMCacheMPRequestTracker(request) self.request_trackers[request_id] = new_tracker