[KV connector][LMCache] Only record the cuda event when there are request to store/load (#30814)

Signed-off-by: ApostaC <yihua98@uchicago.edu>
This commit is contained in:
Yihua Cheng 2025-12-17 21:31:34 -08:00 committed by GitHub
parent 82dc338ad6
commit ec965569d9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 40 additions and 17 deletions

View File

@ -262,6 +262,7 @@ class LMCacheMPWorkerAdapter:
):
keys = []
block_ids = []
for op in ops:
keys.extend(self._block_hashes_to_keys(op.block_hashes))
block_ids.extend(op.block_ids)

View File

@ -24,6 +24,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration import (
)
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import KVConnectorOutput
from vllm.v1.request import RequestStatus
from vllm.v1.utils import ConstantList
if TYPE_CHECKING:
@ -211,7 +212,7 @@ class LMCacheMPRequestTracker:
"""
self.num_stored_blocks += num_new_blocks
def update_block_ids(
def append_block_ids(
self,
new_block_ids: list[int],
):
@ -455,10 +456,6 @@ class LMCacheMPConnector(KVConnectorBase_V1):
metadata = self._get_connector_metadata()
assert isinstance(metadata, LMCacheMPConnectorMetadata)
with torch.cuda.stream(torch.cuda.current_stream()):
event = torch.cuda.Event(interprocess=True)
event.record()
request_ids = []
ops = []
@ -468,10 +465,14 @@ class LMCacheMPConnector(KVConnectorBase_V1):
request_ids.append(meta.request_id)
ops.append(meta.op)
if len(request_ids) > 0:
self.worker_adapter.batched_submit_retrieve_requests(
request_ids, ops, event
)
if len(request_ids) == 0:
return
with torch.cuda.stream(torch.cuda.current_stream()):
event = torch.cuda.Event(interprocess=True)
event.record()
self.worker_adapter.batched_submit_retrieve_requests(request_ids, ops, event)
def wait_for_layer_load(self, layer_name: str) -> None:
"""
@ -518,10 +519,6 @@ class LMCacheMPConnector(KVConnectorBase_V1):
metadata = self._get_connector_metadata()
assert isinstance(metadata, LMCacheMPConnectorMetadata)
with torch.cuda.stream(torch.cuda.current_stream()):
event = torch.cuda.Event(interprocess=True)
event.record()
request_ids = []
ops = []
for meta in metadata.requests:
@ -530,8 +527,14 @@ class LMCacheMPConnector(KVConnectorBase_V1):
request_ids.append(meta.request_id)
ops.append(meta.op)
if len(request_ids) > 0:
self.worker_adapter.batched_submit_store_requests(request_ids, ops, event)
if len(request_ids) == 0:
return
with torch.cuda.stream(torch.cuda.current_stream()):
event = torch.cuda.Event(interprocess=True)
event.record()
self.worker_adapter.batched_submit_store_requests(request_ids, ops, event)
def get_finished(
self, finished_req_ids: set[str]
@ -627,6 +630,9 @@ class LMCacheMPConnector(KVConnectorBase_V1):
into account.
"""
tracker = self._get_or_create_request_tracker(request)
# TODO: support loading KV for preempted requests in the future
if request.status == RequestStatus.PREEMPTED:
return 0, False
self.scheduler_adapter.maybe_submit_lookup_request(
request.request_id, convert_block_hashes_to_bytes(request.block_hashes)
@ -683,7 +689,7 @@ class LMCacheMPConnector(KVConnectorBase_V1):
# No matter we need to retrieve or not, we need to update
# the block ids into the tracker
tracker.update_block_ids(block_ids)
tracker.append_block_ids(block_ids)
# Update the state of the tracker
condition = tracker.needs_retrieve()
@ -866,7 +872,8 @@ class LMCacheMPConnector(KVConnectorBase_V1):
# Update block ids
new_block_ids = reformat_block_ids(cached_reqs.new_block_ids[idx])
request_tracker.update_block_ids(new_block_ids)
if request_id not in cached_reqs.resumed_req_ids:
request_tracker.append_block_ids(new_block_ids)
# Update new scheduled tokens
num_new_tokens = cached_reqs.num_computed_tokens[idx]
@ -889,6 +896,21 @@ class LMCacheMPConnector(KVConnectorBase_V1):
self, request: "Request"
) -> LMCacheMPRequestTracker:
request_id = request.request_id
# Remove the old trackers that is created before the preemption
if (
request.status == RequestStatus.PREEMPTED
and request_id in self.request_trackers
):
tracker = self.request_trackers[request_id]
# NOTE: since this function may be called multiple times
# for a single request (because get_num_new_matched_tokens
# may be called multiple times) for the same request, we
# will only do the remove if the tracker is not in the "fresh"
# state, i.e., PREFETCHING
if tracker.state != LMCacheMPRequestState.PREFETCHING:
self.request_trackers.pop(request_id)
if request_id not in self.request_trackers:
new_tracker = LMCacheMPRequestTracker(request)
self.request_trackers[request_id] = new_tracker