From 9804145cac8b108725b2a431be6cd8a65d20ef9a Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Thu, 27 Feb 2025 18:28:08 -0500 Subject: [PATCH] [Model][Speculative Decoding] Expand DeepSeek MTP code to support k > n_predict (#13626) Signed-off-by: Benjamin Chislett --- vllm/config.py | 11 +++++------ vllm/model_executor/models/deepseek_mtp.py | 14 +++++++++----- vllm/spec_decode/draft_model_runner.py | 11 ++++------- vllm/spec_decode/multi_step_worker.py | 17 +++++++++++++++++ vllm/spec_decode/spec_decode_worker.py | 6 +++--- vllm/worker/model_runner.py | 12 +++++++++++- 6 files changed, 49 insertions(+), 22 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index cb683d19386b9..c3f9932ab8b3f 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1978,13 +1978,12 @@ class SpeculativeConfig: if num_speculative_tokens is None: # Default to max value defined in draft model config. num_speculative_tokens = n_predict - elif num_speculative_tokens > n_predict: - # Verify provided value doesn't exceed the maximum - # supported by the draft model. + elif num_speculative_tokens > n_predict and \ + num_speculative_tokens % n_predict != 0: + # Ensure divisibility for MTP module reuse. raise ValueError( - "This speculative model supports a maximum of " - f"num_speculative_tokens={n_predict}, but " - f"{num_speculative_tokens=} was provided.") + f"{num_speculative_tokens=} must be divisible by " + f"{n_predict=}") speculative_draft_tensor_parallel_size = \ SpeculativeConfig._verify_and_get_draft_model_tensor_parallel_size( diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index cac1b2b3b11cc..e7fde76cd0ba4 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -87,7 +87,7 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module): hidden_states=hidden_states, residual=None) hidden_states = residual + hidden_states - return self.shared_head(hidden_states) + return hidden_states class DeepSeekMultiTokenPredictor(nn.Module): @@ -121,12 +121,13 @@ class DeepSeekMultiTokenPredictor(nn.Module): inputs_embeds: Optional[torch.Tensor] = None, spec_step_idx: int = 0, ) -> torch.Tensor: - return self.layers[str(self.mtp_start_layer_idx + spec_step_idx)]( + current_step_idx = (spec_step_idx % self.num_mtp_layers) + return self.layers[str(self.mtp_start_layer_idx + current_step_idx)]( input_ids, positions, previous_hidden_states, inputs_embeds, - spec_step_idx, + current_step_idx, ) def compute_logits( @@ -135,9 +136,12 @@ class DeepSeekMultiTokenPredictor(nn.Module): sampling_metadata: SamplingMetadata, spec_step_idx: int = 0, ) -> torch.Tensor: - mtp_layer = self.layers[str(self.mtp_start_layer_idx + spec_step_idx)] + current_step_idx = (spec_step_idx % self.num_mtp_layers) + mtp_layer = self.layers[str(self.mtp_start_layer_idx + + current_step_idx)] logits = self.logits_processor(mtp_layer.shared_head.head, - hidden_states, sampling_metadata) + mtp_layer.shared_head(hidden_states), + sampling_metadata) return logits diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index c54e6abe18d73..bc1b3e2319d05 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -50,12 +50,6 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase): """ def __init__(self, model_runner: ModelRunnerBase): - if hasattr( - model_runner, - "return_hidden_states") and model_runner.return_hidden_states: - raise ValueError( - "return_hidden_states is not supported for TP1DraftModelRunner." - ) super().__init__(model_runner) self.indices_of_seq_with_bonus_tokens = None @@ -153,7 +147,7 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase): return False # TODO: Add support for other attn backends - if self.attn_backend.get_name() not in ("FLASH_ATTN", "TRITON_MLA"): + if self.attn_backend.get_name() not in ("FLASH_ATTN", ): return False # TODO: Add support for LORA @@ -307,6 +301,9 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase): ) outputs.append(output) + if self.return_hidden_states and is_fallback: + output.hidden_states = hidden_states + if model_input.attn_metadata.num_prefills == 0 \ and self.indices_of_seq_with_bonus_tokens is not None: assert output.sampled_token_ids is not None diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index c28d413efe747..d8d54918fa98e 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -96,12 +96,16 @@ class MultiStepWorker(ProposerWorkerBase, DelegateWorkerBase): # TODO: Remove this branch once DraftModelRunner supports TP>1 # and other restrictions that are part of DraftModelRunner's # supports_gpu_multi_step(..) + if expanded_request.previous_hidden_states is not None: + self.worker.model_runner.return_hidden_states = True for _ in range(sample_len): model_output: List[SamplerOutput] = self.worker.execute_model( execute_model_req=expanded_request) assert (len(model_output) == 1 ), "composing multistep workers not supported" model_output = model_output[0] + self._maybe_update_previous_hidden_states( + model_output, expanded_request) self._append_new_tokens( model_output, expanded_request.seq_group_metadata_list, @@ -115,6 +119,19 @@ class MultiStepWorker(ProposerWorkerBase, DelegateWorkerBase): model_outputs, indices_of_seq_with_bonus_tokens) return filtered_model_outputs, True + @staticmethod + def _maybe_update_previous_hidden_states( + model_output: SamplerOutput, + expanded_request: ExecuteModelRequest) -> None: + """ + Updates the previous hidden states in an expanded request + in-place with the hidden states from the model output. + """ + if expanded_request.previous_hidden_states is not None: + expanded_request.previous_hidden_states = HiddenStates( + model_output.hidden_states, + expanded_request.seq_group_metadata_list) + @staticmethod def _expand_execute_model_request( execute_model_req: ExecuteModelRequest, diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 871a3aee63063..8909a41bc99fc 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -184,8 +184,7 @@ class SpecDecodeWorker(LoRANotSupportedWorkerBase): elif draft_model_config.hf_config.model_type == "medusa": proposer_worker = MedusaWorker(**draft_worker_kwargs) else: - if draft_tp == 1 or draft_model_config.hf_config.model_type ==\ - "deepseek_mtp": + if draft_tp == 1: if current_platform.is_cuda_alike(): draft_worker_kwargs[ "model_runner_cls"] = TP1DraftModelRunner @@ -203,7 +202,8 @@ class SpecDecodeWorker(LoRANotSupportedWorkerBase): proposer_worker = MultiStepWorker(**draft_worker_kwargs) if draft_model_config.hf_config.model_type == "deepseek_mtp": - num_spec_prefill_steps = num_speculative_tokens + num_spec_prefill_steps = \ + draft_model_config.hf_config.n_predict proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker( proposer_worker, draft_tp, target_tp) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index a37a3168bbbc7..bb2228165b528 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1685,11 +1685,22 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): # TODO(andoorve): We can remove this once all # virtual engines share the same kv cache. virtual_engine = model_input.virtual_engine + previous_hidden_states = kwargs.get("previous_hidden_states") if prefill_meta is None and decode_meta.use_cuda_graph: assert model_input.input_tokens is not None graph_batch_size = model_input.input_tokens.shape[0] model_executable = self.graph_runners[virtual_engine][ graph_batch_size] + if previous_hidden_states is not None: + previous_hidden_states = torch.cat([ + previous_hidden_states, + torch.empty([ + graph_batch_size - previous_hidden_states.shape[0], + *previous_hidden_states.shape[1:] + ], + dtype=previous_hidden_states.dtype, + device=previous_hidden_states.device) + ]) else: model_executable = self.model @@ -1716,7 +1727,6 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): "finished_requests_ids": model_input.finished_requests_ids, "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, } if self.has_inner_state else {} - previous_hidden_states = kwargs.get("previous_hidden_states") model_kwargs = {} if previous_hidden_states is not None: model_kwargs["previous_hidden_states"] = previous_hidden_states