diff --git a/vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py b/vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py index bf9117133af56..42de227b6c309 100644 --- a/vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py @@ -38,7 +38,8 @@ class LMCacheConnector(KVConnectorBase): from lmcache.integration.vllm.utils import ENGINE_NAME from lmcache.integration.vllm.vllm_adapter import ( RetrieveStatus, StoreStatus, init_lmcache_engine, - lmcache_retrieve_kv, lmcache_should_store, lmcache_store_kv) + lmcache_retrieve_kv, lmcache_should_retrieve, lmcache_should_store, + lmcache_store_kv) logger.info("Initializing LMCacheConfig under kv_transfer_config %s", self.transfer_config) @@ -54,6 +55,7 @@ class LMCacheConnector(KVConnectorBase): self.cache_config = config.cache_config self.lmcache_retrieve_kv = lmcache_retrieve_kv self.lmcache_store_kv = lmcache_store_kv + self.lmcache_should_retrieve = lmcache_should_retrieve self.lmcache_should_store = lmcache_should_store self.store_status = StoreStatus self.retrieve_status = RetrieveStatus @@ -65,15 +67,11 @@ class LMCacheConnector(KVConnectorBase): ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, "ModelInputForGPUWithSamplingMetadata"]: - hidden_or_intermediate_states = None - - # TODO (Jiayi): Need to support chunked prefill - retrieve_status = self.retrieve_status.PREFILL - - model_input, bypass_model_exec = self.lmcache_retrieve_kv( - model_executable, model_input, self.cache_config, kv_caches, - retrieve_status) - + retrieve_status = self.lmcache_should_retrieve(model_input) + model_input, bypass_model_exec, hidden_or_intermediate_states =\ + self.lmcache_retrieve_kv( + model_executable, model_input, self.cache_config, kv_caches, + retrieve_status) return hidden_or_intermediate_states, bypass_model_exec, model_input def send_kv_caches_and_hidden_states( @@ -84,15 +82,7 @@ class LMCacheConnector(KVConnectorBase): hidden_or_intermediate_states: Union[torch.Tensor, IntermediateTensors], ) -> None: - num_reqs = 0 - seq_group_list = model_input.sampling_metadata.seq_groups - assert seq_group_list is not None - for seq_group in seq_group_list: - seq_ids = seq_group.seq_ids - for seq_id in seq_ids: - num_reqs += 1 - # TODO (Jiayi): Only normal prefill is supported for now store_status = self.lmcache_should_store(model_input) self.lmcache_store_kv( self.model_config,