[DeepSeek + LMCache Multiprocess] handle MLA for deepseek model + LMCache Multiprocess connector (#29039)

Signed-off-by: KuntaiDu <kuntai@uchicago.edu>
This commit is contained in:
Kuntai Du 2025-11-20 09:40:49 +08:00 committed by GitHub
parent 1d642872a2
commit 05c2dee7e9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, cast
import torch
import zmq
from lmcache.integration.vllm.utils import mla_enabled
from lmcache.utils import init_logger as lmcache_init_logger
from vllm.config import VllmConfig
@ -60,17 +61,44 @@ def reformat_block_ids(block_ids: tuple[list[int], ...] | None) -> list[int]:
return block_ids[0]
def extract_world_size_and_kv_rank(
world_size: int,
rank: int,
vllm_config: VllmConfig,
) -> tuple[int, int]:
"""
Convert the rank for the MLA.
"""
use_mla = mla_enabled(vllm_config.model_config)
if not use_mla:
return world_size, rank
else:
# Tensor parallel does not change the KV caches for MLA models.
# So we need to "exclude" the effect of TP on rank and world size
tp_size = vllm_config.parallel_config.tensor_parallel_size
# vLLM constructs TP groups first, and then construct other
# parallel groups on top of TP groups.
# for example, TP=4, PP=2,
# TP group: [0, 1, 2, 3], [4, 5, 6, 7]
# PP group: [0, 4], [1, 5], [2, 6], [3, 7]
# So we can "exclude" the effect of TP by rank // tp_size.
return world_size // tp_size, rank // tp_size
def create_scheduler_adapter(
server_url: str, zmq_context: zmq.Context, vllm_config: VllmConfig
) -> LMCacheMPSchedulerAdapter:
# TODO: have a helper function to calculate the correct rank and
# world size for the MLA and other models
world_size, kv_rank = extract_world_size_and_kv_rank(
vllm_config.parallel_config.world_size,
vllm_config.parallel_config.rank,
vllm_config,
)
return LMCacheMPSchedulerAdapter(
server_url,
zmq_context,
vllm_config.model_config.model,
vllm_config.parallel_config.world_size,
vllm_config.parallel_config.rank,
world_size,
kv_rank,
vllm_config.cache_config.block_size,
)
@ -78,14 +106,17 @@ def create_scheduler_adapter(
def create_worker_adapter(
server_url: str, zmq_context: zmq.Context, vllm_config: VllmConfig
) -> LMCacheMPWorkerAdapter:
# TODO: have a helper function to calculate the correct rank and
# world size for the MLA and other models
world_size, kv_rank = extract_world_size_and_kv_rank(
vllm_config.parallel_config.world_size,
vllm_config.parallel_config.rank,
vllm_config,
)
return LMCacheMPWorkerAdapter(
server_url,
zmq_context,
vllm_config.model_config.model,
vllm_config.parallel_config.world_size,
vllm_config.parallel_config.rank,
world_size,
kv_rank,
vllm_config.cache_config.block_size,
)