[Feat] Support chunked prefill for LMCache connector (#14505)

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>
This commit is contained in:
Jiayi Yao 2025-03-08 21:30:06 -06:00 committed by GitHub
parent 10f7552789
commit 6d7f037748
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -38,7 +38,8 @@ class LMCacheConnector(KVConnectorBase):
from lmcache.integration.vllm.utils import ENGINE_NAME
from lmcache.integration.vllm.vllm_adapter import (
RetrieveStatus, StoreStatus, init_lmcache_engine,
lmcache_retrieve_kv, lmcache_should_store, lmcache_store_kv)
lmcache_retrieve_kv, lmcache_should_retrieve, lmcache_should_store,
lmcache_store_kv)
logger.info("Initializing LMCacheConfig under kv_transfer_config %s",
self.transfer_config)
@ -54,6 +55,7 @@ class LMCacheConnector(KVConnectorBase):
self.cache_config = config.cache_config
self.lmcache_retrieve_kv = lmcache_retrieve_kv
self.lmcache_store_kv = lmcache_store_kv
self.lmcache_should_retrieve = lmcache_should_retrieve
self.lmcache_should_store = lmcache_should_store
self.store_status = StoreStatus
self.retrieve_status = RetrieveStatus
@ -65,15 +67,11 @@ class LMCacheConnector(KVConnectorBase):
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
"ModelInputForGPUWithSamplingMetadata"]:
hidden_or_intermediate_states = None
# TODO (Jiayi): Need to support chunked prefill
retrieve_status = self.retrieve_status.PREFILL
model_input, bypass_model_exec = self.lmcache_retrieve_kv(
model_executable, model_input, self.cache_config, kv_caches,
retrieve_status)
retrieve_status = self.lmcache_should_retrieve(model_input)
model_input, bypass_model_exec, hidden_or_intermediate_states =\
self.lmcache_retrieve_kv(
model_executable, model_input, self.cache_config, kv_caches,
retrieve_status)
return hidden_or_intermediate_states, bypass_model_exec, model_input
def send_kv_caches_and_hidden_states(
@ -84,15 +82,7 @@ class LMCacheConnector(KVConnectorBase):
hidden_or_intermediate_states: Union[torch.Tensor,
IntermediateTensors],
) -> None:
num_reqs = 0
seq_group_list = model_input.sampling_metadata.seq_groups
assert seq_group_list is not None
for seq_group in seq_group_list:
seq_ids = seq_group.seq_ids
for seq_id in seq_ids:
num_reqs += 1
# TODO (Jiayi): Only normal prefill is supported for now
store_status = self.lmcache_should_store(model_input)
self.lmcache_store_kv(
self.model_config,