mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-25 13:54:28 +08:00
[DeepSeek + LMCache Multiprocess] handle MLA for deepseek model + LMCache Multiprocess connector (#29039)
Signed-off-by: KuntaiDu <kuntai@uchicago.edu>
This commit is contained in:
parent
1d642872a2
commit
05c2dee7e9
@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, cast
|
||||
|
||||
import torch
|
||||
import zmq
|
||||
from lmcache.integration.vllm.utils import mla_enabled
|
||||
from lmcache.utils import init_logger as lmcache_init_logger
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
@ -60,17 +61,44 @@ def reformat_block_ids(block_ids: tuple[list[int], ...] | None) -> list[int]:
|
||||
return block_ids[0]
|
||||
|
||||
|
||||
def extract_world_size_and_kv_rank(
|
||||
world_size: int,
|
||||
rank: int,
|
||||
vllm_config: VllmConfig,
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Convert the rank for the MLA.
|
||||
"""
|
||||
use_mla = mla_enabled(vllm_config.model_config)
|
||||
if not use_mla:
|
||||
return world_size, rank
|
||||
else:
|
||||
# Tensor parallel does not change the KV caches for MLA models.
|
||||
# So we need to "exclude" the effect of TP on rank and world size
|
||||
tp_size = vllm_config.parallel_config.tensor_parallel_size
|
||||
# vLLM constructs TP groups first, and then construct other
|
||||
# parallel groups on top of TP groups.
|
||||
# for example, TP=4, PP=2,
|
||||
# TP group: [0, 1, 2, 3], [4, 5, 6, 7]
|
||||
# PP group: [0, 4], [1, 5], [2, 6], [3, 7]
|
||||
# So we can "exclude" the effect of TP by rank // tp_size.
|
||||
return world_size // tp_size, rank // tp_size
|
||||
|
||||
|
||||
def create_scheduler_adapter(
|
||||
server_url: str, zmq_context: zmq.Context, vllm_config: VllmConfig
|
||||
) -> LMCacheMPSchedulerAdapter:
|
||||
# TODO: have a helper function to calculate the correct rank and
|
||||
# world size for the MLA and other models
|
||||
world_size, kv_rank = extract_world_size_and_kv_rank(
|
||||
vllm_config.parallel_config.world_size,
|
||||
vllm_config.parallel_config.rank,
|
||||
vllm_config,
|
||||
)
|
||||
return LMCacheMPSchedulerAdapter(
|
||||
server_url,
|
||||
zmq_context,
|
||||
vllm_config.model_config.model,
|
||||
vllm_config.parallel_config.world_size,
|
||||
vllm_config.parallel_config.rank,
|
||||
world_size,
|
||||
kv_rank,
|
||||
vllm_config.cache_config.block_size,
|
||||
)
|
||||
|
||||
@ -78,14 +106,17 @@ def create_scheduler_adapter(
|
||||
def create_worker_adapter(
|
||||
server_url: str, zmq_context: zmq.Context, vllm_config: VllmConfig
|
||||
) -> LMCacheMPWorkerAdapter:
|
||||
# TODO: have a helper function to calculate the correct rank and
|
||||
# world size for the MLA and other models
|
||||
world_size, kv_rank = extract_world_size_and_kv_rank(
|
||||
vllm_config.parallel_config.world_size,
|
||||
vllm_config.parallel_config.rank,
|
||||
vllm_config,
|
||||
)
|
||||
return LMCacheMPWorkerAdapter(
|
||||
server_url,
|
||||
zmq_context,
|
||||
vllm_config.model_config.model,
|
||||
vllm_config.parallel_config.world_size,
|
||||
vllm_config.parallel_config.rank,
|
||||
world_size,
|
||||
kv_rank,
|
||||
vllm_config.cache_config.block_size,
|
||||
)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user