From 1e3e76b6cc75c90da6c738951b0b47835de0b6be Mon Sep 17 00:00:00 2001 From: pyc96 Date: Wed, 5 Mar 2025 14:22:40 -0800 Subject: [PATCH] [Bugfix] Fix DeepSeek MTP crash when using TP1ModelRunner with CUDA graph due to shape mismatch (#14237) Signed-off-by: pyc96 --- vllm/spec_decode/draft_model_runner.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index bc1b3e2319d05..3ad9b49933275 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -133,7 +133,7 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase): def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest): """Determines if draft_model_runner GPU multi-step can be used. Currently required conditions are: - 1. Only decodes + 1. Only decodes 2. Only flash-attn 3. No LORA 4. No prompt_adapter_config @@ -171,12 +171,12 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase): num_steps: int = 1, **kwargs, ) -> Optional[List[SamplerOutput]]: - """Executes num_steps forward passes with advacement of input tensors + """Executes num_steps forward passes with advacement of input tensors on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions. Optimizations used: 1. Input tensors are updated on the GPU directly - 2. Skips GPU=>CPU serialization of sampler outputs (we don't need + 2. Skips GPU=>CPU serialization of sampler outputs (we don't need them since we do batch expansion later that uses GPU outputs) 3. Reuses sampling tensors (since we run only decodes and they have a repeating sampling logic) @@ -302,7 +302,12 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase): outputs.append(output) if self.return_hidden_states and is_fallback: - output.hidden_states = hidden_states + if use_cuda_graph: + indices = model_input.sampling_metadata\ + .selected_token_indices + output.hidden_states = hidden_states[:len(indices)] + else: + output.hidden_states = hidden_states if model_input.attn_metadata.num_prefills == 0 \ and self.indices_of_seq_with_bonus_tokens is not None: