diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py index 55831dc56c803..22ddabbf1e352 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, cast import torch import zmq +from lmcache.integration.vllm.utils import mla_enabled from lmcache.utils import init_logger as lmcache_init_logger from vllm.config import VllmConfig @@ -60,17 +61,44 @@ def reformat_block_ids(block_ids: tuple[list[int], ...] | None) -> list[int]: return block_ids[0] +def extract_world_size_and_kv_rank( + world_size: int, + rank: int, + vllm_config: VllmConfig, +) -> tuple[int, int]: + """ + Convert the rank for the MLA. + """ + use_mla = mla_enabled(vllm_config.model_config) + if not use_mla: + return world_size, rank + else: + # Tensor parallel does not change the KV caches for MLA models. + # So we need to "exclude" the effect of TP on rank and world size + tp_size = vllm_config.parallel_config.tensor_parallel_size + # vLLM constructs TP groups first, and then construct other + # parallel groups on top of TP groups. + # for example, TP=4, PP=2, + # TP group: [0, 1, 2, 3], [4, 5, 6, 7] + # PP group: [0, 4], [1, 5], [2, 6], [3, 7] + # So we can "exclude" the effect of TP by rank // tp_size. + return world_size // tp_size, rank // tp_size + + def create_scheduler_adapter( server_url: str, zmq_context: zmq.Context, vllm_config: VllmConfig ) -> LMCacheMPSchedulerAdapter: - # TODO: have a helper function to calculate the correct rank and - # world size for the MLA and other models + world_size, kv_rank = extract_world_size_and_kv_rank( + vllm_config.parallel_config.world_size, + vllm_config.parallel_config.rank, + vllm_config, + ) return LMCacheMPSchedulerAdapter( server_url, zmq_context, vllm_config.model_config.model, - vllm_config.parallel_config.world_size, - vllm_config.parallel_config.rank, + world_size, + kv_rank, vllm_config.cache_config.block_size, ) @@ -78,14 +106,17 @@ def create_scheduler_adapter( def create_worker_adapter( server_url: str, zmq_context: zmq.Context, vllm_config: VllmConfig ) -> LMCacheMPWorkerAdapter: - # TODO: have a helper function to calculate the correct rank and - # world size for the MLA and other models + world_size, kv_rank = extract_world_size_and_kv_rank( + vllm_config.parallel_config.world_size, + vllm_config.parallel_config.rank, + vllm_config, + ) return LMCacheMPWorkerAdapter( server_url, zmq_context, vllm_config.model_config.model, - vllm_config.parallel_config.world_size, - vllm_config.parallel_config.rank, + world_size, + kv_rank, vllm_config.cache_config.block_size, )