mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-18 19:45:26 +08:00
[KV connector][WIP] KV cache proxy based on LMCache multi-process mode (#27902)
Signed-off-by: ApostaC <yihua98@uchicago.edu>
This commit is contained in:
parent
a39dd7bb06
commit
94a9ebcf31
@ -161,6 +161,12 @@ KVConnectorFactory.register_connector(
|
|||||||
"LMCacheConnectorV1",
|
"LMCacheConnectorV1",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
KVConnectorFactory.register_connector(
|
||||||
|
"LMCacheMPConnector",
|
||||||
|
"vllm.distributed.kv_transfer.kv_connector.v1.lmcache_mp_connector",
|
||||||
|
"LMCacheMPConnector",
|
||||||
|
)
|
||||||
|
|
||||||
KVConnectorFactory.register_connector(
|
KVConnectorFactory.register_connector(
|
||||||
"NixlConnector",
|
"NixlConnector",
|
||||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector",
|
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector",
|
||||||
|
|||||||
@ -2,6 +2,17 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
|
||||||
from . import vllm_v1_adapter
|
from . import multi_process_adapter, vllm_v1_adapter
|
||||||
|
from .multi_process_adapter import (
|
||||||
|
LMCacheMPSchedulerAdapter,
|
||||||
|
LMCacheMPWorkerAdapter,
|
||||||
|
LoadStoreOp,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = ["vllm_v1_adapter"]
|
__all__ = [
|
||||||
|
"vllm_v1_adapter",
|
||||||
|
"multi_process_adapter",
|
||||||
|
"LMCacheMPSchedulerAdapter",
|
||||||
|
"LMCacheMPWorkerAdapter",
|
||||||
|
"LoadStoreOp",
|
||||||
|
]
|
||||||
|
|||||||
@ -0,0 +1,379 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import os
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from itertools import islice
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import zmq
|
||||||
|
from lmcache.utils import _lmcache_nvtx_annotate, init_logger
|
||||||
|
from lmcache.v1.multiprocess.custom_types import (
|
||||||
|
CudaIPCWrapper,
|
||||||
|
IPCCacheEngineKey,
|
||||||
|
KVCache,
|
||||||
|
)
|
||||||
|
from lmcache.v1.multiprocess.mq import MessageQueueClient, MessagingFuture
|
||||||
|
from lmcache.v1.multiprocess.protocol import RequestType, get_response_class
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def wrap_kv_caches(kv_caches: dict[str, KVCache]) -> KVCache:
|
||||||
|
logger.info("KV caches keys are %s", list(kv_caches.keys()))
|
||||||
|
return [CudaIPCWrapper(tensor) for tensor in kv_caches.values()]
|
||||||
|
|
||||||
|
|
||||||
|
def send_lmcache_request(
|
||||||
|
mq_client: MessageQueueClient,
|
||||||
|
request_type: RequestType,
|
||||||
|
payloads: list[Any],
|
||||||
|
) -> MessagingFuture[Any]:
|
||||||
|
future = mq_client.submit_request(
|
||||||
|
request_type, payloads, get_response_class(request_type)
|
||||||
|
)
|
||||||
|
return future
|
||||||
|
|
||||||
|
|
||||||
|
def get_lmcache_chunk_size(
|
||||||
|
mq_client: MessageQueueClient,
|
||||||
|
) -> int:
|
||||||
|
future = send_lmcache_request(mq_client, RequestType.GET_CHUNK_SIZE, [])
|
||||||
|
chunk_size = future.result()
|
||||||
|
return chunk_size
|
||||||
|
|
||||||
|
|
||||||
|
def striding_block_hashes(
|
||||||
|
block_hashes: list[bytes],
|
||||||
|
blocks_in_chunk,
|
||||||
|
) -> Iterable[bytes]:
|
||||||
|
"""Striding the block hashes to get the block hashes for each chunk.
|
||||||
|
For example, if blocks_in_chunk is 16, then we will get the block hashes
|
||||||
|
for the 16th, 32nd, 48th, ... blocks.
|
||||||
|
"""
|
||||||
|
return islice(block_hashes, blocks_in_chunk - 1, None, blocks_in_chunk)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LoadStoreOp:
|
||||||
|
block_hashes: list[bytes]
|
||||||
|
block_ids: list[int]
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.block_hashes)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
assert len(self.block_hashes) == len(self.block_ids), (
|
||||||
|
"The number of block hashes should be equal to the number of block ids "
|
||||||
|
f"But got {len(self.block_hashes)} and {len(self.block_ids)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
StoreResult = bool
|
||||||
|
RetrieveResult = list[bool]
|
||||||
|
LookupResult = list[bool]
|
||||||
|
|
||||||
|
|
||||||
|
class LMCacheMPSchedulerAdapter:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
server_url: str,
|
||||||
|
context: zmq.Context,
|
||||||
|
model_name: str,
|
||||||
|
world_size: int,
|
||||||
|
kv_rank: int,
|
||||||
|
vllm_block_size: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
server_url: The server URL for the LMCache message queue
|
||||||
|
context: The ZMQ context
|
||||||
|
|
||||||
|
model_name: The model name used for LMCache keys
|
||||||
|
world_size: The world size used for LMCache keys
|
||||||
|
kv_rank: The kv rank used for LMCache keys
|
||||||
|
vllm_block_size: The block size used in vLLM
|
||||||
|
"""
|
||||||
|
self.mq_client = MessageQueueClient(server_url, context)
|
||||||
|
|
||||||
|
# Request futures
|
||||||
|
self.lookup_futures: dict[str, MessagingFuture[LookupResult]] = {}
|
||||||
|
|
||||||
|
self.model_name = model_name
|
||||||
|
self.world_size = world_size
|
||||||
|
self.worker_id = kv_rank
|
||||||
|
|
||||||
|
# Read chunk size from lmcache
|
||||||
|
self.chunk_size = get_lmcache_chunk_size(self.mq_client)
|
||||||
|
assert self.chunk_size % vllm_block_size == 0, (
|
||||||
|
"LMCache chunk size should be a multiple of vLLM block size"
|
||||||
|
)
|
||||||
|
self.blocks_in_chunk = self.chunk_size // vllm_block_size
|
||||||
|
|
||||||
|
@_lmcache_nvtx_annotate
|
||||||
|
def maybe_submit_lookup_request(self, request_id: str, block_hashes: list[bytes]):
|
||||||
|
if request_id in self.lookup_futures:
|
||||||
|
# Skip if there is already a lookup request
|
||||||
|
return
|
||||||
|
|
||||||
|
s = striding_block_hashes(block_hashes, self.blocks_in_chunk)
|
||||||
|
keys = [self._create_key(block_hash) for block_hash in s]
|
||||||
|
future = send_lmcache_request(
|
||||||
|
self.mq_client,
|
||||||
|
RequestType.LOOKUP,
|
||||||
|
[keys, True],
|
||||||
|
)
|
||||||
|
self.lookup_futures[request_id] = future
|
||||||
|
|
||||||
|
@_lmcache_nvtx_annotate
|
||||||
|
def check_lookup_result(self, request_id: str) -> int | None:
|
||||||
|
assert request_id in self.lookup_futures, (
|
||||||
|
f"Lookup request for request_id={request_id} has not been submitted"
|
||||||
|
)
|
||||||
|
|
||||||
|
future = self.lookup_futures[request_id]
|
||||||
|
if not future.query():
|
||||||
|
return None
|
||||||
|
|
||||||
|
result = future.result()
|
||||||
|
num_chunks = sum(result)
|
||||||
|
return num_chunks * self.chunk_size
|
||||||
|
|
||||||
|
def num_blocks_per_chunk(self) -> int:
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
The number of vllm blocks in a LMCache data chunk
|
||||||
|
"""
|
||||||
|
return self.blocks_in_chunk
|
||||||
|
|
||||||
|
# Helper functions
|
||||||
|
def _create_key(self, block_hash: bytes) -> IPCCacheEngineKey:
|
||||||
|
"""Convert a block hash to an IPC cache engine key"""
|
||||||
|
return IPCCacheEngineKey(
|
||||||
|
model_name=self.model_name,
|
||||||
|
world_size=self.world_size,
|
||||||
|
worker_id=self.worker_id,
|
||||||
|
chunk_hash=block_hash,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LMCacheMPWorkerAdapter:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
server_url: str,
|
||||||
|
context: zmq.Context,
|
||||||
|
model_name: str,
|
||||||
|
world_size: int,
|
||||||
|
kv_rank: int,
|
||||||
|
vllm_block_size: int,
|
||||||
|
):
|
||||||
|
self.mq_client = MessageQueueClient(server_url, context)
|
||||||
|
|
||||||
|
# Instance id for GPU worker
|
||||||
|
self.instance_id = os.getpid()
|
||||||
|
|
||||||
|
# Registered kv caches from vLLM
|
||||||
|
self.kv_caches: dict[str, torch.Tensor] = {}
|
||||||
|
|
||||||
|
# Request futures
|
||||||
|
# request_id -> (future, other merged requests)
|
||||||
|
self.store_futures: dict[
|
||||||
|
str, tuple[MessagingFuture[StoreResult], list[str]]
|
||||||
|
] = {}
|
||||||
|
self.retrieve_futures: dict[
|
||||||
|
str, tuple[MessagingFuture[RetrieveResult], list[str]]
|
||||||
|
] = {}
|
||||||
|
|
||||||
|
self.finished_stores: set[str] = set()
|
||||||
|
self.previously_finished: set[str] = set()
|
||||||
|
|
||||||
|
self.model_name = model_name
|
||||||
|
self.world_size = world_size
|
||||||
|
self.worker_id = kv_rank
|
||||||
|
|
||||||
|
# Read chunk size from lmcache
|
||||||
|
chunk_size = get_lmcache_chunk_size(self.mq_client)
|
||||||
|
assert chunk_size % vllm_block_size == 0, (
|
||||||
|
"LMCache chunk size should be a multiple of vLLM block size"
|
||||||
|
)
|
||||||
|
self.blocks_in_chunk = chunk_size // vllm_block_size
|
||||||
|
|
||||||
|
def register_kv_caches(self, kv_caches: dict[str, KVCache]):
|
||||||
|
# Register kv cache and send the request
|
||||||
|
self.kv_caches = kv_caches
|
||||||
|
logger.info("Registering kv caches")
|
||||||
|
future = send_lmcache_request(
|
||||||
|
self.mq_client,
|
||||||
|
RequestType.REGISTER_KV_CACHE,
|
||||||
|
[self.instance_id, wrap_kv_caches(kv_caches)],
|
||||||
|
)
|
||||||
|
future.result()
|
||||||
|
|
||||||
|
@_lmcache_nvtx_annotate
|
||||||
|
def submit_store_request(
|
||||||
|
self, request_id: str, op: LoadStoreOp, event: torch.cuda.Event
|
||||||
|
):
|
||||||
|
keys = self._block_hashes_to_keys(op.block_hashes)
|
||||||
|
future = send_lmcache_request(
|
||||||
|
self.mq_client,
|
||||||
|
RequestType.STORE,
|
||||||
|
[keys, self.instance_id, op.block_ids, event.ipc_handle()],
|
||||||
|
).to_cuda_future()
|
||||||
|
self.store_futures[request_id] = (future, [])
|
||||||
|
|
||||||
|
@_lmcache_nvtx_annotate
|
||||||
|
def submit_retrieve_request(
|
||||||
|
self, request_id: str, op: LoadStoreOp, event: torch.cuda.Event
|
||||||
|
):
|
||||||
|
keys = self._block_hashes_to_keys(op.block_hashes)
|
||||||
|
future = send_lmcache_request(
|
||||||
|
self.mq_client,
|
||||||
|
RequestType.RETRIEVE,
|
||||||
|
[keys, self.instance_id, op.block_ids, event.ipc_handle()],
|
||||||
|
).to_cuda_future()
|
||||||
|
self.retrieve_futures[request_id] = (future, [])
|
||||||
|
|
||||||
|
@_lmcache_nvtx_annotate
|
||||||
|
def batched_submit_store_requests(
|
||||||
|
self,
|
||||||
|
request_ids: list[str],
|
||||||
|
ops: list[LoadStoreOp],
|
||||||
|
event: torch.cuda.Event,
|
||||||
|
):
|
||||||
|
keys = []
|
||||||
|
block_ids = []
|
||||||
|
for op in ops:
|
||||||
|
keys.extend(self._block_hashes_to_keys(op.block_hashes))
|
||||||
|
block_ids.extend(op.block_ids)
|
||||||
|
future = send_lmcache_request(
|
||||||
|
self.mq_client,
|
||||||
|
RequestType.STORE,
|
||||||
|
[keys, self.instance_id, block_ids, event.ipc_handle()],
|
||||||
|
).to_cuda_future()
|
||||||
|
self.store_futures[request_ids[0]] = (future, request_ids[1:])
|
||||||
|
|
||||||
|
@_lmcache_nvtx_annotate
|
||||||
|
def batched_submit_retrieve_requests(
|
||||||
|
self,
|
||||||
|
request_ids: list[str],
|
||||||
|
ops: list[LoadStoreOp],
|
||||||
|
event: torch.cuda.Event,
|
||||||
|
):
|
||||||
|
keys = []
|
||||||
|
block_ids = []
|
||||||
|
for op in ops:
|
||||||
|
keys.extend(self._block_hashes_to_keys(op.block_hashes))
|
||||||
|
block_ids.extend(op.block_ids)
|
||||||
|
future = send_lmcache_request(
|
||||||
|
self.mq_client,
|
||||||
|
RequestType.RETRIEVE,
|
||||||
|
[keys, self.instance_id, block_ids, event.ipc_handle()],
|
||||||
|
).to_cuda_future()
|
||||||
|
self.retrieve_futures[request_ids[0]] = (future, request_ids[1:])
|
||||||
|
|
||||||
|
@_lmcache_nvtx_annotate
|
||||||
|
def get_finished(
|
||||||
|
self, finished_req_ids: set[str]
|
||||||
|
) -> tuple[set[str] | None, set[str] | None]:
|
||||||
|
finished_stores = set()
|
||||||
|
finished_retrieves = set()
|
||||||
|
for request_id, (future, other_reqs) in self.store_futures.items():
|
||||||
|
if not future.query():
|
||||||
|
continue
|
||||||
|
|
||||||
|
result = future.result()
|
||||||
|
finished_stores.add(request_id)
|
||||||
|
finished_stores.update(other_reqs)
|
||||||
|
|
||||||
|
if not result:
|
||||||
|
# TODO: add error handling here
|
||||||
|
logger.error(
|
||||||
|
"Something went wrong when processing the "
|
||||||
|
"store request for request_id=%s",
|
||||||
|
request_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
for request_id, (future, other_reqs) in self.retrieve_futures.items():
|
||||||
|
if not future.query():
|
||||||
|
continue
|
||||||
|
|
||||||
|
result = future.result()
|
||||||
|
finished_retrieves.add(request_id)
|
||||||
|
finished_retrieves.update(other_reqs)
|
||||||
|
|
||||||
|
if not all(result):
|
||||||
|
# TODO: add error handing here
|
||||||
|
logger.error(
|
||||||
|
"Something went wrong when processing the "
|
||||||
|
"retrieve request for request_id=%s, result=%s",
|
||||||
|
request_id,
|
||||||
|
result,
|
||||||
|
)
|
||||||
|
logger.info("Retrieve request for request_id=%s finished", request_id)
|
||||||
|
|
||||||
|
# Remove the finished requests from the tracking dicts
|
||||||
|
for request_id in finished_stores:
|
||||||
|
self.store_futures.pop(request_id, None)
|
||||||
|
for request_id in finished_retrieves:
|
||||||
|
self.retrieve_futures.pop(request_id, None)
|
||||||
|
|
||||||
|
# Update the internal states
|
||||||
|
self.finished_stores.update(finished_stores)
|
||||||
|
|
||||||
|
ret_stores = set()
|
||||||
|
for req_id in finished_req_ids:
|
||||||
|
if req_id in self.finished_stores or req_id in self.store_futures:
|
||||||
|
self.previously_finished.add(req_id)
|
||||||
|
else:
|
||||||
|
ret_stores.add(req_id)
|
||||||
|
|
||||||
|
# Calculate the final finished stores
|
||||||
|
ret_stores.update(self._update_and_get_finished_store())
|
||||||
|
|
||||||
|
return ret_stores, finished_retrieves
|
||||||
|
|
||||||
|
def num_blocks_per_chunk(self) -> int:
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
The number of vllm blocks in a LMCache data chunk
|
||||||
|
"""
|
||||||
|
return self.blocks_in_chunk
|
||||||
|
|
||||||
|
def shutdown(self):
|
||||||
|
# Unregister kv cache
|
||||||
|
logger.info("Unregistering kv caches")
|
||||||
|
send_lmcache_request(
|
||||||
|
self.mq_client, RequestType.UNREGISTER_KV_CACHE, [self.instance_id]
|
||||||
|
).result()
|
||||||
|
|
||||||
|
self.mq_client.close()
|
||||||
|
|
||||||
|
# Helper functions
|
||||||
|
def _update_and_get_finished_store(
|
||||||
|
self,
|
||||||
|
) -> set[str]:
|
||||||
|
"""Converge the internal states about finished stores
|
||||||
|
and returns the 'safe finished store request ids' back
|
||||||
|
"""
|
||||||
|
safe_finished_s = self.finished_stores.intersection(self.previously_finished)
|
||||||
|
self.finished_stores.difference_update(self.previously_finished)
|
||||||
|
self.previously_finished.difference_update(safe_finished_s)
|
||||||
|
|
||||||
|
return safe_finished_s
|
||||||
|
|
||||||
|
def _create_key(self, block_hash: bytes) -> IPCCacheEngineKey:
|
||||||
|
"""Convert a block hash to an IPC cache engine key"""
|
||||||
|
return IPCCacheEngineKey(
|
||||||
|
model_name=self.model_name,
|
||||||
|
world_size=self.world_size,
|
||||||
|
worker_id=self.worker_id,
|
||||||
|
chunk_hash=block_hash,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _block_hashes_to_keys(
|
||||||
|
self, block_hashes: list[bytes]
|
||||||
|
) -> list[IPCCacheEngineKey]:
|
||||||
|
"""Convert block hashes to IPC cache engine keys"""
|
||||||
|
s = striding_block_hashes(block_hashes, self.blocks_in_chunk)
|
||||||
|
return [self._create_key(block_hash) for block_hash in s]
|
||||||
@ -0,0 +1,867 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import enum
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import zmq
|
||||||
|
from lmcache.utils import init_logger as lmcache_init_logger
|
||||||
|
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||||
|
KVConnectorBase_V1,
|
||||||
|
KVConnectorMetadata,
|
||||||
|
KVConnectorRole,
|
||||||
|
)
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration import (
|
||||||
|
LMCacheMPSchedulerAdapter,
|
||||||
|
LMCacheMPWorkerAdapter,
|
||||||
|
LoadStoreOp,
|
||||||
|
)
|
||||||
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
|
from vllm.v1.outputs import KVConnectorOutput
|
||||||
|
from vllm.v1.utils import ConstantList
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.distributed.kv_events import KVCacheEvent
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||||
|
KVConnectorPromMetrics,
|
||||||
|
KVConnectorStats,
|
||||||
|
PromMetric,
|
||||||
|
PromMetricT,
|
||||||
|
)
|
||||||
|
from vllm.forward_context import ForwardContext
|
||||||
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||||
|
from vllm.v1.core.kv_cache_utils import BlockHash
|
||||||
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||||
|
from vllm.v1.request import Request
|
||||||
|
|
||||||
|
logger = lmcache_init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Helper functions
|
||||||
|
def reformat_block_ids(block_ids: tuple[list[int], ...] | None) -> list[int]:
|
||||||
|
if block_ids is None:
|
||||||
|
return []
|
||||||
|
assert isinstance(block_ids, tuple), (
|
||||||
|
f"Expected block_ids to be a tuple of lists, but got {type(block_ids)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(block_ids) > 1:
|
||||||
|
raise RuntimeError(
|
||||||
|
"LMCacheMPConnector only works without hybrid kv cache manager. "
|
||||||
|
"Please pass --disable-hybrid-kv-cache-manager when starting vllm"
|
||||||
|
)
|
||||||
|
|
||||||
|
return block_ids[0]
|
||||||
|
|
||||||
|
|
||||||
|
def create_scheduler_adapter(
|
||||||
|
server_url: str, zmq_context: zmq.Context, vllm_config: VllmConfig
|
||||||
|
) -> LMCacheMPSchedulerAdapter:
|
||||||
|
# TODO: have a helper function to calculate the correct rank and
|
||||||
|
# world size for the MLA and other models
|
||||||
|
return LMCacheMPSchedulerAdapter(
|
||||||
|
server_url,
|
||||||
|
zmq_context,
|
||||||
|
vllm_config.model_config.model,
|
||||||
|
vllm_config.parallel_config.world_size,
|
||||||
|
vllm_config.parallel_config.rank,
|
||||||
|
vllm_config.cache_config.block_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_worker_adapter(
|
||||||
|
server_url: str, zmq_context: zmq.Context, vllm_config: VllmConfig
|
||||||
|
) -> LMCacheMPWorkerAdapter:
|
||||||
|
# TODO: have a helper function to calculate the correct rank and
|
||||||
|
# world size for the MLA and other models
|
||||||
|
return LMCacheMPWorkerAdapter(
|
||||||
|
server_url,
|
||||||
|
zmq_context,
|
||||||
|
vllm_config.model_config.model,
|
||||||
|
vllm_config.parallel_config.world_size,
|
||||||
|
vllm_config.parallel_config.rank,
|
||||||
|
vllm_config.cache_config.block_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_block_hashes_to_bytes(
|
||||||
|
block_hashes: list["BlockHash"],
|
||||||
|
) -> list[bytes]:
|
||||||
|
return cast(list[bytes], block_hashes)
|
||||||
|
|
||||||
|
|
||||||
|
class LMCacheMPRequestState(enum.Enum):
|
||||||
|
"""
|
||||||
|
State machine:
|
||||||
|
PREFETCHING -- update_state_after_alloc --> WAITING_FOR_LOAD
|
||||||
|
WAITING_FOR_LOAD -- process_loading_requests --> READY
|
||||||
|
"""
|
||||||
|
|
||||||
|
PREFETCHING = enum.auto()
|
||||||
|
WAITING_FOR_LOAD = enum.auto()
|
||||||
|
READY = enum.auto()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LMCacheMPRequestTracker:
|
||||||
|
# NOTE: this class used vLLM data structures, should be part of
|
||||||
|
# vLLM integration code
|
||||||
|
|
||||||
|
request_id: str
|
||||||
|
|
||||||
|
# Read-only lists to track the token ids and block hashes
|
||||||
|
all_token_ids: ConstantList[int]
|
||||||
|
block_hashes: ConstantList["BlockHash"]
|
||||||
|
|
||||||
|
# Block ids and hashes will be updated at update_states_after_alloc and
|
||||||
|
# during the generation
|
||||||
|
allocated_block_ids: list[int] = field(default_factory=list)
|
||||||
|
|
||||||
|
# Number of scheduled tokens in this request. We keep tracking this to
|
||||||
|
# avoid saving half-full blocks.
|
||||||
|
num_scheduled_tokens: int = 0
|
||||||
|
|
||||||
|
# Number of blocks stored will be initialized when lookup the external
|
||||||
|
# hit tokens and will be updated when processing new requests and cached
|
||||||
|
# requests.
|
||||||
|
num_stored_blocks: int = 0
|
||||||
|
|
||||||
|
# Staging load operation -- save vllm and lmcache hit tokens during lookup
|
||||||
|
num_vllm_hit_blocks: int = 0
|
||||||
|
num_lmcache_hit_blocks: int = 0
|
||||||
|
|
||||||
|
# Main state
|
||||||
|
state: LMCacheMPRequestState = LMCacheMPRequestState.PREFETCHING
|
||||||
|
|
||||||
|
def __init__(self, request: "Request"):
|
||||||
|
self.request_id = request.request_id
|
||||||
|
self.all_token_ids = request.all_token_ids
|
||||||
|
self.block_hashes = ConstantList(request.block_hashes)
|
||||||
|
self.allocated_block_ids = []
|
||||||
|
self.num_stored_blocks = 0
|
||||||
|
self.num_vllm_hit_blocks = 0
|
||||||
|
self.num_lmcache_hit_blocks = 0
|
||||||
|
self.state = LMCacheMPRequestState.PREFETCHING
|
||||||
|
|
||||||
|
####
|
||||||
|
# Check the state of the request
|
||||||
|
####
|
||||||
|
def needs_retrieve(self) -> bool:
|
||||||
|
"""Check whether the current request needs retrieve, will be used
|
||||||
|
update_stage_after_alloc"""
|
||||||
|
return (
|
||||||
|
self.num_lmcache_hit_blocks > self.num_vllm_hit_blocks
|
||||||
|
and self.state != LMCacheMPRequestState.READY
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_ready_for_retrieving(self) -> bool:
|
||||||
|
"""Check whether the current request is ready for retrieving,
|
||||||
|
will be used in process_loading_requests"""
|
||||||
|
return (
|
||||||
|
self.state == LMCacheMPRequestState.WAITING_FOR_LOAD
|
||||||
|
and self.needs_retrieve()
|
||||||
|
)
|
||||||
|
|
||||||
|
####
|
||||||
|
# Update internal states
|
||||||
|
####
|
||||||
|
def increase_num_scheduled_tokens(self, num_new_tokens: int):
|
||||||
|
self.num_scheduled_tokens += num_new_tokens
|
||||||
|
|
||||||
|
def increase_num_stored_blocks(self, num_new_blocks: int):
|
||||||
|
"""Increase the number of stored blocks for the current request
|
||||||
|
This function will be called when processing the cached requests.
|
||||||
|
"""
|
||||||
|
self.num_stored_blocks += num_new_blocks
|
||||||
|
|
||||||
|
def update_block_ids(
|
||||||
|
self,
|
||||||
|
new_block_ids: list[int],
|
||||||
|
):
|
||||||
|
"""Update the block ids for the current request
|
||||||
|
This function will be called when processing the cached requests.
|
||||||
|
"""
|
||||||
|
self.allocated_block_ids.extend(new_block_ids)
|
||||||
|
|
||||||
|
####
|
||||||
|
# For debugging
|
||||||
|
####
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (
|
||||||
|
f"LMCacheMPRequestTracker(request_id={self.request_id}, "
|
||||||
|
f"num_tokens={len(self.all_token_ids)}, "
|
||||||
|
f"num_block_hashes={len(self.block_hashes)}, "
|
||||||
|
f"num_allocated_blocks={len(self.allocated_block_ids)}, "
|
||||||
|
f"num_stored_blocks={self.num_stored_blocks}, "
|
||||||
|
f"vllm_hit_blocks={self.num_vllm_hit_blocks}, "
|
||||||
|
f"lmcache_hit_blocks={self.num_lmcache_hit_blocks}, "
|
||||||
|
f"state={self.state})"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return self.__repr__()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LMCacheMPRequestMetadata:
|
||||||
|
request_id: str
|
||||||
|
direction: Literal["STORE", "RETRIEVE"]
|
||||||
|
op: LoadStoreOp
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def GetStoreMetadata(
|
||||||
|
tracker: LMCacheMPRequestTracker,
|
||||||
|
blocks_in_chunk: int,
|
||||||
|
vllm_block_size: int,
|
||||||
|
) -> "LMCacheMPRequestMetadata | None":
|
||||||
|
"""
|
||||||
|
Generate the store metadata for the current request tracker.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tracker: The request tracker to generate the metadata from.
|
||||||
|
blocks_in_chunk: the number of blocks in a LMCache data chunk
|
||||||
|
"""
|
||||||
|
# Store the blocks that has block hashes
|
||||||
|
# NOTE: the invariant here is that `num_stored_blocks` should
|
||||||
|
# always be a multiple of `blocks_in_chunk`
|
||||||
|
# TODO: This should be checked everytime we update the num_stored_blocks
|
||||||
|
min_available_blocks = min(
|
||||||
|
len(tracker.block_hashes),
|
||||||
|
len(tracker.allocated_block_ids),
|
||||||
|
tracker.num_scheduled_tokens // vllm_block_size,
|
||||||
|
)
|
||||||
|
num_staging_blocks = min_available_blocks - tracker.num_stored_blocks
|
||||||
|
num_chunks = num_staging_blocks // blocks_in_chunk
|
||||||
|
|
||||||
|
if num_chunks >= 1:
|
||||||
|
start = tracker.num_stored_blocks
|
||||||
|
end = start + num_chunks * blocks_in_chunk
|
||||||
|
block_hashes = convert_block_hashes_to_bytes(
|
||||||
|
tracker.block_hashes[start:end]
|
||||||
|
)
|
||||||
|
block_ids = tracker.allocated_block_ids[start:end]
|
||||||
|
|
||||||
|
ret = LMCacheMPRequestMetadata(
|
||||||
|
request_id=tracker.request_id,
|
||||||
|
direction="STORE",
|
||||||
|
op=LoadStoreOp(block_hashes=block_hashes, block_ids=block_ids),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update the request tracker
|
||||||
|
tracker.increase_num_stored_blocks(end - start)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def GetRetrieveMetadata(
|
||||||
|
tracker: LMCacheMPRequestTracker,
|
||||||
|
blocks_in_chunk: int,
|
||||||
|
) -> "LMCacheMPRequestMetadata | None":
|
||||||
|
"""
|
||||||
|
Generate the retrieve metadata for the current request tracker.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tracker: The request tracker to generate the metadata from.
|
||||||
|
blocks_in_chunk: the number of blocks in a LMCache data chunk
|
||||||
|
"""
|
||||||
|
if not tracker.is_ready_for_retrieving():
|
||||||
|
return None
|
||||||
|
|
||||||
|
# |---------------------|-----------------|----------------|
|
||||||
|
# | num_vllm_hit_blocks |
|
||||||
|
# | lmcache chunk 1 | lmcache chunk 2 |
|
||||||
|
# | need to retrieve |
|
||||||
|
|
||||||
|
start = tracker.num_vllm_hit_blocks // blocks_in_chunk * blocks_in_chunk
|
||||||
|
end = tracker.num_lmcache_hit_blocks
|
||||||
|
assert end % blocks_in_chunk == 0, (
|
||||||
|
"The number of LMCache hit blocks should be a multiple of the "
|
||||||
|
"number of blocks in a lmcache chunk. "
|
||||||
|
)
|
||||||
|
assert len(tracker.block_hashes) >= end, (
|
||||||
|
"The number of block hashes should be greater than or equal to the "
|
||||||
|
"number of LMCache hit blocks. "
|
||||||
|
)
|
||||||
|
if end > start:
|
||||||
|
block_hashes = convert_block_hashes_to_bytes(
|
||||||
|
tracker.block_hashes[start:end]
|
||||||
|
)
|
||||||
|
block_ids = tracker.allocated_block_ids[start:end]
|
||||||
|
|
||||||
|
ret = LMCacheMPRequestMetadata(
|
||||||
|
request_id=tracker.request_id,
|
||||||
|
direction="RETRIEVE",
|
||||||
|
op=LoadStoreOp(block_hashes=block_hashes, block_ids=block_ids),
|
||||||
|
)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class LMCacheMPConnectorMetadata(KVConnectorMetadata):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.requests: list[LMCacheMPRequestMetadata] = []
|
||||||
|
|
||||||
|
def add_request_metadata(self, request_metadata: LMCacheMPRequestMetadata):
|
||||||
|
self.requests.append(request_metadata)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.requests)
|
||||||
|
|
||||||
|
# For debugging
|
||||||
|
def __str__(self):
|
||||||
|
request_strs = []
|
||||||
|
for req_meta in self.requests:
|
||||||
|
request_strs.append(
|
||||||
|
f"RequestMetadata(request_id={req_meta.request_id}, "
|
||||||
|
f"direction={req_meta.direction}, "
|
||||||
|
f"num_blocks={len(req_meta.op)}, "
|
||||||
|
f"block_ids={req_meta.op.block_ids})"
|
||||||
|
)
|
||||||
|
return "[" + "\n".join(request_strs) + "]"
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return self.__str__()
|
||||||
|
|
||||||
|
|
||||||
|
class LMCacheMPConnector(KVConnectorBase_V1):
|
||||||
|
"""
|
||||||
|
The connector for LMCache multi-process mode.
|
||||||
|
|
||||||
|
Extra configs (kv_transfer_config.extra_config):
|
||||||
|
- lmcache.mp.host: the host of the LMCache server.
|
||||||
|
- lmcache.mp.port: the port of the LMCache server.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vllm_config: "VllmConfig",
|
||||||
|
role: KVConnectorRole,
|
||||||
|
kv_cache_config: Optional["KVCacheConfig"] = None,
|
||||||
|
):
|
||||||
|
super().__init__(vllm_config, role, kv_cache_config)
|
||||||
|
|
||||||
|
assert vllm_config.kv_transfer_config is not None
|
||||||
|
server_host = vllm_config.kv_transfer_config.get_from_extra_config(
|
||||||
|
"lmcache.mp.host", "tcp://localhost"
|
||||||
|
)
|
||||||
|
server_port = vllm_config.kv_transfer_config.get_from_extra_config(
|
||||||
|
"lmcache.mp.port", 5555
|
||||||
|
)
|
||||||
|
|
||||||
|
server_url = f"{server_host}:{server_port}"
|
||||||
|
zmq_context = zmq.Context.instance()
|
||||||
|
if self.role == KVConnectorRole.SCHEDULER:
|
||||||
|
self.scheduler_adapter = create_scheduler_adapter(
|
||||||
|
server_url, zmq_context, vllm_config
|
||||||
|
)
|
||||||
|
self.request_trackers: dict[str, LMCacheMPRequestTracker] = {}
|
||||||
|
elif self.role == KVConnectorRole.WORKER:
|
||||||
|
self.worker_adapter = create_worker_adapter(
|
||||||
|
server_url, zmq_context, vllm_config
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown KVConnectorRole: {self.role}")
|
||||||
|
|
||||||
|
self.vllm_block_size = vllm_config.cache_config.block_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def role(self) -> KVConnectorRole:
|
||||||
|
return self._role
|
||||||
|
|
||||||
|
# ==============================
|
||||||
|
# Worker-side methods
|
||||||
|
# ==============================
|
||||||
|
|
||||||
|
def _get_connector_metadata(self) -> KVConnectorMetadata:
|
||||||
|
"""Get the connector metadata.
|
||||||
|
|
||||||
|
This function should only be called inside the connector.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ConnectorMetadata: the connector metadata.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Should only be called while set to valid metadata.
|
||||||
|
assert self._connector_metadata is not None
|
||||||
|
return self._connector_metadata
|
||||||
|
|
||||||
|
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||||
|
"""
|
||||||
|
Initialize with the KV caches. Useful for pre-registering the
|
||||||
|
KV Caches in the KVConnector (e.g. for NIXL).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
kv_caches: dictionary of layer names, kv cache
|
||||||
|
"""
|
||||||
|
logger.info("Registering kv caches!")
|
||||||
|
self.worker_adapter.register_kv_caches(kv_caches)
|
||||||
|
return
|
||||||
|
|
||||||
|
def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
|
||||||
|
"""
|
||||||
|
Start loading the KV cache from the connector to vLLM's paged
|
||||||
|
KV buffer. This is called from the forward context before the
|
||||||
|
forward pass to enable async loading during model execution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
forward_context (ForwardContext): the forward context.
|
||||||
|
**kwargs: additional arguments for the load operation
|
||||||
|
|
||||||
|
Note:
|
||||||
|
The number of elements in kv_caches and layer_names should be
|
||||||
|
the same.
|
||||||
|
|
||||||
|
"""
|
||||||
|
metadata = self._get_connector_metadata()
|
||||||
|
assert isinstance(metadata, LMCacheMPConnectorMetadata)
|
||||||
|
|
||||||
|
with torch.cuda.stream(torch.cuda.current_stream()):
|
||||||
|
event = torch.cuda.Event(interprocess=True)
|
||||||
|
event.record()
|
||||||
|
|
||||||
|
request_ids = []
|
||||||
|
ops = []
|
||||||
|
|
||||||
|
for meta in metadata.requests:
|
||||||
|
if meta.direction != "RETRIEVE":
|
||||||
|
continue
|
||||||
|
request_ids.append(meta.request_id)
|
||||||
|
ops.append(meta.op)
|
||||||
|
|
||||||
|
if len(request_ids) > 0:
|
||||||
|
logger.info(
|
||||||
|
"HERE! SUBMITTING THE BATCHED RETRIEVE REQUESTS %s", request_ids
|
||||||
|
)
|
||||||
|
self.worker_adapter.batched_submit_retrieve_requests(
|
||||||
|
request_ids, ops, event
|
||||||
|
)
|
||||||
|
|
||||||
|
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||||
|
"""
|
||||||
|
Block until the KV for a specific layer is loaded into vLLM's
|
||||||
|
paged buffer. This is called from within attention layer to ensure
|
||||||
|
async copying from start_load_kv is complete.
|
||||||
|
|
||||||
|
This interface will be useful for layer-by-layer pipelining.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer_name: the name of that layer
|
||||||
|
"""
|
||||||
|
return
|
||||||
|
|
||||||
|
def save_kv_layer(
|
||||||
|
self,
|
||||||
|
layer_name: str,
|
||||||
|
kv_layer: torch.Tensor,
|
||||||
|
attn_metadata: "AttentionMetadata",
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Start saving a layer of KV cache from vLLM's paged buffer
|
||||||
|
to the connector. This is called from within attention layer to
|
||||||
|
enable async copying during execution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer_name (str): the name of the layer.
|
||||||
|
kv_layer (torch.Tensor): the paged KV buffer of the current
|
||||||
|
layer in vLLM.
|
||||||
|
attn_metadata (AttentionMetadata): the attention metadata.
|
||||||
|
**kwargs: additional arguments for the save operation.
|
||||||
|
"""
|
||||||
|
return
|
||||||
|
|
||||||
|
def wait_for_save(self):
|
||||||
|
"""
|
||||||
|
Block until all the save operations is done. This is called
|
||||||
|
as the forward context exits to ensure that the async saving
|
||||||
|
from save_kv_layer is complete before finishing the forward.
|
||||||
|
|
||||||
|
This prevents overwrites of paged KV buffer before saving done.
|
||||||
|
"""
|
||||||
|
metadata = self._get_connector_metadata()
|
||||||
|
assert isinstance(metadata, LMCacheMPConnectorMetadata)
|
||||||
|
|
||||||
|
with torch.cuda.stream(torch.cuda.current_stream()):
|
||||||
|
event = torch.cuda.Event(interprocess=True)
|
||||||
|
event.record()
|
||||||
|
|
||||||
|
request_ids = []
|
||||||
|
ops = []
|
||||||
|
for meta in metadata.requests:
|
||||||
|
if meta.direction != "STORE":
|
||||||
|
continue
|
||||||
|
request_ids.append(meta.request_id)
|
||||||
|
ops.append(meta.op)
|
||||||
|
|
||||||
|
if len(request_ids) > 0:
|
||||||
|
self.worker_adapter.batched_submit_store_requests(request_ids, ops, event)
|
||||||
|
|
||||||
|
def get_finished(
|
||||||
|
self, finished_req_ids: set[str]
|
||||||
|
) -> tuple[set[str] | None, set[str] | None]:
|
||||||
|
"""
|
||||||
|
Notifies worker-side connector ids of requests that have
|
||||||
|
finished generating tokens on the worker.
|
||||||
|
The scheduler process (via the Executors) will use this output
|
||||||
|
to track which workers are done.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ids of requests that have finished asynchronous transfer
|
||||||
|
(requests that previously returned True from request_finished()),
|
||||||
|
tuple of (sending/saving ids, recving/loading ids).
|
||||||
|
The finished saves/sends req ids must belong to a set provided in a
|
||||||
|
call to this method (this call or a prior one).
|
||||||
|
"""
|
||||||
|
val = self.worker_adapter.get_finished(finished_req_ids)
|
||||||
|
# logger.error("Finished req ids: %s, %s", val[0], val[1])
|
||||||
|
return val
|
||||||
|
|
||||||
|
def get_block_ids_with_load_errors(self) -> set[int]:
|
||||||
|
"""
|
||||||
|
Get the set of block IDs that failed to load.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Set of block IDs that encountered load errors.
|
||||||
|
Empty set if no load errors occurred.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- Applies to both sync- and async-loading requests.
|
||||||
|
- Async loading: failed blocks may be reported in any forward pass
|
||||||
|
up to and including the pass where the request ID is returned by
|
||||||
|
`get_finished()`. Even if failures occur, the request must still
|
||||||
|
be reported via `get_finished()`, and the failed block IDs must
|
||||||
|
appear here no later than that same pass.
|
||||||
|
- Sync loading: failed blocks should be reported in the forward
|
||||||
|
pass in which they are detected.
|
||||||
|
"""
|
||||||
|
# TODO: add error tracking
|
||||||
|
return set()
|
||||||
|
|
||||||
|
def shutdown(self):
|
||||||
|
"""
|
||||||
|
Shutdown the connector. This is called when the worker process
|
||||||
|
is shutting down to ensure that all the async operations are
|
||||||
|
completed and the connector is cleaned up properly.
|
||||||
|
"""
|
||||||
|
if hasattr(self, "worker_adapter"):
|
||||||
|
self.worker_adapter.shutdown()
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_kv_connector_stats(self) -> Optional["KVConnectorStats"]:
|
||||||
|
"""
|
||||||
|
Get the KV connector stats collected during the last interval.
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
# ==============================
|
||||||
|
# Scheduler-side methods
|
||||||
|
# ==============================
|
||||||
|
|
||||||
|
def get_num_new_matched_tokens(
|
||||||
|
self,
|
||||||
|
request: "Request",
|
||||||
|
num_computed_tokens: int,
|
||||||
|
) -> tuple[int | None, bool]:
|
||||||
|
"""
|
||||||
|
Get number of new tokens that can be loaded from the
|
||||||
|
external KV cache beyond the num_computed_tokens.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request (Request): the request object.
|
||||||
|
num_computed_tokens (int): the number of locally
|
||||||
|
computed tokens for this request
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple with the following elements:
|
||||||
|
- An optional number of tokens that can be loaded from the
|
||||||
|
external KV cache beyond what is already computed.
|
||||||
|
If None, it means that the connector needs more time to
|
||||||
|
determine the number of matched tokens, and the scheduler
|
||||||
|
should query for this request again later.
|
||||||
|
- `True` if external KV cache tokens will be loaded
|
||||||
|
asynchronously (between scheduler steps). Must be
|
||||||
|
'False' if the first element is 0.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
The connector should only consider the largest prefix of prompt-
|
||||||
|
tokens for which KV cache is actually available at the time of the
|
||||||
|
call. If the cache cannot be loaded for some tokens (e.g., due to
|
||||||
|
connectivity issues or eviction), those tokens must not be taken
|
||||||
|
into account.
|
||||||
|
"""
|
||||||
|
tracker = self._get_or_create_request_tracker(request)
|
||||||
|
|
||||||
|
self.scheduler_adapter.maybe_submit_lookup_request(
|
||||||
|
request.request_id, convert_block_hashes_to_bytes(request.block_hashes)
|
||||||
|
)
|
||||||
|
|
||||||
|
ret = self.scheduler_adapter.check_lookup_result(request.request_id)
|
||||||
|
if ret is None:
|
||||||
|
return None, True
|
||||||
|
|
||||||
|
if ret == 0:
|
||||||
|
return 0, False
|
||||||
|
|
||||||
|
assert (
|
||||||
|
ret % (self.scheduler_adapter.num_blocks_per_chunk() * self.vllm_block_size)
|
||||||
|
== 0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update num stored blocks for the tracker
|
||||||
|
num_vllm_blocks = num_computed_tokens // self.vllm_block_size
|
||||||
|
num_lmcache_blocks = ret // self.vllm_block_size
|
||||||
|
tracker.increase_num_stored_blocks(num_lmcache_blocks)
|
||||||
|
|
||||||
|
# Save the vllm and lmcache hit tokens
|
||||||
|
tracker.num_vllm_hit_blocks = num_vllm_blocks
|
||||||
|
tracker.num_lmcache_hit_blocks = num_lmcache_blocks
|
||||||
|
|
||||||
|
need_to_load = max(0, ret - num_computed_tokens)
|
||||||
|
logger.debug(
|
||||||
|
"vLLM hit is: %d, Need to load is %d", num_computed_tokens, need_to_load
|
||||||
|
)
|
||||||
|
return need_to_load, need_to_load > 0
|
||||||
|
|
||||||
|
def update_state_after_alloc(
|
||||||
|
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Update KVConnector state after block allocation.
|
||||||
|
|
||||||
|
If get_num_new_matched_tokens previously returned True for a
|
||||||
|
request, this function may be called twice for that same request -
|
||||||
|
first when blocks are allocated for the connector tokens to be
|
||||||
|
asynchronously loaded into, and second when any additional blocks
|
||||||
|
are allocated, after the load/transfer is complete.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request (Request): the request object.
|
||||||
|
blocks (KVCacheBlocks): the blocks allocated for the request.
|
||||||
|
num_external_tokens (int): the number of tokens that will be
|
||||||
|
loaded from the external KV cache.
|
||||||
|
"""
|
||||||
|
# NOTE: the `blocks` are NEW BLOCKS allocated for this request.
|
||||||
|
tracker = self._get_request_tracker(request.request_id)
|
||||||
|
block_ids = reformat_block_ids(blocks.get_block_ids())
|
||||||
|
|
||||||
|
# No matter we need to retrieve or not, we need to update
|
||||||
|
# the block ids into the tracker
|
||||||
|
tracker.update_block_ids(block_ids)
|
||||||
|
|
||||||
|
# Update the state of the tracker
|
||||||
|
condition = tracker.needs_retrieve()
|
||||||
|
if tracker.state == LMCacheMPRequestState.PREFETCHING:
|
||||||
|
# If need to retrieve, change to WAITING_FOR_LOAD
|
||||||
|
# Otherwise, change to READY
|
||||||
|
tracker.state = (
|
||||||
|
LMCacheMPRequestState.WAITING_FOR_LOAD
|
||||||
|
if condition
|
||||||
|
else LMCacheMPRequestState.READY
|
||||||
|
)
|
||||||
|
|
||||||
|
def build_connector_meta(
|
||||||
|
self, scheduler_output: SchedulerOutput
|
||||||
|
) -> KVConnectorMetadata:
|
||||||
|
"""
|
||||||
|
Build the connector metadata for this step.
|
||||||
|
|
||||||
|
This function should NOT modify fields in the scheduler_output.
|
||||||
|
Also, calling this function will reset the state of the connector.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||||
|
"""
|
||||||
|
metadata = LMCacheMPConnectorMetadata()
|
||||||
|
|
||||||
|
self._process_retrieve_requests(metadata)
|
||||||
|
self._process_new_requests(scheduler_output, metadata)
|
||||||
|
self._process_cached_requests(scheduler_output, metadata)
|
||||||
|
|
||||||
|
if len(metadata) > 0:
|
||||||
|
logger.debug("Final connector metadata: %s", metadata)
|
||||||
|
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
def update_connector_output(self, connector_output: KVConnectorOutput):
|
||||||
|
"""
|
||||||
|
Update KVConnector state from worker-side connectors output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connector_output (KVConnectorOutput): the worker-side
|
||||||
|
connectors output.
|
||||||
|
"""
|
||||||
|
return
|
||||||
|
|
||||||
|
def request_finished(
|
||||||
|
self,
|
||||||
|
request: "Request",
|
||||||
|
block_ids: list[int],
|
||||||
|
) -> tuple[bool, dict[str, Any] | None]:
|
||||||
|
"""
|
||||||
|
Called exactly once when a request has finished, before its blocks are
|
||||||
|
freed.
|
||||||
|
|
||||||
|
The connector may assumes responsibility for freeing the blocks
|
||||||
|
asynchronously by returning True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the request is being saved/sent asynchronously and blocks
|
||||||
|
should not be freed until the request_id is returned from
|
||||||
|
get_finished().
|
||||||
|
Optional KVTransferParams to be included in the request outputs
|
||||||
|
returned by the engine.
|
||||||
|
"""
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
def take_events(self) -> Iterable["KVCacheEvent"]:
|
||||||
|
"""
|
||||||
|
Take the KV cache events from the connector.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
New KV cache events since the last call.
|
||||||
|
"""
|
||||||
|
return ()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> str | None:
|
||||||
|
"""
|
||||||
|
Get the required KV cache layout for this connector.
|
||||||
|
Args:
|
||||||
|
vllm_config (VllmConfig): the vllm config.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: the required KV cache layout. e.g. HND, or NHD.
|
||||||
|
None if the connector does not require a specific layout.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if cls is KVConnectorBase_V1:
|
||||||
|
raise TypeError(
|
||||||
|
"get_required_kvcache_layout should not be called "
|
||||||
|
"on the abstract base class"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_finished_count(self) -> int | None:
|
||||||
|
"""
|
||||||
|
Get the count of requests expected to complete send/receive operations
|
||||||
|
via this connector. This method is used to initialize the
|
||||||
|
KVOutputAggregator, overwriting the default world_size.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: expected sending or receiving completion count.
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_kv_connector_stats(
|
||||||
|
cls, data: dict[str, Any] | None = None
|
||||||
|
) -> Optional["KVConnectorStats"]:
|
||||||
|
"""
|
||||||
|
KVConnectorStats resolution method. This method allows dynamically
|
||||||
|
registered connectors to return their own KVConnectorStats object,
|
||||||
|
which can implement custom aggregation logic on the data dict.
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_prom_metrics(
|
||||||
|
cls,
|
||||||
|
vllm_config: "VllmConfig",
|
||||||
|
metric_types: dict[type["PromMetric"], type["PromMetricT"]],
|
||||||
|
labelnames: list[str],
|
||||||
|
per_engine_labelvalues: dict[int, list[str]],
|
||||||
|
) -> Optional["KVConnectorPromMetrics"]:
|
||||||
|
"""
|
||||||
|
Create a KVConnectorPromMetrics subclass which should register
|
||||||
|
per-connector Prometheus metrics and implement observe() to
|
||||||
|
expose connector transfer stats via Prometheus.
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
##############################
|
||||||
|
# Helper functions
|
||||||
|
##############################
|
||||||
|
def _process_retrieve_requests(
|
||||||
|
self,
|
||||||
|
metadata: LMCacheMPConnectorMetadata,
|
||||||
|
) -> None:
|
||||||
|
blocks_per_chunk = self.scheduler_adapter.num_blocks_per_chunk()
|
||||||
|
|
||||||
|
for request_tracker in self.request_trackers.values():
|
||||||
|
if request_tracker.state != LMCacheMPRequestState.WAITING_FOR_LOAD:
|
||||||
|
continue
|
||||||
|
r_metadata = LMCacheMPRequestMetadata.GetRetrieveMetadata(
|
||||||
|
request_tracker, blocks_per_chunk
|
||||||
|
)
|
||||||
|
if r_metadata is not None:
|
||||||
|
metadata.add_request_metadata(r_metadata)
|
||||||
|
request_tracker.state = LMCacheMPRequestState.READY
|
||||||
|
|
||||||
|
def _process_new_requests(
|
||||||
|
self,
|
||||||
|
scheduler_output: SchedulerOutput,
|
||||||
|
metadata: LMCacheMPConnectorMetadata,
|
||||||
|
) -> None:
|
||||||
|
blocks_per_chunk = self.scheduler_adapter.num_blocks_per_chunk()
|
||||||
|
|
||||||
|
for new_request in scheduler_output.scheduled_new_reqs:
|
||||||
|
request_tracker = self._get_request_tracker(new_request.req_id)
|
||||||
|
|
||||||
|
num_new_tokens = scheduler_output.num_scheduled_tokens[new_request.req_id]
|
||||||
|
request_tracker.increase_num_scheduled_tokens(num_new_tokens)
|
||||||
|
|
||||||
|
r_meta = LMCacheMPRequestMetadata.GetStoreMetadata(
|
||||||
|
request_tracker, blocks_per_chunk, self.vllm_block_size
|
||||||
|
)
|
||||||
|
if r_meta is not None:
|
||||||
|
metadata.add_request_metadata(r_meta)
|
||||||
|
|
||||||
|
def _process_cached_requests(
|
||||||
|
self,
|
||||||
|
scheduler_output: SchedulerOutput,
|
||||||
|
metadata: LMCacheMPConnectorMetadata,
|
||||||
|
) -> None:
|
||||||
|
blocks_per_chunk = self.scheduler_adapter.num_blocks_per_chunk()
|
||||||
|
|
||||||
|
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||||
|
for idx, request_id in enumerate(cached_reqs.req_ids):
|
||||||
|
request_tracker = self._get_request_tracker(request_id)
|
||||||
|
|
||||||
|
# Update block ids
|
||||||
|
new_block_ids = reformat_block_ids(cached_reqs.new_block_ids[idx])
|
||||||
|
request_tracker.update_block_ids(new_block_ids)
|
||||||
|
|
||||||
|
# Update new scheduled tokens
|
||||||
|
num_new_tokens = cached_reqs.num_computed_tokens[idx]
|
||||||
|
request_tracker.increase_num_scheduled_tokens(num_new_tokens)
|
||||||
|
|
||||||
|
r_meta = LMCacheMPRequestMetadata.GetStoreMetadata(
|
||||||
|
request_tracker, blocks_per_chunk, self.vllm_block_size
|
||||||
|
)
|
||||||
|
|
||||||
|
if r_meta is not None:
|
||||||
|
metadata.add_request_metadata(r_meta)
|
||||||
|
|
||||||
|
def _get_request_tracker(self, request_id: str) -> LMCacheMPRequestTracker:
|
||||||
|
assert request_id in self.request_trackers, (
|
||||||
|
f"Request tracker for request_id {request_id} not found. "
|
||||||
|
)
|
||||||
|
return self.request_trackers[request_id]
|
||||||
|
|
||||||
|
def _get_or_create_request_tracker(
|
||||||
|
self, request: "Request"
|
||||||
|
) -> LMCacheMPRequestTracker:
|
||||||
|
request_id = request.request_id
|
||||||
|
if request_id not in self.request_trackers:
|
||||||
|
new_tracker = LMCacheMPRequestTracker(request)
|
||||||
|
self.request_trackers[request_id] = new_tracker
|
||||||
|
return self.request_trackers[request_id]
|
||||||
Loading…
x
Reference in New Issue
Block a user