[V1] Support Deepseek MTP (#18435)

Signed-off-by: Rui Qiao <ruisearch42@gmail.com>
Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>
Co-authored-by: Rui Qiao <ruisearch42@gmail.com>
This commit is contained in:
Jiayi Yao 2025-05-23 12:26:28 -05:00 committed by GitHub
parent 371f7e4ca2
commit 2628a69e35
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 120 additions and 66 deletions

View File

@ -2255,7 +2255,7 @@ class DeviceConfig:
SpeculativeMethod = Literal["ngram", "eagle", "medusa", "mlp_speculator",
"draft_model"]
"draft_model", "deepseek_mtp"]
SpeculativeAcceptanceMethod = Literal["rejection_sampler",
"typical_acceptance_sampler"]
@ -2519,6 +2519,15 @@ class SpeculativeConfig:
elif (self.draft_model_config.hf_config.model_type ==
"mlp_speculator"):
self.method = "mlp_speculator"
elif (self.draft_model_config.hf_config.model_type ==
"deepseek_mtp"):
self.method = "deepseek_mtp"
if self.num_speculative_tokens > 1:
logger.warning(
"All Deepseek MTP models only have " \
"one layer. Might need some code changes " \
"to support multiple layers."
)
else:
self.method = "draft_model"
@ -2738,7 +2747,7 @@ class SpeculativeConfig:
return self.num_speculative_tokens
def use_eagle(self) -> bool:
return self.method in ("eagle", "eagle3")
return self.method in ("eagle", "eagle3", "deepseek_mtp")
def __repr__(self) -> str:
method = self.method

View File

@ -1338,7 +1338,7 @@ class EngineArgs:
is_ngram_enabled = True
elif speculative_method == "medusa":
is_medusa_enabled = True
elif speculative_method in ("eagle", "eagle3"):
elif speculative_method in ("eagle", "eagle3", "deepseek_mtp"):
is_eagle_enabled = True
else:
speculative_model = self.speculative_config.get("model")

View File

