diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py index 3f60fbd6455a2..ad907c75a244b 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py @@ -44,8 +44,8 @@ from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration.utils impo ) from vllm.distributed.parallel_state import get_tensor_model_parallel_rank, get_tp_group from vllm.sampling_params import SamplingParams -from vllm.utils import get_kv_cache_torch_dtype from vllm.utils.math_utils import cdiv +from vllm.utils.torch_utils import get_kv_cache_torch_dtype from vllm.v1.core.sched.output import SchedulerOutput from vllm.version import __version__ as VLLM_VERSION @@ -389,7 +389,7 @@ class ReqMeta: def need_gpu_interm_buffer(lmcache_config: LMCacheEngineConfig): - return lmcache_config.enable_pd + return not lmcache_config.enable_pd def _calculate_mtp_layers(vllm_config, model_config): @@ -403,6 +403,20 @@ def _calculate_mtp_layers(vllm_config, model_config): num_mtp_layers = getattr( model_config.hf_config, "num_nextn_predict_layers", 0 ) + + elif vllm_config.speculative_config.use_eagle(): + try: + draft_model_config = vllm_config.speculative_config.draft_model_config + num_mtp_layers = draft_model_config.get_num_layers( + vllm_config.parallel_config + ) + logger.info("EAGLE detected %d extra layer(s)", num_mtp_layers) + except Exception: + logger.info( + "EAGLE detected, but failed to get the number of extra layers" + "falling back to 1" + ) + num_mtp_layers = 1 return num_mtp_layers @@ -1208,6 +1222,10 @@ class LMCacheConnectorV1Impl: if the CacheManager this allocated blocks for us. """ + # Clear local status in lookup client when a new request is + # successfully scheduled. + self.lookup_client.clear_lookup_status(request.request_id) + kv_transfer_params = ( request.kv_transfer_params if hasattr(request, "kv_transfer_params")