diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index bc1b3e2319d05..3ad9b49933275 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -133,7 +133,7 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase): def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest): """Determines if draft_model_runner GPU multi-step can be used. Currently required conditions are: - 1. Only decodes + 1. Only decodes 2. Only flash-attn 3. No LORA 4. No prompt_adapter_config @@ -171,12 +171,12 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase): num_steps: int = 1, **kwargs, ) -> Optional[List[SamplerOutput]]: - """Executes num_steps forward passes with advacement of input tensors + """Executes num_steps forward passes with advacement of input tensors on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions. Optimizations used: 1. Input tensors are updated on the GPU directly - 2. Skips GPU=>CPU serialization of sampler outputs (we don't need + 2. Skips GPU=>CPU serialization of sampler outputs (we don't need them since we do batch expansion later that uses GPU outputs) 3. Reuses sampling tensors (since we run only decodes and they have a repeating sampling logic) @@ -302,7 +302,12 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase): outputs.append(output) if self.return_hidden_states and is_fallback: - output.hidden_states = hidden_states + if use_cuda_graph: + indices = model_input.sampling_metadata\ + .selected_token_indices + output.hidden_states = hidden_states[:len(indices)] + else: + output.hidden_states = hidden_states if model_input.attn_metadata.num_prefills == 0 \ and self.indices_of_seq_with_bonus_tokens is not None: