diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index 494a4d3c33aa..df871dd7cbe4 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -161,6 +161,12 @@ KVConnectorFactory.register_connector( "LMCacheConnectorV1", ) +KVConnectorFactory.register_connector( + "LMCacheMPConnector", + "vllm.distributed.kv_transfer.kv_connector.v1.lmcache_mp_connector", + "LMCacheMPConnector", +) + KVConnectorFactory.register_connector( "NixlConnector", "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector", diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/__init__.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/__init__.py index 3c73a1c09e58..07e05cc8f893 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/__init__.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/__init__.py @@ -2,6 +2,17 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from . import vllm_v1_adapter +from . import multi_process_adapter, vllm_v1_adapter +from .multi_process_adapter import ( + LMCacheMPSchedulerAdapter, + LMCacheMPWorkerAdapter, + LoadStoreOp, +) -__all__ = ["vllm_v1_adapter"] +__all__ = [ + "vllm_v1_adapter", + "multi_process_adapter", + "LMCacheMPSchedulerAdapter", + "LMCacheMPWorkerAdapter", + "LoadStoreOp", +] diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py new file mode 100644 index 000000000000..ab2eeed9f6b8 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py @@ -0,0 +1,379 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os +from collections.abc import Iterable +from dataclasses import dataclass +from itertools import islice +from typing import Any + +import torch +import zmq +from lmcache.utils import _lmcache_nvtx_annotate, init_logger +from lmcache.v1.multiprocess.custom_types import ( + CudaIPCWrapper, + IPCCacheEngineKey, + KVCache, +) +from lmcache.v1.multiprocess.mq import MessageQueueClient, MessagingFuture +from lmcache.v1.multiprocess.protocol import RequestType, get_response_class + +logger = init_logger(__name__) + + +def wrap_kv_caches(kv_caches: dict[str, KVCache]) -> KVCache: + logger.info("KV caches keys are %s", list(kv_caches.keys())) + return [CudaIPCWrapper(tensor) for tensor in kv_caches.values()] + + +def send_lmcache_request( + mq_client: MessageQueueClient, + request_type: RequestType, + payloads: list[Any], +) -> MessagingFuture[Any]: + future = mq_client.submit_request( + request_type, payloads, get_response_class(request_type) + ) + return future + + +def get_lmcache_chunk_size( + mq_client: MessageQueueClient, +) -> int: + future = send_lmcache_request(mq_client, RequestType.GET_CHUNK_SIZE, []) + chunk_size = future.result() + return chunk_size + + +def striding_block_hashes( + block_hashes: list[bytes], + blocks_in_chunk, +) -> Iterable[bytes]: + """Striding the block hashes to get the block hashes for each chunk. + For example, if blocks_in_chunk is 16, then we will get the block hashes + for the 16th, 32nd, 48th, ... blocks. + """ + return islice(block_hashes, blocks_in_chunk - 1, None, blocks_in_chunk) + + +@dataclass +class LoadStoreOp: + block_hashes: list[bytes] + block_ids: list[int] + + def __len__(self) -> int: + return len(self.block_hashes) + + def __post_init__(self): + assert len(self.block_hashes) == len(self.block_ids), ( + "The number of block hashes should be equal to the number of block ids " + f"But got {len(self.block_hashes)} and {len(self.block_ids)}" + ) + + +StoreResult = bool +RetrieveResult = list[bool] +LookupResult = list[bool] + + +class LMCacheMPSchedulerAdapter: + def __init__( + self, + server_url: str, + context: zmq.Context, + model_name: str, + world_size: int, + kv_rank: int, + vllm_block_size: int, + ): + """ + Args: + server_url: The server URL for the LMCache message queue + context: The ZMQ context + + model_name: The model name used for LMCache keys + world_size: The world size used for LMCache keys + kv_rank: The kv rank used for LMCache keys + vllm_block_size: The block size used in vLLM + """ + self.mq_client = MessageQueueClient(server_url, context) + + # Request futures + self.lookup_futures: dict[str, MessagingFuture[LookupResult]] = {} + + self.model_name = model_name + self.world_size = world_size + self.worker_id = kv_rank + + # Read chunk size from lmcache + self.chunk_size = get_lmcache_chunk_size(self.mq_client) + assert self.chunk_size % vllm_block_size == 0, ( + "LMCache chunk size should be a multiple of vLLM block size" + ) + self.blocks_in_chunk = self.chunk_size // vllm_block_size + + @_lmcache_nvtx_annotate + def maybe_submit_lookup_request(self, request_id: str, block_hashes: list[bytes]): + if request_id in self.lookup_futures: + # Skip if there is already a lookup request + return + + s = striding_block_hashes(block_hashes, self.blocks_in_chunk) + keys = [self._create_key(block_hash) for block_hash in s] + future = send_lmcache_request( + self.mq_client, + RequestType.LOOKUP, + [keys, True], + ) + self.lookup_futures[request_id] = future + + @_lmcache_nvtx_annotate + def check_lookup_result(self, request_id: str) -> int | None: + assert request_id in self.lookup_futures, ( + f"Lookup request for request_id={request_id} has not been submitted" + ) + + future = self.lookup_futures[request_id] + if not future.query(): + return None + + result = future.result() + num_chunks = sum(result) + return num_chunks * self.chunk_size + + def num_blocks_per_chunk(self) -> int: + """ + Returns: + The number of vllm blocks in a LMCache data chunk + """ + return self.blocks_in_chunk + + # Helper functions + def _create_key(self, block_hash: bytes) -> IPCCacheEngineKey: + """Convert a block hash to an IPC cache engine key""" + return IPCCacheEngineKey( + model_name=self.model_name, + world_size=self.world_size, + worker_id=self.worker_id, + chunk_hash=block_hash, + ) + + +class LMCacheMPWorkerAdapter: + def __init__( + self, + server_url: str, + context: zmq.Context, + model_name: str, + world_size: int, + kv_rank: int, + vllm_block_size: int, + ): + self.mq_client = MessageQueueClient(server_url, context) + + # Instance id for GPU worker + self.instance_id = os.getpid() + + # Registered kv caches from vLLM + self.kv_caches: dict[str, torch.Tensor] = {} + + # Request futures + # request_id -> (future, other merged requests) + self.store_futures: dict[ + str, tuple[MessagingFuture[StoreResult], list[str]] + ] = {} + self.retrieve_futures: dict[ + str, tuple[MessagingFuture[RetrieveResult], list[str]] + ] = {} + + self.finished_stores: set[str] = set() + self.previously_finished: set[str] = set() + + self.model_name = model_name + self.world_size = world_size + self.worker_id = kv_rank + + # Read chunk size from lmcache + chunk_size = get_lmcache_chunk_size(self.mq_client) + assert chunk_size % vllm_block_size == 0, ( + "LMCache chunk size should be a multiple of vLLM block size" + ) + self.blocks_in_chunk = chunk_size // vllm_block_size + + def register_kv_caches(self, kv_caches: dict[str, KVCache]): + # Register kv cache and send the request + self.kv_caches = kv_caches + logger.info("Registering kv caches") + future = send_lmcache_request( + self.mq_client, + RequestType.REGISTER_KV_CACHE, + [self.instance_id, wrap_kv_caches(kv_caches)], + ) + future.result() + + @_lmcache_nvtx_annotate + def submit_store_request( + self, request_id: str, op: LoadStoreOp, event: torch.cuda.Event + ): + keys = self._block_hashes_to_keys(op.block_hashes) + future = send_lmcache_request( + self.mq_client, + RequestType.STORE, + [keys, self.instance_id, op.block_ids, event.ipc_handle()], + ).to_cuda_future() + self.store_futures[request_id] = (future, []) + + @_lmcache_nvtx_annotate + def submit_retrieve_request( + self, request_id: str, op: LoadStoreOp, event: torch.cuda.Event + ): + keys = self._block_hashes_to_keys(op.block_hashes) + future = send_lmcache_request( + self.mq_client, + RequestType.RETRIEVE, + [keys, self.instance_id, op.block_ids, event.ipc_handle()], + ).to_cuda_future() + self.retrieve_futures[request_id] = (future, []) + + @_lmcache_nvtx_annotate + def batched_submit_store_requests( + self, + request_ids: list[str], + ops: list[LoadStoreOp], + event: torch.cuda.Event, + ): + keys = [] + block_ids = [] + for op in ops: + keys.extend(self._block_hashes_to_keys(op.block_hashes)) + block_ids.extend(op.block_ids) + future = send_lmcache_request( + self.mq_client, + RequestType.STORE, + [keys, self.instance_id, block_ids, event.ipc_handle()], + ).to_cuda_future() + self.store_futures[request_ids[0]] = (future, request_ids[1:]) + + @_lmcache_nvtx_annotate + def batched_submit_retrieve_requests( + self, + request_ids: list[str], + ops: list[LoadStoreOp], + event: torch.cuda.Event, + ): + keys = [] + block_ids = [] + for op in ops: + keys.extend(self._block_hashes_to_keys(op.block_hashes)) + block_ids.extend(op.block_ids) + future = send_lmcache_request( + self.mq_client, + RequestType.RETRIEVE, + [keys, self.instance_id, block_ids, event.ipc_handle()], + ).to_cuda_future() + self.retrieve_futures[request_ids[0]] = (future, request_ids[1:]) + + @_lmcache_nvtx_annotate + def get_finished( + self, finished_req_ids: set[str] + ) -> tuple[set[str] | None, set[str] | None]: + finished_stores = set() + finished_retrieves = set() + for request_id, (future, other_reqs) in self.store_futures.items(): + if not future.query(): + continue + + result = future.result() + finished_stores.add(request_id) + finished_stores.update(other_reqs) + + if not result: + # TODO: add error handling here + logger.error( + "Something went wrong when processing the " + "store request for request_id=%s", + request_id, + ) + + for request_id, (future, other_reqs) in self.retrieve_futures.items(): + if not future.query(): + continue + + result = future.result() + finished_retrieves.add(request_id) + finished_retrieves.update(other_reqs) + + if not all(result): + # TODO: add error handing here + logger.error( + "Something went wrong when processing the " + "retrieve request for request_id=%s, result=%s", + request_id, + result, + ) + logger.info("Retrieve request for request_id=%s finished", request_id) + + # Remove the finished requests from the tracking dicts + for request_id in finished_stores: + self.store_futures.pop(request_id, None) + for request_id in finished_retrieves: + self.retrieve_futures.pop(request_id, None) + + # Update the internal states + self.finished_stores.update(finished_stores) + + ret_stores = set() + for req_id in finished_req_ids: + if req_id in self.finished_stores or req_id in self.store_futures: + self.previously_finished.add(req_id) + else: + ret_stores.add(req_id) + + # Calculate the final finished stores + ret_stores.update(self._update_and_get_finished_store()) + + return ret_stores, finished_retrieves + + def num_blocks_per_chunk(self) -> int: + """ + Returns: + The number of vllm blocks in a LMCache data chunk + """ + return self.blocks_in_chunk + + def shutdown(self): + # Unregister kv cache + logger.info("Unregistering kv caches") + send_lmcache_request( + self.mq_client, RequestType.UNREGISTER_KV_CACHE, [self.instance_id] + ).result() + + self.mq_client.close() + + # Helper functions + def _update_and_get_finished_store( + self, + ) -> set[str]: + """Converge the internal states about finished stores + and returns the 'safe finished store request ids' back + """ + safe_finished_s = self.finished_stores.intersection(self.previously_finished) + self.finished_stores.difference_update(self.previously_finished) + self.previously_finished.difference_update(safe_finished_s) + + return safe_finished_s + + def _create_key(self, block_hash: bytes) -> IPCCacheEngineKey: + """Convert a block hash to an IPC cache engine key""" + return IPCCacheEngineKey( + model_name=self.model_name, + world_size=self.world_size, + worker_id=self.worker_id, + chunk_hash=block_hash, + ) + + def _block_hashes_to_keys( + self, block_hashes: list[bytes] + ) -> list[IPCCacheEngineKey]: + """Convert block hashes to IPC cache engine keys""" + s = striding_block_hashes(block_hashes, self.blocks_in_chunk) + return [self._create_key(block_hash) for block_hash in s] diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py new file mode 100644 index 000000000000..55831dc56c80 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py @@ -0,0 +1,867 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import enum +from collections.abc import Iterable +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Literal, Optional, cast + +import torch +import zmq +from lmcache.utils import init_logger as lmcache_init_logger + +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, + KVConnectorMetadata, + KVConnectorRole, +) +from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration import ( + LMCacheMPSchedulerAdapter, + LMCacheMPWorkerAdapter, + LoadStoreOp, +) +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.outputs import KVConnectorOutput +from vllm.v1.utils import ConstantList + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.config import VllmConfig + from vllm.distributed.kv_events import KVCacheEvent + from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( + KVConnectorPromMetrics, + KVConnectorStats, + PromMetric, + PromMetricT, + ) + from vllm.forward_context import ForwardContext + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.core.kv_cache_utils import BlockHash + from vllm.v1.kv_cache_interface import KVCacheConfig + from vllm.v1.request import Request + +logger = lmcache_init_logger(__name__) + + +# Helper functions +def reformat_block_ids(block_ids: tuple[list[int], ...] | None) -> list[int]: + if block_ids is None: + return [] + assert isinstance(block_ids, tuple), ( + f"Expected block_ids to be a tuple of lists, but got {type(block_ids)}" + ) + + if len(block_ids) > 1: + raise RuntimeError( + "LMCacheMPConnector only works without hybrid kv cache manager. " + "Please pass --disable-hybrid-kv-cache-manager when starting vllm" + ) + + return block_ids[0] + + +def create_scheduler_adapter( + server_url: str, zmq_context: zmq.Context, vllm_config: VllmConfig +) -> LMCacheMPSchedulerAdapter: + # TODO: have a helper function to calculate the correct rank and + # world size for the MLA and other models + return LMCacheMPSchedulerAdapter( + server_url, + zmq_context, + vllm_config.model_config.model, + vllm_config.parallel_config.world_size, + vllm_config.parallel_config.rank, + vllm_config.cache_config.block_size, + ) + + +def create_worker_adapter( + server_url: str, zmq_context: zmq.Context, vllm_config: VllmConfig +) -> LMCacheMPWorkerAdapter: + # TODO: have a helper function to calculate the correct rank and + # world size for the MLA and other models + return LMCacheMPWorkerAdapter( + server_url, + zmq_context, + vllm_config.model_config.model, + vllm_config.parallel_config.world_size, + vllm_config.parallel_config.rank, + vllm_config.cache_config.block_size, + ) + + +def convert_block_hashes_to_bytes( + block_hashes: list["BlockHash"], +) -> list[bytes]: + return cast(list[bytes], block_hashes) + + +class LMCacheMPRequestState(enum.Enum): + """ + State machine: + PREFETCHING -- update_state_after_alloc --> WAITING_FOR_LOAD + WAITING_FOR_LOAD -- process_loading_requests --> READY + """ + + PREFETCHING = enum.auto() + WAITING_FOR_LOAD = enum.auto() + READY = enum.auto() + + +@dataclass +class LMCacheMPRequestTracker: + # NOTE: this class used vLLM data structures, should be part of + # vLLM integration code + + request_id: str + + # Read-only lists to track the token ids and block hashes + all_token_ids: ConstantList[int] + block_hashes: ConstantList["BlockHash"] + + # Block ids and hashes will be updated at update_states_after_alloc and + # during the generation + allocated_block_ids: list[int] = field(default_factory=list) + + # Number of scheduled tokens in this request. We keep tracking this to + # avoid saving half-full blocks. + num_scheduled_tokens: int = 0 + + # Number of blocks stored will be initialized when lookup the external + # hit tokens and will be updated when processing new requests and cached + # requests. + num_stored_blocks: int = 0 + + # Staging load operation -- save vllm and lmcache hit tokens during lookup + num_vllm_hit_blocks: int = 0 + num_lmcache_hit_blocks: int = 0 + + # Main state + state: LMCacheMPRequestState = LMCacheMPRequestState.PREFETCHING + + def __init__(self, request: "Request"): + self.request_id = request.request_id + self.all_token_ids = request.all_token_ids + self.block_hashes = ConstantList(request.block_hashes) + self.allocated_block_ids = [] + self.num_stored_blocks = 0 + self.num_vllm_hit_blocks = 0 + self.num_lmcache_hit_blocks = 0 + self.state = LMCacheMPRequestState.PREFETCHING + + #### + # Check the state of the request + #### + def needs_retrieve(self) -> bool: + """Check whether the current request needs retrieve, will be used + update_stage_after_alloc""" + return ( + self.num_lmcache_hit_blocks > self.num_vllm_hit_blocks + and self.state != LMCacheMPRequestState.READY + ) + + def is_ready_for_retrieving(self) -> bool: + """Check whether the current request is ready for retrieving, + will be used in process_loading_requests""" + return ( + self.state == LMCacheMPRequestState.WAITING_FOR_LOAD + and self.needs_retrieve() + ) + + #### + # Update internal states + #### + def increase_num_scheduled_tokens(self, num_new_tokens: int): + self.num_scheduled_tokens += num_new_tokens + + def increase_num_stored_blocks(self, num_new_blocks: int): + """Increase the number of stored blocks for the current request + This function will be called when processing the cached requests. + """ + self.num_stored_blocks += num_new_blocks + + def update_block_ids( + self, + new_block_ids: list[int], + ): + """Update the block ids for the current request + This function will be called when processing the cached requests. + """ + self.allocated_block_ids.extend(new_block_ids) + + #### + # For debugging + #### + def __repr__(self) -> str: + return ( + f"LMCacheMPRequestTracker(request_id={self.request_id}, " + f"num_tokens={len(self.all_token_ids)}, " + f"num_block_hashes={len(self.block_hashes)}, " + f"num_allocated_blocks={len(self.allocated_block_ids)}, " + f"num_stored_blocks={self.num_stored_blocks}, " + f"vllm_hit_blocks={self.num_vllm_hit_blocks}, " + f"lmcache_hit_blocks={self.num_lmcache_hit_blocks}, " + f"state={self.state})" + ) + + def __str__(self) -> str: + return self.__repr__() + + +@dataclass +class LMCacheMPRequestMetadata: + request_id: str + direction: Literal["STORE", "RETRIEVE"] + op: LoadStoreOp + + @staticmethod + def GetStoreMetadata( + tracker: LMCacheMPRequestTracker, + blocks_in_chunk: int, + vllm_block_size: int, + ) -> "LMCacheMPRequestMetadata | None": + """ + Generate the store metadata for the current request tracker. + + Args: + tracker: The request tracker to generate the metadata from. + blocks_in_chunk: the number of blocks in a LMCache data chunk + """ + # Store the blocks that has block hashes + # NOTE: the invariant here is that `num_stored_blocks` should + # always be a multiple of `blocks_in_chunk` + # TODO: This should be checked everytime we update the num_stored_blocks + min_available_blocks = min( + len(tracker.block_hashes), + len(tracker.allocated_block_ids), + tracker.num_scheduled_tokens // vllm_block_size, + ) + num_staging_blocks = min_available_blocks - tracker.num_stored_blocks + num_chunks = num_staging_blocks // blocks_in_chunk + + if num_chunks >= 1: + start = tracker.num_stored_blocks + end = start + num_chunks * blocks_in_chunk + block_hashes = convert_block_hashes_to_bytes( + tracker.block_hashes[start:end] + ) + block_ids = tracker.allocated_block_ids[start:end] + + ret = LMCacheMPRequestMetadata( + request_id=tracker.request_id, + direction="STORE", + op=LoadStoreOp(block_hashes=block_hashes, block_ids=block_ids), + ) + + # Update the request tracker + tracker.increase_num_stored_blocks(end - start) + return ret + + return None + + @staticmethod + def GetRetrieveMetadata( + tracker: LMCacheMPRequestTracker, + blocks_in_chunk: int, + ) -> "LMCacheMPRequestMetadata | None": + """ + Generate the retrieve metadata for the current request tracker. + + Args: + tracker: The request tracker to generate the metadata from. + blocks_in_chunk: the number of blocks in a LMCache data chunk + """ + if not tracker.is_ready_for_retrieving(): + return None + + # |---------------------|-----------------|----------------| + # | num_vllm_hit_blocks | + # | lmcache chunk 1 | lmcache chunk 2 | + # | need to retrieve | + + start = tracker.num_vllm_hit_blocks // blocks_in_chunk * blocks_in_chunk + end = tracker.num_lmcache_hit_blocks + assert end % blocks_in_chunk == 0, ( + "The number of LMCache hit blocks should be a multiple of the " + "number of blocks in a lmcache chunk. " + ) + assert len(tracker.block_hashes) >= end, ( + "The number of block hashes should be greater than or equal to the " + "number of LMCache hit blocks. " + ) + if end > start: + block_hashes = convert_block_hashes_to_bytes( + tracker.block_hashes[start:end] + ) + block_ids = tracker.allocated_block_ids[start:end] + + ret = LMCacheMPRequestMetadata( + request_id=tracker.request_id, + direction="RETRIEVE", + op=LoadStoreOp(block_hashes=block_hashes, block_ids=block_ids), + ) + return ret + + return None + + +class LMCacheMPConnectorMetadata(KVConnectorMetadata): + def __init__(self): + super().__init__() + self.requests: list[LMCacheMPRequestMetadata] = [] + + def add_request_metadata(self, request_metadata: LMCacheMPRequestMetadata): + self.requests.append(request_metadata) + + def __len__(self): + return len(self.requests) + + # For debugging + def __str__(self): + request_strs = [] + for req_meta in self.requests: + request_strs.append( + f"RequestMetadata(request_id={req_meta.request_id}, " + f"direction={req_meta.direction}, " + f"num_blocks={len(req_meta.op)}, " + f"block_ids={req_meta.op.block_ids})" + ) + return "[" + "\n".join(request_strs) + "]" + + def __repr__(self): + return self.__str__() + + +class LMCacheMPConnector(KVConnectorBase_V1): + """ + The connector for LMCache multi-process mode. + + Extra configs (kv_transfer_config.extra_config): + - lmcache.mp.host: the host of the LMCache server. + - lmcache.mp.port: the port of the LMCache server. + """ + + def __init__( + self, + vllm_config: "VllmConfig", + role: KVConnectorRole, + kv_cache_config: Optional["KVCacheConfig"] = None, + ): + super().__init__(vllm_config, role, kv_cache_config) + + assert vllm_config.kv_transfer_config is not None + server_host = vllm_config.kv_transfer_config.get_from_extra_config( + "lmcache.mp.host", "tcp://localhost" + ) + server_port = vllm_config.kv_transfer_config.get_from_extra_config( + "lmcache.mp.port", 5555 + ) + + server_url = f"{server_host}:{server_port}" + zmq_context = zmq.Context.instance() + if self.role == KVConnectorRole.SCHEDULER: + self.scheduler_adapter = create_scheduler_adapter( + server_url, zmq_context, vllm_config + ) + self.request_trackers: dict[str, LMCacheMPRequestTracker] = {} + elif self.role == KVConnectorRole.WORKER: + self.worker_adapter = create_worker_adapter( + server_url, zmq_context, vllm_config + ) + else: + raise ValueError(f"Unknown KVConnectorRole: {self.role}") + + self.vllm_block_size = vllm_config.cache_config.block_size + + @property + def role(self) -> KVConnectorRole: + return self._role + + # ============================== + # Worker-side methods + # ============================== + + def _get_connector_metadata(self) -> KVConnectorMetadata: + """Get the connector metadata. + + This function should only be called inside the connector. + + Returns: + ConnectorMetadata: the connector metadata. + """ + + # Should only be called while set to valid metadata. + assert self._connector_metadata is not None + return self._connector_metadata + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + """ + Initialize with the KV caches. Useful for pre-registering the + KV Caches in the KVConnector (e.g. for NIXL). + + Args: + kv_caches: dictionary of layer names, kv cache + """ + logger.info("Registering kv caches!") + self.worker_adapter.register_kv_caches(kv_caches) + return + + def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None: + """ + Start loading the KV cache from the connector to vLLM's paged + KV buffer. This is called from the forward context before the + forward pass to enable async loading during model execution. + + Args: + forward_context (ForwardContext): the forward context. + **kwargs: additional arguments for the load operation + + Note: + The number of elements in kv_caches and layer_names should be + the same. + + """ + metadata = self._get_connector_metadata() + assert isinstance(metadata, LMCacheMPConnectorMetadata) + + with torch.cuda.stream(torch.cuda.current_stream()): + event = torch.cuda.Event(interprocess=True) + event.record() + + request_ids = [] + ops = [] + + for meta in metadata.requests: + if meta.direction != "RETRIEVE": + continue + request_ids.append(meta.request_id) + ops.append(meta.op) + + if len(request_ids) > 0: + logger.info( + "HERE! SUBMITTING THE BATCHED RETRIEVE REQUESTS %s", request_ids + ) + self.worker_adapter.batched_submit_retrieve_requests( + request_ids, ops, event + ) + + def wait_for_layer_load(self, layer_name: str) -> None: + """ + Block until the KV for a specific layer is loaded into vLLM's + paged buffer. This is called from within attention layer to ensure + async copying from start_load_kv is complete. + + This interface will be useful for layer-by-layer pipelining. + + Args: + layer_name: the name of that layer + """ + return + + def save_kv_layer( + self, + layer_name: str, + kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", + **kwargs: Any, + ) -> None: + """ + Start saving a layer of KV cache from vLLM's paged buffer + to the connector. This is called from within attention layer to + enable async copying during execution. + + Args: + layer_name (str): the name of the layer. + kv_layer (torch.Tensor): the paged KV buffer of the current + layer in vLLM. + attn_metadata (AttentionMetadata): the attention metadata. + **kwargs: additional arguments for the save operation. + """ + return + + def wait_for_save(self): + """ + Block until all the save operations is done. This is called + as the forward context exits to ensure that the async saving + from save_kv_layer is complete before finishing the forward. + + This prevents overwrites of paged KV buffer before saving done. + """ + metadata = self._get_connector_metadata() + assert isinstance(metadata, LMCacheMPConnectorMetadata) + + with torch.cuda.stream(torch.cuda.current_stream()): + event = torch.cuda.Event(interprocess=True) + event.record() + + request_ids = [] + ops = [] + for meta in metadata.requests: + if meta.direction != "STORE": + continue + request_ids.append(meta.request_id) + ops.append(meta.op) + + if len(request_ids) > 0: + self.worker_adapter.batched_submit_store_requests(request_ids, ops, event) + + def get_finished( + self, finished_req_ids: set[str] + ) -> tuple[set[str] | None, set[str] | None]: + """ + Notifies worker-side connector ids of requests that have + finished generating tokens on the worker. + The scheduler process (via the Executors) will use this output + to track which workers are done. + + Returns: + ids of requests that have finished asynchronous transfer + (requests that previously returned True from request_finished()), + tuple of (sending/saving ids, recving/loading ids). + The finished saves/sends req ids must belong to a set provided in a + call to this method (this call or a prior one). + """ + val = self.worker_adapter.get_finished(finished_req_ids) + # logger.error("Finished req ids: %s, %s", val[0], val[1]) + return val + + def get_block_ids_with_load_errors(self) -> set[int]: + """ + Get the set of block IDs that failed to load. + + Returns: + Set of block IDs that encountered load errors. + Empty set if no load errors occurred. + + Notes: + - Applies to both sync- and async-loading requests. + - Async loading: failed blocks may be reported in any forward pass + up to and including the pass where the request ID is returned by + `get_finished()`. Even if failures occur, the request must still + be reported via `get_finished()`, and the failed block IDs must + appear here no later than that same pass. + - Sync loading: failed blocks should be reported in the forward + pass in which they are detected. + """ + # TODO: add error tracking + return set() + + def shutdown(self): + """ + Shutdown the connector. This is called when the worker process + is shutting down to ensure that all the async operations are + completed and the connector is cleaned up properly. + """ + if hasattr(self, "worker_adapter"): + self.worker_adapter.shutdown() + return None + + def get_kv_connector_stats(self) -> Optional["KVConnectorStats"]: + """ + Get the KV connector stats collected during the last interval. + """ + return None + + # ============================== + # Scheduler-side methods + # ============================== + + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> tuple[int | None, bool]: + """ + Get number of new tokens that can be loaded from the + external KV cache beyond the num_computed_tokens. + + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + + Returns: + A tuple with the following elements: + - An optional number of tokens that can be loaded from the + external KV cache beyond what is already computed. + If None, it means that the connector needs more time to + determine the number of matched tokens, and the scheduler + should query for this request again later. + - `True` if external KV cache tokens will be loaded + asynchronously (between scheduler steps). Must be + 'False' if the first element is 0. + + Notes: + The connector should only consider the largest prefix of prompt- + tokens for which KV cache is actually available at the time of the + call. If the cache cannot be loaded for some tokens (e.g., due to + connectivity issues or eviction), those tokens must not be taken + into account. + """ + tracker = self._get_or_create_request_tracker(request) + + self.scheduler_adapter.maybe_submit_lookup_request( + request.request_id, convert_block_hashes_to_bytes(request.block_hashes) + ) + + ret = self.scheduler_adapter.check_lookup_result(request.request_id) + if ret is None: + return None, True + + if ret == 0: + return 0, False + + assert ( + ret % (self.scheduler_adapter.num_blocks_per_chunk() * self.vllm_block_size) + == 0 + ) + + # Update num stored blocks for the tracker + num_vllm_blocks = num_computed_tokens // self.vllm_block_size + num_lmcache_blocks = ret // self.vllm_block_size + tracker.increase_num_stored_blocks(num_lmcache_blocks) + + # Save the vllm and lmcache hit tokens + tracker.num_vllm_hit_blocks = num_vllm_blocks + tracker.num_lmcache_hit_blocks = num_lmcache_blocks + + need_to_load = max(0, ret - num_computed_tokens) + logger.debug( + "vLLM hit is: %d, Need to load is %d", num_computed_tokens, need_to_load + ) + return need_to_load, need_to_load > 0 + + def update_state_after_alloc( + self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int + ): + """ + Update KVConnector state after block allocation. + + If get_num_new_matched_tokens previously returned True for a + request, this function may be called twice for that same request - + first when blocks are allocated for the connector tokens to be + asynchronously loaded into, and second when any additional blocks + are allocated, after the load/transfer is complete. + + Args: + request (Request): the request object. + blocks (KVCacheBlocks): the blocks allocated for the request. + num_external_tokens (int): the number of tokens that will be + loaded from the external KV cache. + """ + # NOTE: the `blocks` are NEW BLOCKS allocated for this request. + tracker = self._get_request_tracker(request.request_id) + block_ids = reformat_block_ids(blocks.get_block_ids()) + + # No matter we need to retrieve or not, we need to update + # the block ids into the tracker + tracker.update_block_ids(block_ids) + + # Update the state of the tracker + condition = tracker.needs_retrieve() + if tracker.state == LMCacheMPRequestState.PREFETCHING: + # If need to retrieve, change to WAITING_FOR_LOAD + # Otherwise, change to READY + tracker.state = ( + LMCacheMPRequestState.WAITING_FOR_LOAD + if condition + else LMCacheMPRequestState.READY + ) + + def build_connector_meta( + self, scheduler_output: SchedulerOutput + ) -> KVConnectorMetadata: + """ + Build the connector metadata for this step. + + This function should NOT modify fields in the scheduler_output. + Also, calling this function will reset the state of the connector. + + Args: + scheduler_output (SchedulerOutput): the scheduler output object. + """ + metadata = LMCacheMPConnectorMetadata() + + self._process_retrieve_requests(metadata) + self._process_new_requests(scheduler_output, metadata) + self._process_cached_requests(scheduler_output, metadata) + + if len(metadata) > 0: + logger.debug("Final connector metadata: %s", metadata) + + return metadata + + def update_connector_output(self, connector_output: KVConnectorOutput): + """ + Update KVConnector state from worker-side connectors output. + + Args: + connector_output (KVConnectorOutput): the worker-side + connectors output. + """ + return + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, dict[str, Any] | None]: + """ + Called exactly once when a request has finished, before its blocks are + freed. + + The connector may assumes responsibility for freeing the blocks + asynchronously by returning True. + + Returns: + True if the request is being saved/sent asynchronously and blocks + should not be freed until the request_id is returned from + get_finished(). + Optional KVTransferParams to be included in the request outputs + returned by the engine. + """ + return True, None + + def take_events(self) -> Iterable["KVCacheEvent"]: + """ + Take the KV cache events from the connector. + + Yields: + New KV cache events since the last call. + """ + return () + + @classmethod + def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> str | None: + """ + Get the required KV cache layout for this connector. + Args: + vllm_config (VllmConfig): the vllm config. + + Returns: + str: the required KV cache layout. e.g. HND, or NHD. + None if the connector does not require a specific layout. + """ + + if cls is KVConnectorBase_V1: + raise TypeError( + "get_required_kvcache_layout should not be called " + "on the abstract base class" + ) + return None + + def get_finished_count(self) -> int | None: + """ + Get the count of requests expected to complete send/receive operations + via this connector. This method is used to initialize the + KVOutputAggregator, overwriting the default world_size. + + Returns: + int: expected sending or receiving completion count. + """ + return None + + @classmethod + def build_kv_connector_stats( + cls, data: dict[str, Any] | None = None + ) -> Optional["KVConnectorStats"]: + """ + KVConnectorStats resolution method. This method allows dynamically + registered connectors to return their own KVConnectorStats object, + which can implement custom aggregation logic on the data dict. + """ + return None + + @classmethod + def build_prom_metrics( + cls, + vllm_config: "VllmConfig", + metric_types: dict[type["PromMetric"], type["PromMetricT"]], + labelnames: list[str], + per_engine_labelvalues: dict[int, list[str]], + ) -> Optional["KVConnectorPromMetrics"]: + """ + Create a KVConnectorPromMetrics subclass which should register + per-connector Prometheus metrics and implement observe() to + expose connector transfer stats via Prometheus. + """ + return None + + ############################## + # Helper functions + ############################## + def _process_retrieve_requests( + self, + metadata: LMCacheMPConnectorMetadata, + ) -> None: + blocks_per_chunk = self.scheduler_adapter.num_blocks_per_chunk() + + for request_tracker in self.request_trackers.values(): + if request_tracker.state != LMCacheMPRequestState.WAITING_FOR_LOAD: + continue + r_metadata = LMCacheMPRequestMetadata.GetRetrieveMetadata( + request_tracker, blocks_per_chunk + ) + if r_metadata is not None: + metadata.add_request_metadata(r_metadata) + request_tracker.state = LMCacheMPRequestState.READY + + def _process_new_requests( + self, + scheduler_output: SchedulerOutput, + metadata: LMCacheMPConnectorMetadata, + ) -> None: + blocks_per_chunk = self.scheduler_adapter.num_blocks_per_chunk() + + for new_request in scheduler_output.scheduled_new_reqs: + request_tracker = self._get_request_tracker(new_request.req_id) + + num_new_tokens = scheduler_output.num_scheduled_tokens[new_request.req_id] + request_tracker.increase_num_scheduled_tokens(num_new_tokens) + + r_meta = LMCacheMPRequestMetadata.GetStoreMetadata( + request_tracker, blocks_per_chunk, self.vllm_block_size + ) + if r_meta is not None: + metadata.add_request_metadata(r_meta) + + def _process_cached_requests( + self, + scheduler_output: SchedulerOutput, + metadata: LMCacheMPConnectorMetadata, + ) -> None: + blocks_per_chunk = self.scheduler_adapter.num_blocks_per_chunk() + + cached_reqs = scheduler_output.scheduled_cached_reqs + for idx, request_id in enumerate(cached_reqs.req_ids): + request_tracker = self._get_request_tracker(request_id) + + # Update block ids + new_block_ids = reformat_block_ids(cached_reqs.new_block_ids[idx]) + request_tracker.update_block_ids(new_block_ids) + + # Update new scheduled tokens + num_new_tokens = cached_reqs.num_computed_tokens[idx] + request_tracker.increase_num_scheduled_tokens(num_new_tokens) + + r_meta = LMCacheMPRequestMetadata.GetStoreMetadata( + request_tracker, blocks_per_chunk, self.vllm_block_size + ) + + if r_meta is not None: + metadata.add_request_metadata(r_meta) + + def _get_request_tracker(self, request_id: str) -> LMCacheMPRequestTracker: + assert request_id in self.request_trackers, ( + f"Request tracker for request_id {request_id} not found. " + ) + return self.request_trackers[request_id] + + def _get_or_create_request_tracker( + self, request: "Request" + ) -> LMCacheMPRequestTracker: + request_id = request.request_id + if request_id not in self.request_trackers: + new_tracker = LMCacheMPRequestTracker(request) + self.request_trackers[request_id] = new_tracker + return self.request_trackers[request_id]