From d811b442d305b33b3aca2836c5d7f761effe76de Mon Sep 17 00:00:00 2001 From: Haco <75477391+xiaohajiayou@users.noreply.github.com> Date: Sat, 1 Nov 2025 22:52:43 +0800 Subject: [PATCH] [Bugfix] DeepSeek V3.2 MTP metadata & CUDA graph issues (#26779) Signed-off-by: xiaohajiayou <923390377@qq.com> --- vllm/v1/spec_decode/eagle.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 35c2e73e8ee2c..1e18eea2330a4 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -109,6 +109,7 @@ class EagleProposer: else [] ) + self.use_cuda_graph = self.use_cuda_graph and bool(self.cudagraph_batch_sizes) # persistent buffers for cuda graph self.input_ids = torch.zeros( self.max_num_tokens, dtype=torch.int32, device=device @@ -939,7 +940,7 @@ class EagleProposer: self.vllm_config, DeepseekV32IndexerCache ) draft_indexer_layer_names = indexer_layers.keys() - target_indexer_layer_names - self.attn_layer_names = list(draft_attn_layer_names) + self.attn_layer_names = list(draft_attn_layer_names - draft_indexer_layer_names) self.indexer_layer_names = list(draft_indexer_layer_names) if self.indexer_layer_names: @@ -1050,16 +1051,18 @@ class EagleProposer: num_tokens: int, use_cudagraphs=True, ) -> None: - if use_cudagraphs and num_tokens <= self.cudagraph_batch_sizes[-1]: + # Determine if CUDA graphs should be used for this run. + cudagraphs_enabled = use_cudagraphs and self.use_cuda_graph + if cudagraphs_enabled and num_tokens <= self.cudagraph_batch_sizes[-1]: num_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) with set_forward_context( None, self.vllm_config, num_tokens=num_tokens, - cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE - if use_cudagraphs - else CUDAGraphMode.NONE, + cudagraph_runtime_mode=( + CUDAGraphMode.PIECEWISE if cudagraphs_enabled else CUDAGraphMode.NONE + ), ): if self.supports_mm_inputs: input_ids = None