mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-21 08:04:29 +08:00
[KV connector][LMCache] Only record the cuda event when there are request to store/load (#30814)
Signed-off-by: ApostaC <yihua98@uchicago.edu>
This commit is contained in:
parent
82dc338ad6
commit
ec965569d9
@ -262,6 +262,7 @@ class LMCacheMPWorkerAdapter:
|
||||
):
|
||||
keys = []
|
||||
block_ids = []
|
||||
|
||||
for op in ops:
|
||||
keys.extend(self._block_hashes_to_keys(op.block_hashes))
|
||||
block_ids.extend(op.block_ids)
|
||||
|
||||
@ -24,6 +24,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration import (
|
||||
)
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.outputs import KVConnectorOutput
|
||||
from vllm.v1.request import RequestStatus
|
||||
from vllm.v1.utils import ConstantList
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -211,7 +212,7 @@ class LMCacheMPRequestTracker:
|
||||
"""
|
||||
self.num_stored_blocks += num_new_blocks
|
||||
|
||||
def update_block_ids(
|
||||
def append_block_ids(
|
||||
self,
|
||||
new_block_ids: list[int],
|
||||
):
|
||||
@ -455,10 +456,6 @@ class LMCacheMPConnector(KVConnectorBase_V1):
|
||||
metadata = self._get_connector_metadata()
|
||||
assert isinstance(metadata, LMCacheMPConnectorMetadata)
|
||||
|
||||
with torch.cuda.stream(torch.cuda.current_stream()):
|
||||
event = torch.cuda.Event(interprocess=True)
|
||||
event.record()
|
||||
|
||||
request_ids = []
|
||||
ops = []
|
||||
|
||||
@ -468,10 +465,14 @@ class LMCacheMPConnector(KVConnectorBase_V1):
|
||||
request_ids.append(meta.request_id)
|
||||
ops.append(meta.op)
|
||||
|
||||
if len(request_ids) > 0:
|
||||
self.worker_adapter.batched_submit_retrieve_requests(
|
||||
request_ids, ops, event
|
||||
)
|
||||
if len(request_ids) == 0:
|
||||
return
|
||||
|
||||
with torch.cuda.stream(torch.cuda.current_stream()):
|
||||
event = torch.cuda.Event(interprocess=True)
|
||||
event.record()
|
||||
|
||||
self.worker_adapter.batched_submit_retrieve_requests(request_ids, ops, event)
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
"""
|
||||
@ -518,10 +519,6 @@ class LMCacheMPConnector(KVConnectorBase_V1):
|
||||
metadata = self._get_connector_metadata()
|
||||
assert isinstance(metadata, LMCacheMPConnectorMetadata)
|
||||
|
||||
with torch.cuda.stream(torch.cuda.current_stream()):
|
||||
event = torch.cuda.Event(interprocess=True)
|
||||
event.record()
|
||||
|
||||
request_ids = []
|
||||
ops = []
|
||||
for meta in metadata.requests:
|
||||
@ -530,8 +527,14 @@ class LMCacheMPConnector(KVConnectorBase_V1):
|
||||
request_ids.append(meta.request_id)
|
||||
ops.append(meta.op)
|
||||
|
||||
if len(request_ids) > 0:
|
||||
self.worker_adapter.batched_submit_store_requests(request_ids, ops, event)
|
||||
if len(request_ids) == 0:
|
||||
return
|
||||
|
||||
with torch.cuda.stream(torch.cuda.current_stream()):
|
||||
event = torch.cuda.Event(interprocess=True)
|
||||
event.record()
|
||||
|
||||
self.worker_adapter.batched_submit_store_requests(request_ids, ops, event)
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str]
|
||||
@ -627,6 +630,9 @@ class LMCacheMPConnector(KVConnectorBase_V1):
|
||||
into account.
|
||||
"""
|
||||
tracker = self._get_or_create_request_tracker(request)
|
||||
# TODO: support loading KV for preempted requests in the future
|
||||
if request.status == RequestStatus.PREEMPTED:
|
||||
return 0, False
|
||||
|
||||
self.scheduler_adapter.maybe_submit_lookup_request(
|
||||
request.request_id, convert_block_hashes_to_bytes(request.block_hashes)
|
||||
@ -683,7 +689,7 @@ class LMCacheMPConnector(KVConnectorBase_V1):
|
||||
|
||||
# No matter we need to retrieve or not, we need to update
|
||||
# the block ids into the tracker
|
||||
tracker.update_block_ids(block_ids)
|
||||
tracker.append_block_ids(block_ids)
|
||||
|
||||
# Update the state of the tracker
|
||||
condition = tracker.needs_retrieve()
|
||||
@ -866,7 +872,8 @@ class LMCacheMPConnector(KVConnectorBase_V1):
|
||||
|
||||
# Update block ids
|
||||
new_block_ids = reformat_block_ids(cached_reqs.new_block_ids[idx])
|
||||
request_tracker.update_block_ids(new_block_ids)
|
||||
if request_id not in cached_reqs.resumed_req_ids:
|
||||
request_tracker.append_block_ids(new_block_ids)
|
||||
|
||||
# Update new scheduled tokens
|
||||
num_new_tokens = cached_reqs.num_computed_tokens[idx]
|
||||
@ -889,6 +896,21 @@ class LMCacheMPConnector(KVConnectorBase_V1):
|
||||
self, request: "Request"
|
||||
) -> LMCacheMPRequestTracker:
|
||||
request_id = request.request_id
|
||||
# Remove the old trackers that is created before the preemption
|
||||
if (
|
||||
request.status == RequestStatus.PREEMPTED
|
||||
and request_id in self.request_trackers
|
||||
):
|
||||
tracker = self.request_trackers[request_id]
|
||||
|
||||
# NOTE: since this function may be called multiple times
|
||||
# for a single request (because get_num_new_matched_tokens
|
||||
# may be called multiple times) for the same request, we
|
||||
# will only do the remove if the tracker is not in the "fresh"
|
||||
# state, i.e., PREFETCHING
|
||||
if tracker.state != LMCacheMPRequestState.PREFETCHING:
|
||||
self.request_trackers.pop(request_id)
|
||||
|
||||
if request_id not in self.request_trackers:
|
||||
new_tracker = LMCacheMPRequestTracker(request)
|
||||
self.request_trackers[request_id] = new_tracker
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user