mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 08:56:02 +08:00
[Frontend][Bugfix] support prefill decode disaggregation on deepseek (#14824)
Signed-off-by: billishyahao <bill.he@amd.com> Co-authored-by: Zhai Feiyue <80079571+ZhaiFeiyue@users.noreply.github.com>
This commit is contained in:
parent
bfe2fe0af4
commit
742369d35a
@ -8,6 +8,9 @@ set -xe
|
|||||||
echo "🚧🚧 Warning: The usage of disaggregated prefill is experimental and subject to change 🚧🚧"
|
echo "🚧🚧 Warning: The usage of disaggregated prefill is experimental and subject to change 🚧🚧"
|
||||||
sleep 1
|
sleep 1
|
||||||
|
|
||||||
|
# meta-llama/Meta-Llama-3.1-8B-Instruct or deepseek-ai/DeepSeek-V2-Lite
|
||||||
|
MODEL_NAME=${HF_MODEL_NAME:-meta-llama/Meta-Llama-3.1-8B-Instruct}
|
||||||
|
|
||||||
# Trap the SIGINT signal (triggered by Ctrl+C)
|
# Trap the SIGINT signal (triggered by Ctrl+C)
|
||||||
trap 'cleanup' INT
|
trap 'cleanup' INT
|
||||||
|
|
||||||
@ -44,18 +47,20 @@ wait_for_server() {
|
|||||||
# You can also adjust --kv-ip and --kv-port for distributed inference.
|
# You can also adjust --kv-ip and --kv-port for distributed inference.
|
||||||
|
|
||||||
# prefilling instance, which is the KV producer
|
# prefilling instance, which is the KV producer
|
||||||
CUDA_VISIBLE_DEVICES=0 vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct \
|
CUDA_VISIBLE_DEVICES=0 vllm serve $MODEL_NAME \
|
||||||
--port 8100 \
|
--port 8100 \
|
||||||
--max-model-len 100 \
|
--max-model-len 100 \
|
||||||
--gpu-memory-utilization 0.8 \
|
--gpu-memory-utilization 0.8 \
|
||||||
|
--trust-remote-code \
|
||||||
--kv-transfer-config \
|
--kv-transfer-config \
|
||||||
'{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' &
|
'{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' &
|
||||||
|
|
||||||
# decoding instance, which is the KV consumer
|
# decoding instance, which is the KV consumer
|
||||||
CUDA_VISIBLE_DEVICES=1 vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct \
|
CUDA_VISIBLE_DEVICES=1 vllm serve $MODEL_NAME \
|
||||||
--port 8200 \
|
--port 8200 \
|
||||||
--max-model-len 100 \
|
--max-model-len 100 \
|
||||||
--gpu-memory-utilization 0.8 \
|
--gpu-memory-utilization 0.8 \
|
||||||
|
--trust-remote-code \
|
||||||
--kv-transfer-config \
|
--kv-transfer-config \
|
||||||
'{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' &
|
'{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' &
|
||||||
|
|
||||||
@ -78,7 +83,7 @@ sleep 1
|
|||||||
output1=$(curl -X POST -s http://localhost:8000/v1/completions \
|
output1=$(curl -X POST -s http://localhost:8000/v1/completions \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
-d '{
|
-d '{
|
||||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
"model": "'"$MODEL_NAME"'",
|
||||||
"prompt": "San Francisco is a",
|
"prompt": "San Francisco is a",
|
||||||
"max_tokens": 10,
|
"max_tokens": 10,
|
||||||
"temperature": 0
|
"temperature": 0
|
||||||
@ -87,7 +92,7 @@ output1=$(curl -X POST -s http://localhost:8000/v1/completions \
|
|||||||
output2=$(curl -X POST -s http://localhost:8000/v1/completions \
|
output2=$(curl -X POST -s http://localhost:8000/v1/completions \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
-d '{
|
-d '{
|
||||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
"model": "'"$MODEL_NAME"'",
|
||||||
"prompt": "Santa Clara is a",
|
"prompt": "Santa Clara is a",
|
||||||
"max_tokens": 10,
|
"max_tokens": 10,
|
||||||
"temperature": 0
|
"temperature": 0
|
||||||
|
|||||||
@ -12,6 +12,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
|
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
|
||||||
@ -37,6 +38,8 @@ class SimpleConnector(KVConnectorBase):
|
|||||||
|
|
||||||
self.config = config.kv_transfer_config
|
self.config = config.kv_transfer_config
|
||||||
self.tp_size = config.parallel_config.tensor_parallel_size
|
self.tp_size = config.parallel_config.tensor_parallel_size
|
||||||
|
self.is_deepseek_mla = config.model_config.is_deepseek_mla
|
||||||
|
self.use_mla_opt = not envs.VLLM_MLA_DISABLE
|
||||||
|
|
||||||
if self.config.kv_connector == "PyNcclConnector":
|
if self.config.kv_connector == "PyNcclConnector":
|
||||||
from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import (
|
from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import (
|
||||||
@ -167,8 +170,26 @@ class SimpleConnector(KVConnectorBase):
|
|||||||
num_heads = int(model_config.num_key_value_heads / self.tp_size)
|
num_heads = int(model_config.num_key_value_heads / self.tp_size)
|
||||||
hidden_size = model_config.hidden_size
|
hidden_size = model_config.hidden_size
|
||||||
num_attention_heads = model_config.num_attention_heads
|
num_attention_heads = model_config.num_attention_heads
|
||||||
head_size = getattr(model_config, "head_dim",
|
|
||||||
int(hidden_size // num_attention_heads))
|
# Deepseek's MLA (Multi-head Latent Attention) uses two different
|
||||||
|
# kv_cache shapes based on whether VLLM_MLA_DISABLE is set to 0.
|
||||||
|
# When VLLM_MLA_DISABLE=0 (default), forward absorb is applied,
|
||||||
|
# resulting in a kv_cache shape of [num_blks, blk_size, 1,
|
||||||
|
# kv_lora_rank + qk_rope_head_dim].
|
||||||
|
# When VLLM_MLA_DISABLE=1, standard FA is used instead, leading
|
||||||
|
# to a kv_cache shape of [2, num_blks, blk_size,
|
||||||
|
# num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim].
|
||||||
|
# For more details, see vllm/attention/backends/mla/common.py.
|
||||||
|
if self.is_deepseek_mla and self.use_mla_opt:
|
||||||
|
head_size = model_config.kv_lora_rank + \
|
||||||
|
model_config.qk_rope_head_dim
|
||||||
|
num_heads = 1
|
||||||
|
elif self.is_deepseek_mla and not self.use_mla_opt:
|
||||||
|
head_size = model_config.qk_nope_head_dim + \
|
||||||
|
model_config.qk_rope_head_dim
|
||||||
|
else:
|
||||||
|
head_size = getattr(model_config, "head_dim",
|
||||||
|
int(hidden_size // num_attention_heads))
|
||||||
|
|
||||||
# query_lens contains new KV caches that are added to vLLM.
|
# query_lens contains new KV caches that are added to vLLM.
|
||||||
# so we will send them to decode instance
|
# so we will send them to decode instance
|
||||||
@ -192,8 +213,12 @@ class SimpleConnector(KVConnectorBase):
|
|||||||
for layer_id in range(start_layer, end_layer):
|
for layer_id in range(start_layer, end_layer):
|
||||||
kv_cache = kv_caches[layer_id - start_layer]
|
kv_cache = kv_caches[layer_id - start_layer]
|
||||||
|
|
||||||
key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
|
if self.is_deepseek_mla and self.use_mla_opt:
|
||||||
value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
|
key_cache = kv_cache.reshape(-1, num_heads, head_size)
|
||||||
|
value_cache = kv_cache.reshape(-1, num_heads, head_size)
|
||||||
|
else:
|
||||||
|
key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
|
||||||
|
value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
|
||||||
|
|
||||||
current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
|
current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
|
||||||
|
|
||||||
@ -223,6 +248,8 @@ class SimpleConnector(KVConnectorBase):
|
|||||||
# and hidden states.
|
# and hidden states.
|
||||||
bypass_model_exec = True
|
bypass_model_exec = True
|
||||||
|
|
||||||
|
model_config = model_executable.model.config
|
||||||
|
|
||||||
input_tokens_tensor = model_input.input_tokens
|
input_tokens_tensor = model_input.input_tokens
|
||||||
seq_lens = model_input.attn_metadata.seq_lens
|
seq_lens = model_input.attn_metadata.seq_lens
|
||||||
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
|
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
|
||||||
@ -291,19 +318,35 @@ class SimpleConnector(KVConnectorBase):
|
|||||||
kv_cache = kv_caches[i - model_executable.model.start_layer]
|
kv_cache = kv_caches[i - model_executable.model.start_layer]
|
||||||
layer = model_executable.model.layers[i]
|
layer = model_executable.model.layers[i]
|
||||||
|
|
||||||
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
if self.is_deepseek_mla and self.use_mla_opt:
|
||||||
ops.reshape_and_cache_flash(
|
layer.self_attn.attn = layer.self_attn.mla_attn
|
||||||
keys[i - model_executable.model.start_layer].to(
|
k_c_normed_k_pe = keys[
|
||||||
key_cache.device),
|
i - model_executable.model.start_layer].to(
|
||||||
values[i - model_executable.model.start_layer].to(
|
kv_cache.device).squeeze(1)
|
||||||
value_cache.device),
|
k_c_normed = k_c_normed_k_pe[:, :model_config.kv_lora_rank]
|
||||||
key_cache,
|
k_pe = k_c_normed_k_pe[:, model_config.kv_lora_rank:]
|
||||||
value_cache,
|
ops.concat_and_cache_mla(
|
||||||
slot_mapping[start_pos:end_pos],
|
k_c_normed,
|
||||||
layer.self_attn.attn.kv_cache_dtype,
|
k_pe,
|
||||||
layer.self_attn.attn._k_scale,
|
kv_cache,
|
||||||
layer.self_attn.attn._v_scale,
|
slot_mapping[start_pos:end_pos],
|
||||||
)
|
layer.self_attn.attn.kv_cache_dtype,
|
||||||
|
layer.self_attn.attn._k_scale,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
||||||
|
ops.reshape_and_cache_flash(
|
||||||
|
keys[i - model_executable.model.start_layer].to(
|
||||||
|
key_cache.device),
|
||||||
|
values[i - model_executable.model.start_layer].to(
|
||||||
|
value_cache.device),
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
slot_mapping[start_pos:end_pos],
|
||||||
|
layer.self_attn.attn.kv_cache_dtype,
|
||||||
|
layer.self_attn.attn._k_scale,
|
||||||
|
layer.self_attn.attn._v_scale,
|
||||||
|
)
|
||||||
|
|
||||||
hidden_or_intermediate_states_for_one_req.append(hidden)
|
hidden_or_intermediate_states_for_one_req.append(hidden)
|
||||||
|
|
||||||
|
|||||||
@ -589,6 +589,7 @@ class DeepseekV2Model(nn.Module):
|
|||||||
model_config = vllm_config.model_config
|
model_config = vllm_config.model_config
|
||||||
cache_config = vllm_config.cache_config
|
cache_config = vllm_config.cache_config
|
||||||
quant_config = vllm_config.quant_config
|
quant_config = vllm_config.quant_config
|
||||||
|
self.config = config
|
||||||
|
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user