From 6d7f037748b2e7df64f3318e54101a1c80016f3c Mon Sep 17 00:00:00 2001 From: Jiayi Yao <82156730+YaoJiayi@users.noreply.github.com> Date: Sat, 8 Mar 2025 21:30:06 -0600 Subject: [PATCH] [Feat] Support chunked prefill for LMCache connector (#14505) Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn> --- .../kv_connector/lmcache_connector.py | 26 ++++++------------- 1 file changed, 8 insertions(+), 18 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py b/vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py index bf9117133af56..42de227b6c309 100644 --- a/vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py @@ -38,7 +38,8 @@ class LMCacheConnector(KVConnectorBase): from lmcache.integration.vllm.utils import ENGINE_NAME from lmcache.integration.vllm.vllm_adapter import ( RetrieveStatus, StoreStatus, init_lmcache_engine, - lmcache_retrieve_kv, lmcache_should_store, lmcache_store_kv) + lmcache_retrieve_kv, lmcache_should_retrieve, lmcache_should_store, + lmcache_store_kv) logger.info("Initializing LMCacheConfig under kv_transfer_config %s", self.transfer_config) @@ -54,6 +55,7 @@ class LMCacheConnector(KVConnectorBase): self.cache_config = config.cache_config self.lmcache_retrieve_kv = lmcache_retrieve_kv self.lmcache_store_kv = lmcache_store_kv + self.lmcache_should_retrieve = lmcache_should_retrieve self.lmcache_should_store = lmcache_should_store self.store_status = StoreStatus self.retrieve_status = RetrieveStatus @@ -65,15 +67,11 @@ class LMCacheConnector(KVConnectorBase): ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, "ModelInputForGPUWithSamplingMetadata"]: - hidden_or_intermediate_states = None - - # TODO (Jiayi): Need to support chunked prefill - retrieve_status = self.retrieve_status.PREFILL - - model_input, bypass_model_exec = self.lmcache_retrieve_kv( - model_executable, model_input, self.cache_config, kv_caches, - retrieve_status) - + retrieve_status = self.lmcache_should_retrieve(model_input) + model_input, bypass_model_exec, hidden_or_intermediate_states =\ + self.lmcache_retrieve_kv( + model_executable, model_input, self.cache_config, kv_caches, + retrieve_status) return hidden_or_intermediate_states, bypass_model_exec, model_input def send_kv_caches_and_hidden_states( @@ -84,15 +82,7 @@ class LMCacheConnector(KVConnectorBase): hidden_or_intermediate_states: Union[torch.Tensor, IntermediateTensors], ) -> None: - num_reqs = 0 - seq_group_list = model_input.sampling_metadata.seq_groups - assert seq_group_list is not None - for seq_group in seq_group_list: - seq_ids = seq_group.seq_ids - for seq_id in seq_ids: - num_reqs += 1 - # TODO (Jiayi): Only normal prefill is supported for now store_status = self.lmcache_should_store(model_input) self.lmcache_store_kv( self.model_config,