@ -19,6 +19,7 @@ from vllm.sequence import IntermediateTensors
from .deepseek_v2 import (DeepseekV2DecoderLayer,
get_spec_layer_idx_from_weight_name)
from .interfaces import SupportsPP
from .utils import maybe_prefix
@ -145,7 +146,7 @@ class DeepSeekMultiTokenPredictor(nn.Module):
return logits
class DeepSeekMTP(nn.Module):
class DeepSeekMTP(nn.Module, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()

View File

@ -10,9 +10,10 @@ from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.triton_utils import tl, triton
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.flash_attn import (CommonAttentionMetadata,
FlashAttentionMetadata)
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.utils import prepare_eagle_input_kernel
logger = init_logger(__name__)
@ -25,12 +26,15 @@ class EagleProposer:
self,
vllm_config: VllmConfig,
device: torch.device,
runner=None,
):
self.vllm_config = vllm_config
self.speculative_config = vllm_config.speculative_config
self.draft_model_config = self.speculative_config.draft_model_config
self.method = self.speculative_config.method
self.runner = runner
self.dtype = vllm_config.model_config.dtype
self.max_model_len = vllm_config.model_config.max_model_len
self.block_size = vllm_config.cache_config.block_size
@ -106,24 +110,46 @@ class EagleProposer:
# FA requires seq_len to have dtype int32.
seq_lens = (target_positions[last_token_indices] + 1).int()
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
max_seq_len = seq_lens.max().item()
max_num_tokens = (cu_num_tokens[1:] - cu_num_tokens[:-1]).max().item()
attn_metadata = FlashAttentionMetadata(
num_actual_tokens=num_tokens,
max_query_len=max_num_tokens,
query_start_loc=cu_num_tokens,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table=block_table,
slot_mapping=target_slot_mapping,
# TODO(woosuk): Support cascade attention.
use_cascade=False,
common_prefix_len=0,
cu_prefix_query_lens=None,
prefix_kv_lens=None,
suffix_kv_lens=None,
)
if self.method in ["eagle", "eagle3"]:
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
max_seq_len = seq_lens.max().item()
max_num_tokens = (cu_num_tokens[1:] -
cu_num_tokens[:-1]).max().item()
attn_metadata = FlashAttentionMetadata(
num_actual_tokens=num_tokens,
max_query_len=max_num_tokens,
query_start_loc=cu_num_tokens,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table=block_table,
slot_mapping=target_slot_mapping,
# TODO(woosuk): Support cascade attention.
use_cascade=False,
common_prefix_len=0,
cu_prefix_query_lens=None,
prefix_kv_lens=None,
suffix_kv_lens=None,
)
elif self.method == "deepseek_mtp":
query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1]
max_query_len = query_lens.max().item()
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=cu_num_tokens, seq_lens=seq_lens)
assert self.runner is not None
# FIXME: need to consider multiple kv_cache_groups
attn_metadata = self.runner.attn_metadata_builder.build(
num_reqs=batch_size,
num_actual_tokens=num_tokens,
max_query_len=max_query_len,
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
else:
raise ValueError(f"Unsupported method: {self.method}")
if self.use_cuda_graph and \
num_tokens <= self.cudagraph_batch_sizes[-1]:
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
@ -136,11 +162,15 @@ class EagleProposer:
with set_forward_context(attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens):
last_hidden_states, hidden_states = self.model(
input_ids=self.input_ids[:num_input_tokens],
positions=self.positions[:num_input_tokens],
hidden_states=self.hidden_states[:num_input_tokens],
ret_hidden_states = self.model(
self.input_ids[:num_input_tokens],
self.positions[:num_input_tokens],
self.hidden_states[:num_input_tokens],
)
if self.method == "deepseek_mtp":
last_hidden_states = ret_hidden_states
else:
last_hidden_states, hidden_states = ret_hidden_states
sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
draft_token_ids = logits.argmax(dim=-1)
@ -150,6 +180,10 @@ class EagleProposer:
# [batch_size, 1]
return draft_token_ids.view(-1, 1)
# TODO: Currently, MTP module released by deepseek only has
# one layer. Adapt this code to support multiple layers once
# there's a multi-layer MTP module.
# Generate the remaining draft tokens.
draft_token_ids_list = [draft_token_ids]
@ -215,9 +249,9 @@ class EagleProposer:
self.vllm_config,
num_tokens=input_batch_size):
last_hidden_states, hidden_states = self.model(
input_ids=self.input_ids[:input_batch_size],
positions=self.positions[:input_batch_size],
hidden_states=self.hidden_states[:input_batch_size],
self.input_ids[:input_batch_size],
self.positions[:input_batch_size],
self.hidden_states[:input_batch_size],
)
hidden_states = hidden_states[:batch_size]
logits = self.model.compute_logits(last_hidden_states[:batch_size],
@ -268,7 +302,7 @@ class EagleProposer:
batch_size = num_rejected_tokens.shape[0]
BLOCK_SIZE = 1024
prepare_input_kernel[(batch_size, )](
prepare_eagle_input_kernel[(batch_size, )](
token_indices,
cu_target_query_lens,
cu_num_tokens,
@ -320,9 +354,9 @@ class EagleProposer:
with set_forward_context(None, self.vllm_config,
num_tokens=num_tokens):
self.model(
input_ids=self.input_ids[:num_tokens],
positions=self.positions[:num_tokens],
hidden_states=self.hidden_states[:num_tokens],
self.input_ids[:num_tokens],
self.positions[:num_tokens],
self.hidden_states[:num_tokens],
)
@ -367,29 +401,3 @@ def compute_probs_and_sample_next_token(
next_token_ids,
)
return next_token_ids, probs
@triton.jit
def prepare_input_kernel(
out_ptr,
cu_query_lens_ptr,
cu_num_tokens_ptr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
# [start_pos, end_pos)
start_pos = tl.load(cu_num_tokens_ptr + pid)
end_pos = tl.load(cu_num_tokens_ptr + pid + 1)
num_tokens = end_pos - start_pos
index_start = tl.load(cu_query_lens_ptr + pid)
num_blocks = tl.cdiv(num_tokens, BLOCK_SIZE)
for i in tl.range(num_blocks):
offset = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
tl.store(
out_ptr + start_pos + offset,
index_start + offset,
mask=offset < num_tokens,
)

View File

@ -1,4 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
from vllm.triton_utils import tl, triton
from vllm.v1.worker.gpu_input_batch import InputBatch
@ -16,3 +17,29 @@ def is_spec_decode_supported(req_id: str, input_batch: InputBatch) -> bool:
return False
return True
@triton.jit
def prepare_eagle_input_kernel(
out_ptr,
cu_query_lens_ptr,
cu_num_tokens_ptr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
# [start_pos, end_pos)
start_pos = tl.load(cu_num_tokens_ptr + pid)
end_pos = tl.load(cu_num_tokens_ptr + pid + 1)
num_tokens = end_pos - start_pos
index_start = tl.load(cu_query_lens_ptr + pid)
num_blocks = tl.cdiv(num_tokens, BLOCK_SIZE)
for i in tl.range(num_blocks):
offset = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
tl.store(
out_ptr + start_pos + offset,
index_start + offset,
mask=offset < num_tokens,
)

View File

@ -151,12 +151,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.use_aux_hidden_state_outputs = False
if self.speculative_config:
self.use_spec_decode = True
# NOTE(Jiayi): currently we put the entire draft model on
# the last PP rank. This is not ideal if there are many
# layers in the draft model.
if get_pp_group().is_last_rank:
if self.speculative_config.method == "ngram":
self.drafter = NgramProposer(self.vllm_config)
elif self.speculative_config.use_eagle():
self.drafter = EagleProposer(self.vllm_config,
self.device) # type: ignore
self.drafter = EagleProposer(self.vllm_config, self.device,
self) # type: ignore
if self.speculative_config.method == "eagle3":
self.use_aux_hidden_state_outputs = True
elif self.speculative_config.method == "medusa":
@ -1361,6 +1365,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
device=self.device)
eagle_attn_metadata = attn_metadata[self.drafter.attn_layer_name]
# NOTE: deepseek_mtp uses MLA which does not have `block_table`
if hasattr(eagle_attn_metadata, "block_table"):
block_table = eagle_attn_metadata.block_table
else:
block_table = None
if spec_decode_metadata is None:
# input_ids can be None for multimodal models.
target_token_ids = self.input_ids[:num_scheduled_tokens]
@ -1406,7 +1416,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
target_slot_mapping=target_slot_mapping,
next_token_ids=next_token_ids,
cu_num_tokens=cu_num_tokens,
block_table=eagle_attn_metadata.block_table,
block_table=block_table,
sampling_metadata=sampling_metadata,
)
spec_token_ids = draft_token_ids.tolist()
@ -1723,8 +1733,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
else:
hidden_states = outputs
if self.use_spec_decode and \
self.speculative_config.method in ('eagle', 'eagle3'):
if self.use_spec_decode and self.speculative_config.use_eagle():
assert isinstance(self.drafter, EagleProposer)
self.drafter.dummy_run(num_tokens)