mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-01 03:17:03 +08:00
[V1] Support Deepseek MTP (#18435)
Signed-off-by: Rui Qiao <ruisearch42@gmail.com> Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn> Co-authored-by: Rui Qiao <ruisearch42@gmail.com>
This commit is contained in:
parent
371f7e4ca2
commit
2628a69e35
@ -2255,7 +2255,7 @@ class DeviceConfig:
|
||||
|
||||
|
||||
SpeculativeMethod = Literal["ngram", "eagle", "medusa", "mlp_speculator",
|
||||
"draft_model"]
|
||||
"draft_model", "deepseek_mtp"]
|
||||
SpeculativeAcceptanceMethod = Literal["rejection_sampler",
|
||||
"typical_acceptance_sampler"]
|
||||
|
||||
@ -2519,6 +2519,15 @@ class SpeculativeConfig:
|
||||
elif (self.draft_model_config.hf_config.model_type ==
|
||||
"mlp_speculator"):
|
||||
self.method = "mlp_speculator"
|
||||
elif (self.draft_model_config.hf_config.model_type ==
|
||||
"deepseek_mtp"):
|
||||
self.method = "deepseek_mtp"
|
||||
if self.num_speculative_tokens > 1:
|
||||
logger.warning(
|
||||
"All Deepseek MTP models only have " \
|
||||
"one layer. Might need some code changes " \
|
||||
"to support multiple layers."
|
||||
)
|
||||
else:
|
||||
self.method = "draft_model"
|
||||
|
||||
@ -2738,7 +2747,7 @@ class SpeculativeConfig:
|
||||
return self.num_speculative_tokens
|
||||
|
||||
def use_eagle(self) -> bool:
|
||||
return self.method in ("eagle", "eagle3")
|
||||
return self.method in ("eagle", "eagle3", "deepseek_mtp")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
method = self.method
|
||||
|
||||
@ -1338,7 +1338,7 @@ class EngineArgs:
|
||||
is_ngram_enabled = True
|
||||
elif speculative_method == "medusa":
|
||||
is_medusa_enabled = True
|
||||
elif speculative_method in ("eagle", "eagle3"):
|
||||
elif speculative_method in ("eagle", "eagle3", "deepseek_mtp"):
|
||||
is_eagle_enabled = True
|
||||
else:
|
||||
speculative_model = self.speculative_config.get("model")
|
||||
|
||||
@ -19,6 +19,7 @@ from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .deepseek_v2 import (DeepseekV2DecoderLayer,
|
||||
get_spec_layer_idx_from_weight_name)
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import maybe_prefix
|
||||
|
||||
|
||||
@ -145,7 +146,7 @@ class DeepSeekMultiTokenPredictor(nn.Module):
|
||||
return logits
|
||||
|
||||
|
||||
class DeepSeekMTP(nn.Module):
|
||||
class DeepSeekMTP(nn.Module, SupportsPP):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
@ -10,9 +10,10 @@ from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.v1.attention.backends.flash_attn import (CommonAttentionMetadata,
|
||||
FlashAttentionMetadata)
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.spec_decode.utils import prepare_eagle_input_kernel
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -25,12 +26,15 @@ class EagleProposer:
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
runner=None,
|
||||
):
|
||||
self.vllm_config = vllm_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.draft_model_config = self.speculative_config.draft_model_config
|
||||
self.method = self.speculative_config.method
|
||||
|
||||
self.runner = runner
|
||||
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
self.max_model_len = vllm_config.model_config.max_model_len
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
@ -106,24 +110,46 @@ class EagleProposer:
|
||||
# FA requires seq_len to have dtype int32.
|
||||
seq_lens = (target_positions[last_token_indices] + 1).int()
|
||||
|
||||
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
|
||||
max_seq_len = seq_lens.max().item()
|
||||
max_num_tokens = (cu_num_tokens[1:] - cu_num_tokens[:-1]).max().item()
|
||||
attn_metadata = FlashAttentionMetadata(
|
||||
num_actual_tokens=num_tokens,
|
||||
max_query_len=max_num_tokens,
|
||||
query_start_loc=cu_num_tokens,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=seq_lens,
|
||||
block_table=block_table,
|
||||
slot_mapping=target_slot_mapping,
|
||||
# TODO(woosuk): Support cascade attention.
|
||||
use_cascade=False,
|
||||
common_prefix_len=0,
|
||||
cu_prefix_query_lens=None,
|
||||
prefix_kv_lens=None,
|
||||
suffix_kv_lens=None,
|
||||
)
|
||||
if self.method in ["eagle", "eagle3"]:
|
||||
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
|
||||
max_seq_len = seq_lens.max().item()
|
||||
max_num_tokens = (cu_num_tokens[1:] -
|
||||
cu_num_tokens[:-1]).max().item()
|
||||
attn_metadata = FlashAttentionMetadata(
|
||||
num_actual_tokens=num_tokens,
|
||||
max_query_len=max_num_tokens,
|
||||
query_start_loc=cu_num_tokens,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=seq_lens,
|
||||
block_table=block_table,
|
||||
slot_mapping=target_slot_mapping,
|
||||
# TODO(woosuk): Support cascade attention.
|
||||
use_cascade=False,
|
||||
common_prefix_len=0,
|
||||
cu_prefix_query_lens=None,
|
||||
prefix_kv_lens=None,
|
||||
suffix_kv_lens=None,
|
||||
)
|
||||
elif self.method == "deepseek_mtp":
|
||||
query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1]
|
||||
max_query_len = query_lens.max().item()
|
||||
|
||||
common_attn_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=cu_num_tokens, seq_lens=seq_lens)
|
||||
|
||||
assert self.runner is not None
|
||||
|
||||
# FIXME: need to consider multiple kv_cache_groups
|
||||
attn_metadata = self.runner.attn_metadata_builder.build(
|
||||
num_reqs=batch_size,
|
||||
num_actual_tokens=num_tokens,
|
||||
max_query_len=max_query_len,
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported method: {self.method}")
|
||||
|
||||
if self.use_cuda_graph and \
|
||||
num_tokens <= self.cudagraph_batch_sizes[-1]:
|
||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
|
||||
@ -136,11 +162,15 @@ class EagleProposer:
|
||||
with set_forward_context(attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens):
|
||||
last_hidden_states, hidden_states = self.model(
|
||||
input_ids=self.input_ids[:num_input_tokens],
|
||||
positions=self.positions[:num_input_tokens],
|
||||
hidden_states=self.hidden_states[:num_input_tokens],
|
||||
ret_hidden_states = self.model(
|
||||
self.input_ids[:num_input_tokens],
|
||||
self.positions[:num_input_tokens],
|
||||
self.hidden_states[:num_input_tokens],
|
||||
)
|
||||
if self.method == "deepseek_mtp":
|
||||
last_hidden_states = ret_hidden_states
|
||||
else:
|
||||
last_hidden_states, hidden_states = ret_hidden_states
|
||||
sample_hidden_states = last_hidden_states[last_token_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||
draft_token_ids = logits.argmax(dim=-1)
|
||||
@ -150,6 +180,10 @@ class EagleProposer:
|
||||
# [batch_size, 1]
|
||||
return draft_token_ids.view(-1, 1)
|
||||
|
||||
# TODO: Currently, MTP module released by deepseek only has
|
||||
# one layer. Adapt this code to support multiple layers once
|
||||
# there's a multi-layer MTP module.
|
||||
|
||||
# Generate the remaining draft tokens.
|
||||
draft_token_ids_list = [draft_token_ids]
|
||||
|
||||
@ -215,9 +249,9 @@ class EagleProposer:
|
||||
self.vllm_config,
|
||||
num_tokens=input_batch_size):
|
||||
last_hidden_states, hidden_states = self.model(
|
||||
input_ids=self.input_ids[:input_batch_size],
|
||||
positions=self.positions[:input_batch_size],
|
||||
hidden_states=self.hidden_states[:input_batch_size],
|
||||
self.input_ids[:input_batch_size],
|
||||
self.positions[:input_batch_size],
|
||||
self.hidden_states[:input_batch_size],
|
||||
)
|
||||
hidden_states = hidden_states[:batch_size]
|
||||
logits = self.model.compute_logits(last_hidden_states[:batch_size],
|
||||
@ -268,7 +302,7 @@ class EagleProposer:
|
||||
|
||||
batch_size = num_rejected_tokens.shape[0]
|
||||
BLOCK_SIZE = 1024
|
||||
prepare_input_kernel[(batch_size, )](
|
||||
prepare_eagle_input_kernel[(batch_size, )](
|
||||
token_indices,
|
||||
cu_target_query_lens,
|
||||
cu_num_tokens,
|
||||
@ -320,9 +354,9 @@ class EagleProposer:
|
||||
with set_forward_context(None, self.vllm_config,
|
||||
num_tokens=num_tokens):
|
||||
self.model(
|
||||
input_ids=self.input_ids[:num_tokens],
|
||||
positions=self.positions[:num_tokens],
|
||||
hidden_states=self.hidden_states[:num_tokens],
|
||||
self.input_ids[:num_tokens],
|
||||
self.positions[:num_tokens],
|
||||
self.hidden_states[:num_tokens],
|
||||
)
|
||||
|
||||
|
||||
@ -367,29 +401,3 @@ def compute_probs_and_sample_next_token(
|
||||
next_token_ids,
|
||||
)
|
||||
return next_token_ids, probs
|
||||
|
||||
|
||||
@triton.jit
|
||||
def prepare_input_kernel(
|
||||
out_ptr,
|
||||
cu_query_lens_ptr,
|
||||
cu_num_tokens_ptr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
|
||||
# [start_pos, end_pos)
|
||||
start_pos = tl.load(cu_num_tokens_ptr + pid)
|
||||
end_pos = tl.load(cu_num_tokens_ptr + pid + 1)
|
||||
num_tokens = end_pos - start_pos
|
||||
|
||||
index_start = tl.load(cu_query_lens_ptr + pid)
|
||||
|
||||
num_blocks = tl.cdiv(num_tokens, BLOCK_SIZE)
|
||||
for i in tl.range(num_blocks):
|
||||
offset = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
tl.store(
|
||||
out_ptr + start_pos + offset,
|
||||
index_start + offset,
|
||||
mask=offset < num_tokens,
|
||||
)
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
|
||||
@ -16,3 +17,29 @@ def is_spec_decode_supported(req_id: str, input_batch: InputBatch) -> bool:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@triton.jit
|
||||
def prepare_eagle_input_kernel(
|
||||
out_ptr,
|
||||
cu_query_lens_ptr,
|
||||
cu_num_tokens_ptr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
|
||||
# [start_pos, end_pos)
|
||||
start_pos = tl.load(cu_num_tokens_ptr + pid)
|
||||
end_pos = tl.load(cu_num_tokens_ptr + pid + 1)
|
||||
num_tokens = end_pos - start_pos
|
||||
|
||||
index_start = tl.load(cu_query_lens_ptr + pid)
|
||||
|
||||
num_blocks = tl.cdiv(num_tokens, BLOCK_SIZE)
|
||||
for i in tl.range(num_blocks):
|
||||
offset = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
tl.store(
|
||||
out_ptr + start_pos + offset,
|
||||
index_start + offset,
|
||||
mask=offset < num_tokens,
|
||||
)
|
||||
|
||||
@ -151,12 +151,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.use_aux_hidden_state_outputs = False
|
||||
if self.speculative_config:
|
||||
self.use_spec_decode = True
|
||||
|
||||
# NOTE(Jiayi): currently we put the entire draft model on
|
||||
# the last PP rank. This is not ideal if there are many
|
||||
# layers in the draft model.
|
||||
if get_pp_group().is_last_rank:
|
||||
if self.speculative_config.method == "ngram":
|
||||
self.drafter = NgramProposer(self.vllm_config)
|
||||
elif self.speculative_config.use_eagle():
|
||||
self.drafter = EagleProposer(self.vllm_config,
|
||||
self.device) # type: ignore
|
||||
self.drafter = EagleProposer(self.vllm_config, self.device,
|
||||
self) # type: ignore
|
||||
if self.speculative_config.method == "eagle3":
|
||||
self.use_aux_hidden_state_outputs = True
|
||||
elif self.speculative_config.method == "medusa":
|
||||
@ -1361,6 +1365,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
device=self.device)
|
||||
eagle_attn_metadata = attn_metadata[self.drafter.attn_layer_name]
|
||||
|
||||
# NOTE: deepseek_mtp uses MLA which does not have `block_table`
|
||||
if hasattr(eagle_attn_metadata, "block_table"):
|
||||
block_table = eagle_attn_metadata.block_table
|
||||
else:
|
||||
block_table = None
|
||||
|
||||
if spec_decode_metadata is None:
|
||||
# input_ids can be None for multimodal models.
|
||||
target_token_ids = self.input_ids[:num_scheduled_tokens]
|
||||
@ -1406,7 +1416,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
target_slot_mapping=target_slot_mapping,
|
||||
next_token_ids=next_token_ids,
|
||||
cu_num_tokens=cu_num_tokens,
|
||||
block_table=eagle_attn_metadata.block_table,
|
||||
block_table=block_table,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
spec_token_ids = draft_token_ids.tolist()
|
||||
@ -1723,8 +1733,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
else:
|
||||
hidden_states = outputs
|
||||
|
||||
if self.use_spec_decode and \
|
||||
self.speculative_config.method in ('eagle', 'eagle3'):
|
||||
if self.use_spec_decode and self.speculative_config.use_eagle():
|
||||
assert isinstance(self.drafter, EagleProposer)
|
||||
self.drafter.dummy_run(num_tokens)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user