[Model] MTP fallback to eager for DeepSeek v32 (#25982)

Signed-off-by: Lu Fang <fanglu@fb.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
This commit is contained in:
Lucia Fang 2025-09-30 18:53:22 -07:00 committed by simon-mo
parent c214d699fd
commit bab9231bf1
5 changed files with 32 additions and 5 deletions

View File

@ -337,13 +337,19 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
"target_attn_1": mock.MagicMock(),
"target_attn_2": mock.MagicMock()
}
target_indx_layers: dict[str, mock.MagicMock] = {}
# Draft model has one extra attention layer compared to target model
all_attn_layers = {
**target_attn_layers, "draft_extra_attn": mock.MagicMock()
}
all_indx_layers: dict[str, mock.MagicMock] = {}
# Make mock_get_layers return different values for each call
mock_get_layers.side_effect = [target_attn_layers, all_attn_layers]
mock_get_layers.side_effect = [
target_attn_layers, target_indx_layers, all_attn_layers,
all_indx_layers
]
# Setup mock for pp group to return the appropriate value for world size
mock_pp_group = mock.MagicMock()
@ -658,6 +664,9 @@ def test_propose_tree(spec_token_tree):
# Mock runner for attention metadata building.
proposer.runner = mock.MagicMock()
proposer.runner.attn_groups.append([mock.MagicMock()])
proposer.runner.attn_groups[0][0].metadata_builders = [
attn_metadata_builder
]
proposer.runner.attn_groups[0][0].get_metadata_builder.return_value = \
attn_metadata_builder
proposer._get_attention_metadata_builder = mock.MagicMock(

View File

@ -63,7 +63,13 @@ def test_mtp_load_model_unified(mock_get_model, mock_get_layers,
target_attn_layers = {"target_attn_1": mock.MagicMock()}
all_attn_layers = {**target_attn_layers, "draft_attn_1": mock.MagicMock()}
mock_get_layers.side_effect = [target_attn_layers, all_attn_layers]
target_indexer_layers: dict = {}
all_indexer_layers: dict = {}
mock_get_layers.side_effect = [
target_attn_layers, target_indexer_layers, all_attn_layers,
all_indexer_layers
]
mock_pp_group = mock.MagicMock()
mock_pp_group.world_size = 1

View File

@ -41,7 +41,8 @@ MTP_MODEL_TYPES = ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp", "ernie_mtp",
@dataclass
class SpeculativeConfig:
"""Configuration for speculative decoding."""
enforce_eager: Optional[bool] = None
"""Override the default enforce_eager from model_config"""
# General speculative decoding control
num_speculative_tokens: SkipValidation[int] = None # type: ignore
"""The number of speculative tokens, if provided. It will default to the
@ -219,6 +220,11 @@ class SpeculativeConfig:
assert (
self.target_model_config
is not None), "target_model_config must be present for mtp"
if self.target_model_config.hf_text_config.model_type \
== "deepseek_v32":
# FIXME(luccafong): cudgraph with v32 MTP is not supported,
# remove this when the issue is fixed.
self.enforce_eager = True
# use the draft model from the same model:
self.model = self.target_model_config.model
# Align the quantization of draft model for cases such as

View File

@ -171,7 +171,7 @@ def get_max_prefill_buffer_size(vllm_config: VllmConfig):
class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_BATCH
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
reorder_batch_threshold: int = 1

View File

@ -49,6 +49,7 @@ class EagleProposer:
):
self.vllm_config = vllm_config
self.speculative_config = vllm_config.speculative_config
assert self.speculative_config is not None
self.draft_model_config = self.speculative_config.draft_model_config
self.method = self.speculative_config.method
@ -71,10 +72,15 @@ class EagleProposer:
.is_multimodal_model
self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None
self.draft_indexer_metadata_builder: Optional[
AttentionMetadataBuilder] = None
self.attn_layer_names: list[str] = []
self.indexer_layer_names: list[str] = []
self.use_cuda_graph = (self.vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE and
not self.vllm_config.model_config.enforce_eager)
not self.vllm_config.model_config.enforce_eager
and not self.speculative_config.enforce_eager)
self.cudagraph_batch_sizes = list(
reversed(
self.vllm_config.compilation_config.cudagraph_capture_sizes))