[KV connector][WIP] KV cache proxy based on LMCache multi-process mode (#27902)

Signed-off-by: ApostaC <yihua98@uchicago.edu>
This commit is contained in:
Yihua Cheng 2025-11-12 12:25:43 -08:00 committed by GitHub
parent a39dd7bb06
commit 94a9ebcf31
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 1265 additions and 2 deletions

View File

@ -161,6 +161,12 @@ KVConnectorFactory.register_connector(
"LMCacheConnectorV1",
)
KVConnectorFactory.register_connector(
"LMCacheMPConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.lmcache_mp_connector",
"LMCacheMPConnector",
)
KVConnectorFactory.register_connector(
"NixlConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector",

View File

@ -2,6 +2,17 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from . import vllm_v1_adapter
from . import multi_process_adapter, vllm_v1_adapter
from .multi_process_adapter import (
LMCacheMPSchedulerAdapter,
LMCacheMPWorkerAdapter,
LoadStoreOp,
)
__all__ = ["vllm_v1_adapter"]
__all__ = [
"vllm_v1_adapter",
"multi_process_adapter",
"LMCacheMPSchedulerAdapter",
"LMCacheMPWorkerAdapter",
"LoadStoreOp",
]

View File

@ -0,0 +1,379 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from collections.abc import Iterable
from dataclasses import dataclass
from itertools import islice
from typing import Any
import torch
import zmq
from lmcache.utils import _lmcache_nvtx_annotate, init_logger
from lmcache.v1.multiprocess.custom_types import (
CudaIPCWrapper,
IPCCacheEngineKey,
KVCache,
)
from lmcache.v1.multiprocess.mq import MessageQueueClient, MessagingFuture
from lmcache.v1.multiprocess.protocol import RequestType, get_response_class
logger = init_logger(__name__)
def wrap_kv_caches(kv_caches: dict[str, KVCache]) -> KVCache:
logger.info("KV caches keys are %s", list(kv_caches.keys()))
return [CudaIPCWrapper(tensor) for tensor in kv_caches.values()]
def send_lmcache_request(
mq_client: MessageQueueClient,
request_type: RequestType,
payloads: list[Any],
) -> MessagingFuture[Any]:
future = mq_client.submit_request(
request_type, payloads, get_response_class(request_type)
)
return future
def get_lmcache_chunk_size(
mq_client: MessageQueueClient,
) -> int:
future = send_lmcache_request(mq_client, RequestType.GET_CHUNK_SIZE, [])
chunk_size = future.result()
return chunk_size
def striding_block_hashes(
block_hashes: list[bytes],
blocks_in_chunk,
) -> Iterable[bytes]:
"""Striding the block hashes to get the block hashes for each chunk.
For example, if blocks_in_chunk is 16, then we will get the block hashes
for the 16th, 32nd, 48th, ... blocks.
"""
return islice(block_hashes, blocks_in_chunk - 1, None, blocks_in_chunk)
@dataclass
class LoadStoreOp:
block_hashes: list[bytes]
block_ids: list[int]
def __len__(self) -> int:
return len(self.block_hashes)
def __post_init__(self):
assert len(self.block_hashes) == len(self.block_ids), (
"The number of block hashes should be equal to the number of block ids "
f"But got {len(self.block_hashes)} and {len(self.block_ids)}"
)
StoreResult = bool
RetrieveResult = list[bool]
LookupResult = list[bool]
class LMCacheMPSchedulerAdapter:
def __init__(
self,
server_url: str,
context: zmq.Context,
model_name: str,
world_size: int,
kv_rank: int,
vllm_block_size: int,
):
"""
Args:
server_url: The server URL for the LMCache message queue
context: The ZMQ context
model_name: The model name used for LMCache keys
world_size: The world size used for LMCache keys
kv_rank: The kv rank used for LMCache keys
vllm_block_size: The block size used in vLLM
"""
self.mq_client = MessageQueueClient(server_url, context)
# Request futures
self.lookup_futures: dict[str, MessagingFuture[LookupResult]] = {}
self.model_name = model_name
self.world_size = world_size
self.worker_id = kv_rank
# Read chunk size from lmcache
self.chunk_size = get_lmcache_chunk_size(self.mq_client)
assert self.chunk_size % vllm_block_size == 0, (
"LMCache chunk size should be a multiple of vLLM block size"
)
self.blocks_in_chunk = self.chunk_size // vllm_block_size
@_lmcache_nvtx_annotate
def maybe_submit_lookup_request(self, request_id: str, block_hashes: list[bytes]):
if request_id in self.lookup_futures:
# Skip if there is already a lookup request
return
s = striding_block_hashes(block_hashes, self.blocks_in_chunk)
keys = [self._create_key(block_hash) for block_hash in s]
future = send_lmcache_request(
self.mq_client,
RequestType.LOOKUP,
[keys, True],
)
self.lookup_futures[request_id] = future
@_lmcache_nvtx_annotate
def check_lookup_result(self, request_id: str) -> int | None:
assert request_id in self.lookup_futures, (
f"Lookup request for request_id={request_id} has not been submitted"
)
future = self.lookup_futures[request_id]
if not future.query():
return None
result = future.result()
num_chunks = sum(result)
return num_chunks * self.chunk_size
def num_blocks_per_chunk(self) -> int:
"""
Returns:
The number of vllm blocks in a LMCache data chunk
"""
return self.blocks_in_chunk
# Helper functions
def _create_key(self, block_hash: bytes) -> IPCCacheEngineKey:
"""Convert a block hash to an IPC cache engine key"""
return IPCCacheEngineKey(
model_name=self.model_name,
world_size=self.world_size,
worker_id=self.worker_id,
chunk_hash=block_hash,
)
class LMCacheMPWorkerAdapter:
def __init__(
self,
server_url: str,
context: zmq.Context,
model_name: str,
world_size: int,
kv_rank: int,
vllm_block_size: int,
):
self.mq_client = MessageQueueClient(server_url, context)
# Instance id for GPU worker
self.instance_id = os.getpid()
# Registered kv caches from vLLM
self.kv_caches: dict[str, torch.Tensor] = {}
# Request futures
# request_id -> (future, other merged requests)
self.store_futures: dict[
str, tuple[MessagingFuture[StoreResult], list[str]]
] = {}
self.retrieve_futures: dict[
str, tuple[MessagingFuture[RetrieveResult], list[str]]
] = {}
self.finished_stores: set[str] = set()
self.previously_finished: set[str] = set()
self.model_name = model_name
self.world_size = world_size
self.worker_id = kv_rank
# Read chunk size from lmcache
chunk_size = get_lmcache_chunk_size(self.mq_client)
assert chunk_size % vllm_block_size == 0, (
"LMCache chunk size should be a multiple of vLLM block size"
)
self.blocks_in_chunk = chunk_size // vllm_block_size
def register_kv_caches(self, kv_caches: dict[str, KVCache]):
# Register kv cache and send the request
self.kv_caches = kv_caches
logger.info("Registering kv caches")
future = send_lmcache_request(
self.mq_client,
RequestType.REGISTER_KV_CACHE,
[self.instance_id, wrap_kv_caches(kv_caches)],
)
future.result()
@_lmcache_nvtx_annotate
def submit_store_request(
self, request_id: str, op: LoadStoreOp, event: torch.cuda.Event
):
keys = self._block_hashes_to_keys(op.block_hashes)
future = send_lmcache_request(
self.mq_client,
RequestType.STORE,
[keys, self.instance_id, op.block_ids, event.ipc_handle()],
).to_cuda_future()
self.store_futures[request_id] = (future, [])
@_lmcache_nvtx_annotate
def submit_retrieve_request(
self, request_id: str, op: LoadStoreOp, event: torch.cuda.Event
):
keys = self._block_hashes_to_keys(op.block_hashes)
future = send_lmcache_request(
self.mq_client,
RequestType.RETRIEVE,
[keys, self.instance_id, op.block_ids, event.ipc_handle()],
).to_cuda_future()
self.retrieve_futures[request_id] = (future, [])
@_lmcache_nvtx_annotate
def batched_submit_store_requests(
self,
request_ids: list[str],
ops: list[LoadStoreOp],
event: torch.cuda.Event,
):
keys = []
block_ids = []
for op in ops:
keys.extend(self._block_hashes_to_keys(op.block_hashes))
block_ids.extend(op.block_ids)
future = send_lmcache_request(
self.mq_client,
RequestType.STORE,
[keys, self.instance_id, block_ids, event.ipc_handle()],
).to_cuda_future()
self.store_futures[request_ids[0]] = (future, request_ids[1:])
@_lmcache_nvtx_annotate
def batched_submit_retrieve_requests(
self,
request_ids: list[str],
ops: list[LoadStoreOp],
event: torch.cuda.Event,
):
keys = []
block_ids = []
for op in ops:
keys.extend(self._block_hashes_to_keys(op.block_hashes))
block_ids.extend(op.block_ids)
future = send_lmcache_request(
self.mq_client,
RequestType.RETRIEVE,
[keys, self.instance_id, block_ids, event.ipc_handle()],
).to_cuda_future()
self.retrieve_futures[request_ids[0]] = (future, request_ids[1:])
@_lmcache_nvtx_annotate
def get_finished(
self, finished_req_ids: set[str]
) -> tuple[set[str] | None, set[str] | None]:
finished_stores = set()
finished_retrieves = set()
for request_id, (future, other_reqs) in self.store_futures.items():
if not future.query():
continue
result = future.result()
finished_stores.add(request_id)
finished_stores.update(other_reqs)
if not result:
# TODO: add error handling here
logger.error(
"Something went wrong when processing the "
"store request for request_id=%s",
request_id,
)
for request_id, (future, other_reqs) in self.retrieve_futures.items():
if not future.query():
continue
result = future.result()
finished_retrieves.add(request_id)
finished_retrieves.update(other_reqs)
if not all(result):
# TODO: add error handing here
logger.error(
"Something went wrong when processing the "
"retrieve request for request_id=%s, result=%s",
request_id,
result,
)
logger.info("Retrieve request for request_id=%s finished", request_id)
# Remove the finished requests from the tracking dicts
for request_id in finished_stores:
self.store_futures.pop(request_id, None)
for request_id in finished_retrieves:
self.retrieve_futures.pop(request_id, None)
# Update the internal states
self.finished_stores.update(finished_stores)
ret_stores = set()
for req_id in finished_req_ids:
if req_id in self.finished_stores or req_id in self.store_futures:
self.previously_finished.add(req_id)
else:
ret_stores.add(req_id)
# Calculate the final finished stores
ret_stores.update(self._update_and_get_finished_store())
return ret_stores, finished_retrieves
def num_blocks_per_chunk(self) -> int:
"""
Returns:
The number of vllm blocks in a LMCache data chunk
"""
return self.blocks_in_chunk
def shutdown(self):
# Unregister kv cache
logger.info("Unregistering kv caches")
send_lmcache_request(
self.mq_client, RequestType.UNREGISTER_KV_CACHE, [self.instance_id]
).result()
self.mq_client.close()
# Helper functions
def _update_and_get_finished_store(
self,
) -> set[str]:
"""Converge the internal states about finished stores
and returns the 'safe finished store request ids' back
"""
safe_finished_s = self.finished_stores.intersection(self.previously_finished)
self.finished_stores.difference_update(self.previously_finished)
self.previously_finished.difference_update(safe_finished_s)
return safe_finished_s
def _create_key(self, block_hash: bytes) -> IPCCacheEngineKey:
"""Convert a block hash to an IPC cache engine key"""
return IPCCacheEngineKey(
model_name=self.model_name,
world_size=self.world_size,
worker_id=self.worker_id,
chunk_hash=block_hash,
)
def _block_hashes_to_keys(
self, block_hashes: list[bytes]
) -> list[IPCCacheEngineKey]:
"""Convert block hashes to IPC cache engine keys"""
s = striding_block_hashes(block_hashes, self.blocks_in_chunk)
return [self._create_key(block_hash) for block_hash in s]

View File

@ -0,0 +1,867 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
from collections.abc import Iterable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
import torch
import zmq
from lmcache.utils import init_logger as lmcache_init_logger
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
KVConnectorMetadata,
KVConnectorRole,
)
from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration import (
LMCacheMPSchedulerAdapter,
LMCacheMPWorkerAdapter,
LoadStoreOp,
)
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import KVConnectorOutput
from vllm.v1.utils import ConstantList
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_events import KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorPromMetrics,
KVConnectorStats,
PromMetric,
PromMetricT,
)
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.kv_cache_utils import BlockHash
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request
logger = lmcache_init_logger(__name__)
# Helper functions
def reformat_block_ids(block_ids: tuple[list[int], ...] | None) -> list[int]:
if block_ids is None:
return []
assert isinstance(block_ids, tuple), (
f"Expected block_ids to be a tuple of lists, but got {type(block_ids)}"
)
if len(block_ids) > 1:
raise RuntimeError(
"LMCacheMPConnector only works without hybrid kv cache manager. "
"Please pass --disable-hybrid-kv-cache-manager when starting vllm"
)
return block_ids[0]
def create_scheduler_adapter(
server_url: str, zmq_context: zmq.Context, vllm_config: VllmConfig
) -> LMCacheMPSchedulerAdapter:
# TODO: have a helper function to calculate the correct rank and
# world size for the MLA and other models
return LMCacheMPSchedulerAdapter(
server_url,
zmq_context,
vllm_config.model_config.model,
vllm_config.parallel_config.world_size,
vllm_config.parallel_config.rank,
vllm_config.cache_config.block_size,
)
def create_worker_adapter(
server_url: str, zmq_context: zmq.Context, vllm_config: VllmConfig
) -> LMCacheMPWorkerAdapter:
# TODO: have a helper function to calculate the correct rank and
# world size for the MLA and other models
return LMCacheMPWorkerAdapter(
server_url,
zmq_context,
vllm_config.model_config.model,
vllm_config.parallel_config.world_size,
vllm_config.parallel_config.rank,
vllm_config.cache_config.block_size,
)
def convert_block_hashes_to_bytes(
block_hashes: list["BlockHash"],
) -> list[bytes]:
return cast(list[bytes], block_hashes)
class LMCacheMPRequestState(enum.Enum):
"""
State machine:
PREFETCHING -- update_state_after_alloc --> WAITING_FOR_LOAD
WAITING_FOR_LOAD -- process_loading_requests --> READY
"""
PREFETCHING = enum.auto()
WAITING_FOR_LOAD = enum.auto()
READY = enum.auto()
@dataclass
class LMCacheMPRequestTracker:
# NOTE: this class used vLLM data structures, should be part of
# vLLM integration code
request_id: str
# Read-only lists to track the token ids and block hashes
all_token_ids: ConstantList[int]
block_hashes: ConstantList["BlockHash"]
# Block ids and hashes will be updated at update_states_after_alloc and
# during the generation
allocated_block_ids: list[int] = field(default_factory=list)
# Number of scheduled tokens in this request. We keep tracking this to
# avoid saving half-full blocks.
num_scheduled_tokens: int = 0
# Number of blocks stored will be initialized when lookup the external
# hit tokens and will be updated when processing new requests and cached
# requests.
num_stored_blocks: int = 0
# Staging load operation -- save vllm and lmcache hit tokens during lookup
num_vllm_hit_blocks: int = 0
num_lmcache_hit_blocks: int = 0
# Main state
state: LMCacheMPRequestState = LMCacheMPRequestState.PREFETCHING
def __init__(self, request: "Request"):
self.request_id = request.request_id
self.all_token_ids = request.all_token_ids
self.block_hashes = ConstantList(request.block_hashes)
self.allocated_block_ids = []
self.num_stored_blocks = 0
self.num_vllm_hit_blocks = 0
self.num_lmcache_hit_blocks = 0
self.state = LMCacheMPRequestState.PREFETCHING
####
# Check the state of the request
####
def needs_retrieve(self) -> bool:
"""Check whether the current request needs retrieve, will be used
update_stage_after_alloc"""
return (
self.num_lmcache_hit_blocks > self.num_vllm_hit_blocks
and self.state != LMCacheMPRequestState.READY
)
def is_ready_for_retrieving(self) -> bool:
"""Check whether the current request is ready for retrieving,
will be used in process_loading_requests"""
return (
self.state == LMCacheMPRequestState.WAITING_FOR_LOAD
and self.needs_retrieve()
)
####
# Update internal states
####
def increase_num_scheduled_tokens(self, num_new_tokens: int):
self.num_scheduled_tokens += num_new_tokens
def increase_num_stored_blocks(self, num_new_blocks: int):
"""Increase the number of stored blocks for the current request
This function will be called when processing the cached requests.
"""
self.num_stored_blocks += num_new_blocks
def update_block_ids(
self,
new_block_ids: list[int],
):
"""Update the block ids for the current request
This function will be called when processing the cached requests.
"""
self.allocated_block_ids.extend(new_block_ids)
####
# For debugging
####
def __repr__(self) -> str:
return (
f"LMCacheMPRequestTracker(request_id={self.request_id}, "
f"num_tokens={len(self.all_token_ids)}, "
f"num_block_hashes={len(self.block_hashes)}, "
f"num_allocated_blocks={len(self.allocated_block_ids)}, "
f"num_stored_blocks={self.num_stored_blocks}, "
f"vllm_hit_blocks={self.num_vllm_hit_blocks}, "
f"lmcache_hit_blocks={self.num_lmcache_hit_blocks}, "
f"state={self.state})"
)
def __str__(self) -> str:
return self.__repr__()
@dataclass
class LMCacheMPRequestMetadata:
request_id: str
direction: Literal["STORE", "RETRIEVE"]
op: LoadStoreOp
@staticmethod
def GetStoreMetadata(
tracker: LMCacheMPRequestTracker,
blocks_in_chunk: int,
vllm_block_size: int,
) -> "LMCacheMPRequestMetadata | None":
"""
Generate the store metadata for the current request tracker.
Args:
tracker: The request tracker to generate the metadata from.
blocks_in_chunk: the number of blocks in a LMCache data chunk
"""
# Store the blocks that has block hashes
# NOTE: the invariant here is that `num_stored_blocks` should
# always be a multiple of `blocks_in_chunk`
# TODO: This should be checked everytime we update the num_stored_blocks
min_available_blocks = min(
len(tracker.block_hashes),
len(tracker.allocated_block_ids),
tracker.num_scheduled_tokens // vllm_block_size,
)
num_staging_blocks = min_available_blocks - tracker.num_stored_blocks
num_chunks = num_staging_blocks // blocks_in_chunk
if num_chunks >= 1:
start = tracker.num_stored_blocks
end = start + num_chunks * blocks_in_chunk
block_hashes = convert_block_hashes_to_bytes(
tracker.block_hashes[start:end]
)
block_ids = tracker.allocated_block_ids[start:end]
ret = LMCacheMPRequestMetadata(
request_id=tracker.request_id,
direction="STORE",
op=LoadStoreOp(block_hashes=block_hashes, block_ids=block_ids),
)
# Update the request tracker
tracker.increase_num_stored_blocks(end - start)
return ret
return None
@staticmethod
def GetRetrieveMetadata(
tracker: LMCacheMPRequestTracker,
blocks_in_chunk: int,
) -> "LMCacheMPRequestMetadata | None":
"""
Generate the retrieve metadata for the current request tracker.
Args:
tracker: The request tracker to generate the metadata from.
blocks_in_chunk: the number of blocks in a LMCache data chunk
"""
if not tracker.is_ready_for_retrieving():
return None
# |---------------------|-----------------|----------------|
# | num_vllm_hit_blocks |
# | lmcache chunk 1 | lmcache chunk 2 |
# | need to retrieve |
start = tracker.num_vllm_hit_blocks // blocks_in_chunk * blocks_in_chunk
end = tracker.num_lmcache_hit_blocks
assert end % blocks_in_chunk == 0, (
"The number of LMCache hit blocks should be a multiple of the "
"number of blocks in a lmcache chunk. "
)
assert len(tracker.block_hashes) >= end, (
"The number of block hashes should be greater than or equal to the "
"number of LMCache hit blocks. "
)
if end > start:
block_hashes = convert_block_hashes_to_bytes(
tracker.block_hashes[start:end]
)
block_ids = tracker.allocated_block_ids[start:end]
ret = LMCacheMPRequestMetadata(
request_id=tracker.request_id,
direction="RETRIEVE",
op=LoadStoreOp(block_hashes=block_hashes, block_ids=block_ids),
)
return ret
return None
class LMCacheMPConnectorMetadata(KVConnectorMetadata):
def __init__(self):
super().__init__()
self.requests: list[LMCacheMPRequestMetadata] = []
def add_request_metadata(self, request_metadata: LMCacheMPRequestMetadata):
self.requests.append(request_metadata)
def __len__(self):
return len(self.requests)
# For debugging
def __str__(self):
request_strs = []
for req_meta in self.requests:
request_strs.append(
f"RequestMetadata(request_id={req_meta.request_id}, "
f"direction={req_meta.direction}, "
f"num_blocks={len(req_meta.op)}, "
f"block_ids={req_meta.op.block_ids})"
)
return "[" + "\n".join(request_strs) + "]"
def __repr__(self):
return self.__str__()
class LMCacheMPConnector(KVConnectorBase_V1):
"""
The connector for LMCache multi-process mode.
Extra configs (kv_transfer_config.extra_config):
- lmcache.mp.host: the host of the LMCache server.
- lmcache.mp.port: the port of the LMCache server.
"""
def __init__(
self,
vllm_config: "VllmConfig",
role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None,
):
super().__init__(vllm_config, role, kv_cache_config)
assert vllm_config.kv_transfer_config is not None
server_host = vllm_config.kv_transfer_config.get_from_extra_config(
"lmcache.mp.host", "tcp://localhost"
)
server_port = vllm_config.kv_transfer_config.get_from_extra_config(
"lmcache.mp.port", 5555
)
server_url = f"{server_host}:{server_port}"
zmq_context = zmq.Context.instance()
if self.role == KVConnectorRole.SCHEDULER:
self.scheduler_adapter = create_scheduler_adapter(
server_url, zmq_context, vllm_config
)
self.request_trackers: dict[str, LMCacheMPRequestTracker] = {}
elif self.role == KVConnectorRole.WORKER:
self.worker_adapter = create_worker_adapter(
server_url, zmq_context, vllm_config
)
else:
raise ValueError(f"Unknown KVConnectorRole: {self.role}")
self.vllm_block_size = vllm_config.cache_config.block_size
@property
def role(self) -> KVConnectorRole:
return self._role
# ==============================
# Worker-side methods
# ==============================
def _get_connector_metadata(self) -> KVConnectorMetadata:
"""Get the connector metadata.
This function should only be called inside the connector.
Returns:
ConnectorMetadata: the connector metadata.
"""
# Should only be called while set to valid metadata.
assert self._connector_metadata is not None
return self._connector_metadata
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""
Initialize with the KV caches. Useful for pre-registering the
KV Caches in the KVConnector (e.g. for NIXL).
Args:
kv_caches: dictionary of layer names, kv cache
"""
logger.info("Registering kv caches!")
self.worker_adapter.register_kv_caches(kv_caches)
return
def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
"""
Start loading the KV cache from the connector to vLLM's paged
KV buffer. This is called from the forward context before the
forward pass to enable async loading during model execution.
Args:
forward_context (ForwardContext): the forward context.
**kwargs: additional arguments for the load operation
Note:
The number of elements in kv_caches and layer_names should be
the same.
"""
metadata = self._get_connector_metadata()
assert isinstance(metadata, LMCacheMPConnectorMetadata)
with torch.cuda.stream(torch.cuda.current_stream()):
event = torch.cuda.Event(interprocess=True)
event.record()
request_ids = []
ops = []
for meta in metadata.requests:
if meta.direction != "RETRIEVE":
continue
request_ids.append(meta.request_id)
ops.append(meta.op)
if len(request_ids) > 0:
logger.info(
"HERE! SUBMITTING THE BATCHED RETRIEVE REQUESTS %s", request_ids
)
self.worker_adapter.batched_submit_retrieve_requests(
request_ids, ops, event
)
def wait_for_layer_load(self, layer_name: str) -> None:
"""
Block until the KV for a specific layer is loaded into vLLM's
paged buffer. This is called from within attention layer to ensure
async copying from start_load_kv is complete.
This interface will be useful for layer-by-layer pipelining.
Args:
layer_name: the name of that layer
"""
return
def save_kv_layer(
self,
layer_name: str,
kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata",
**kwargs: Any,
) -> None:
"""
Start saving a layer of KV cache from vLLM's paged buffer
to the connector. This is called from within attention layer to
enable async copying during execution.
Args:
layer_name (str): the name of the layer.
kv_layer (torch.Tensor): the paged KV buffer of the current
layer in vLLM.
attn_metadata (AttentionMetadata): the attention metadata.
**kwargs: additional arguments for the save operation.
"""
return
def wait_for_save(self):
"""
Block until all the save operations is done. This is called
as the forward context exits to ensure that the async saving
from save_kv_layer is complete before finishing the forward.
This prevents overwrites of paged KV buffer before saving done.
"""
metadata = self._get_connector_metadata()
assert isinstance(metadata, LMCacheMPConnectorMetadata)
with torch.cuda.stream(torch.cuda.current_stream()):
event = torch.cuda.Event(interprocess=True)
event.record()
request_ids = []
ops = []
for meta in metadata.requests:
if meta.direction != "STORE":
continue
request_ids.append(meta.request_id)
ops.append(meta.op)
if len(request_ids) > 0:
self.worker_adapter.batched_submit_store_requests(request_ids, ops, event)
def get_finished(
self, finished_req_ids: set[str]
) -> tuple[set[str] | None, set[str] | None]:
"""
Notifies worker-side connector ids of requests that have
finished generating tokens on the worker.
The scheduler process (via the Executors) will use this output
to track which workers are done.
Returns:
ids of requests that have finished asynchronous transfer
(requests that previously returned True from request_finished()),
tuple of (sending/saving ids, recving/loading ids).
The finished saves/sends req ids must belong to a set provided in a
call to this method (this call or a prior one).
"""
val = self.worker_adapter.get_finished(finished_req_ids)
# logger.error("Finished req ids: %s, %s", val[0], val[1])
return val
def get_block_ids_with_load_errors(self) -> set[int]:
"""
Get the set of block IDs that failed to load.
Returns:
Set of block IDs that encountered load errors.
Empty set if no load errors occurred.
Notes:
- Applies to both sync- and async-loading requests.
- Async loading: failed blocks may be reported in any forward pass
up to and including the pass where the request ID is returned by
`get_finished()`. Even if failures occur, the request must still
be reported via `get_finished()`, and the failed block IDs must
appear here no later than that same pass.
- Sync loading: failed blocks should be reported in the forward
pass in which they are detected.
"""
# TODO: add error tracking
return set()
def shutdown(self):
"""
Shutdown the connector. This is called when the worker process
is shutting down to ensure that all the async operations are
completed and the connector is cleaned up properly.
"""
if hasattr(self, "worker_adapter"):
self.worker_adapter.shutdown()
return None
def get_kv_connector_stats(self) -> Optional["KVConnectorStats"]:
"""
Get the KV connector stats collected during the last interval.
"""
return None
# ==============================
# Scheduler-side methods
# ==============================
def get_num_new_matched_tokens(
self,
request: "Request",
num_computed_tokens: int,
) -> tuple[int | None, bool]:
"""
Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
A tuple with the following elements:
- An optional number of tokens that can be loaded from the
external KV cache beyond what is already computed.
If None, it means that the connector needs more time to
determine the number of matched tokens, and the scheduler
should query for this request again later.
- `True` if external KV cache tokens will be loaded
asynchronously (between scheduler steps). Must be
'False' if the first element is 0.
Notes:
The connector should only consider the largest prefix of prompt-
tokens for which KV cache is actually available at the time of the
call. If the cache cannot be loaded for some tokens (e.g., due to
connectivity issues or eviction), those tokens must not be taken
into account.
"""
tracker = self._get_or_create_request_tracker(request)
self.scheduler_adapter.maybe_submit_lookup_request(
request.request_id, convert_block_hashes_to_bytes(request.block_hashes)
)
ret = self.scheduler_adapter.check_lookup_result(request.request_id)
if ret is None:
return None, True
if ret == 0:
return 0, False
assert (
ret % (self.scheduler_adapter.num_blocks_per_chunk() * self.vllm_block_size)
== 0
)
# Update num stored blocks for the tracker
num_vllm_blocks = num_computed_tokens // self.vllm_block_size
num_lmcache_blocks = ret // self.vllm_block_size
tracker.increase_num_stored_blocks(num_lmcache_blocks)
# Save the vllm and lmcache hit tokens
tracker.num_vllm_hit_blocks = num_vllm_blocks
tracker.num_lmcache_hit_blocks = num_lmcache_blocks
need_to_load = max(0, ret - num_computed_tokens)
logger.debug(
"vLLM hit is: %d, Need to load is %d", num_computed_tokens, need_to_load
)
return need_to_load, need_to_load > 0
def update_state_after_alloc(
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
):
"""
Update KVConnector state after block allocation.
If get_num_new_matched_tokens previously returned True for a
request, this function may be called twice for that same request -
first when blocks are allocated for the connector tokens to be
asynchronously loaded into, and second when any additional blocks
are allocated, after the load/transfer is complete.
Args:
request (Request): the request object.
blocks (KVCacheBlocks): the blocks allocated for the request.
num_external_tokens (int): the number of tokens that will be
loaded from the external KV cache.
"""
# NOTE: the `blocks` are NEW BLOCKS allocated for this request.
tracker = self._get_request_tracker(request.request_id)
block_ids = reformat_block_ids(blocks.get_block_ids())
# No matter we need to retrieve or not, we need to update
# the block ids into the tracker
tracker.update_block_ids(block_ids)
# Update the state of the tracker
condition = tracker.needs_retrieve()
if tracker.state == LMCacheMPRequestState.PREFETCHING:
# If need to retrieve, change to WAITING_FOR_LOAD
# Otherwise, change to READY
tracker.state = (
LMCacheMPRequestState.WAITING_FOR_LOAD
if condition
else LMCacheMPRequestState.READY
)
def build_connector_meta(
self, scheduler_output: SchedulerOutput
) -> KVConnectorMetadata:
"""
Build the connector metadata for this step.
This function should NOT modify fields in the scheduler_output.
Also, calling this function will reset the state of the connector.
Args:
scheduler_output (SchedulerOutput): the scheduler output object.
"""
metadata = LMCacheMPConnectorMetadata()
self._process_retrieve_requests(metadata)
self._process_new_requests(scheduler_output, metadata)
self._process_cached_requests(scheduler_output, metadata)
if len(metadata) > 0:
logger.debug("Final connector metadata: %s", metadata)
return metadata
def update_connector_output(self, connector_output: KVConnectorOutput):
"""
Update KVConnector state from worker-side connectors output.
Args:
connector_output (KVConnectorOutput): the worker-side
connectors output.
"""
return
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, dict[str, Any] | None]:
"""
Called exactly once when a request has finished, before its blocks are
freed.
The connector may assumes responsibility for freeing the blocks
asynchronously by returning True.
Returns:
True if the request is being saved/sent asynchronously and blocks
should not be freed until the request_id is returned from
get_finished().
Optional KVTransferParams to be included in the request outputs
returned by the engine.
"""
return True, None
def take_events(self) -> Iterable["KVCacheEvent"]:
"""
Take the KV cache events from the connector.
Yields:
New KV cache events since the last call.
"""
return ()
@classmethod
def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> str | None:
"""
Get the required KV cache layout for this connector.
Args:
vllm_config (VllmConfig): the vllm config.
Returns:
str: the required KV cache layout. e.g. HND, or NHD.
None if the connector does not require a specific layout.
"""
if cls is KVConnectorBase_V1:
raise TypeError(
"get_required_kvcache_layout should not be called "
"on the abstract base class"
)
return None
def get_finished_count(self) -> int | None:
"""
Get the count of requests expected to complete send/receive operations
via this connector. This method is used to initialize the
KVOutputAggregator, overwriting the default world_size.
Returns:
int: expected sending or receiving completion count.
"""
return None
@classmethod
def build_kv_connector_stats(
cls, data: dict[str, Any] | None = None
) -> Optional["KVConnectorStats"]:
"""
KVConnectorStats resolution method. This method allows dynamically
registered connectors to return their own KVConnectorStats object,
which can implement custom aggregation logic on the data dict.
"""
return None
@classmethod
def build_prom_metrics(
cls,
vllm_config: "VllmConfig",
metric_types: dict[type["PromMetric"], type["PromMetricT"]],
labelnames: list[str],
per_engine_labelvalues: dict[int, list[str]],
) -> Optional["KVConnectorPromMetrics"]:
"""
Create a KVConnectorPromMetrics subclass which should register
per-connector Prometheus metrics and implement observe() to
expose connector transfer stats via Prometheus.
"""
return None
##############################
# Helper functions
##############################
def _process_retrieve_requests(
self,
metadata: LMCacheMPConnectorMetadata,
) -> None:
blocks_per_chunk = self.scheduler_adapter.num_blocks_per_chunk()
for request_tracker in self.request_trackers.values():
if request_tracker.state != LMCacheMPRequestState.WAITING_FOR_LOAD:
continue
r_metadata = LMCacheMPRequestMetadata.GetRetrieveMetadata(
request_tracker, blocks_per_chunk
)
if r_metadata is not None:
metadata.add_request_metadata(r_metadata)
request_tracker.state = LMCacheMPRequestState.READY
def _process_new_requests(
self,
scheduler_output: SchedulerOutput,
metadata: LMCacheMPConnectorMetadata,
) -> None:
blocks_per_chunk = self.scheduler_adapter.num_blocks_per_chunk()
for new_request in scheduler_output.scheduled_new_reqs:
request_tracker = self._get_request_tracker(new_request.req_id)
num_new_tokens = scheduler_output.num_scheduled_tokens[new_request.req_id]
request_tracker.increase_num_scheduled_tokens(num_new_tokens)
r_meta = LMCacheMPRequestMetadata.GetStoreMetadata(
request_tracker, blocks_per_chunk, self.vllm_block_size
)
if r_meta is not None:
metadata.add_request_metadata(r_meta)
def _process_cached_requests(
self,
scheduler_output: SchedulerOutput,
metadata: LMCacheMPConnectorMetadata,
) -> None:
blocks_per_chunk = self.scheduler_adapter.num_blocks_per_chunk()
cached_reqs = scheduler_output.scheduled_cached_reqs
for idx, request_id in enumerate(cached_reqs.req_ids):
request_tracker = self._get_request_tracker(request_id)
# Update block ids
new_block_ids = reformat_block_ids(cached_reqs.new_block_ids[idx])
request_tracker.update_block_ids(new_block_ids)
# Update new scheduled tokens
num_new_tokens = cached_reqs.num_computed_tokens[idx]
request_tracker.increase_num_scheduled_tokens(num_new_tokens)
r_meta = LMCacheMPRequestMetadata.GetStoreMetadata(
request_tracker, blocks_per_chunk, self.vllm_block_size
)
if r_meta is not None:
metadata.add_request_metadata(r_meta)
def _get_request_tracker(self, request_id: str) -> LMCacheMPRequestTracker:
assert request_id in self.request_trackers, (
f"Request tracker for request_id {request_id} not found. "
)
return self.request_trackers[request_id]
def _get_or_create_request_tracker(
self, request: "Request"
) -> LMCacheMPRequestTracker:
request_id = request.request_id
if request_id not in self.request_trackers:
new_tracker = LMCacheMPRequestTracker(request)
self.request_trackers[request_id] = new_tracker
return self.request_trackers[request_id]