[Model][Speculative Decoding] Expand DeepSeek MTP code to support k > n_predict (#13626)

Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
This commit is contained in:
Benjamin Chislett 2025-02-27 18:28:08 -05:00 committed by GitHub
parent 2e94b9cfbb
commit 9804145cac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 49 additions and 22 deletions

View File

@ -1978,13 +1978,12 @@ class SpeculativeConfig:
if num_speculative_tokens is None:
# Default to max value defined in draft model config.
num_speculative_tokens = n_predict
elif num_speculative_tokens > n_predict:
# Verify provided value doesn't exceed the maximum
# supported by the draft model.
elif num_speculative_tokens > n_predict and \
num_speculative_tokens % n_predict != 0:
# Ensure divisibility for MTP module reuse.
raise ValueError(
"This speculative model supports a maximum of "
f"num_speculative_tokens={n_predict}, but "
f"{num_speculative_tokens=} was provided.")
f"{num_speculative_tokens=} must be divisible by "
f"{n_predict=}")
speculative_draft_tensor_parallel_size = \
SpeculativeConfig._verify_and_get_draft_model_tensor_parallel_size(

View File

@ -87,7 +87,7 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
hidden_states=hidden_states,
residual=None)
hidden_states = residual + hidden_states
return self.shared_head(hidden_states)
return hidden_states
class DeepSeekMultiTokenPredictor(nn.Module):
@ -121,12 +121,13 @@ class DeepSeekMultiTokenPredictor(nn.Module):
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
return self.layers[str(self.mtp_start_layer_idx + spec_step_idx)](
current_step_idx = (spec_step_idx % self.num_mtp_layers)
return self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
input_ids,
positions,
previous_hidden_states,
inputs_embeds,
spec_step_idx,
current_step_idx,
)
def compute_logits(
@ -135,9 +136,12 @@ class DeepSeekMultiTokenPredictor(nn.Module):
sampling_metadata: SamplingMetadata,
spec_step_idx: int = 0,
) -> torch.Tensor:
mtp_layer = self.layers[str(self.mtp_start_layer_idx + spec_step_idx)]
current_step_idx = (spec_step_idx % self.num_mtp_layers)
mtp_layer = self.layers[str(self.mtp_start_layer_idx +
current_step_idx)]
logits = self.logits_processor(mtp_layer.shared_head.head,
hidden_states, sampling_metadata)
mtp_layer.shared_head(hidden_states),
sampling_metadata)
return logits

View File

@ -50,12 +50,6 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
"""
def __init__(self, model_runner: ModelRunnerBase):
if hasattr(
model_runner,
"return_hidden_states") and model_runner.return_hidden_states:
raise ValueError(
"return_hidden_states is not supported for TP1DraftModelRunner."
)
super().__init__(model_runner)
self.indices_of_seq_with_bonus_tokens = None
@ -153,7 +147,7 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
return False
# TODO: Add support for other attn backends
if self.attn_backend.get_name() not in ("FLASH_ATTN", "TRITON_MLA"):
if self.attn_backend.get_name() not in ("FLASH_ATTN", ):
return False
# TODO: Add support for LORA
@ -307,6 +301,9 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
)
outputs.append(output)
if self.return_hidden_states and is_fallback:
output.hidden_states = hidden_states
if model_input.attn_metadata.num_prefills == 0 \
and self.indices_of_seq_with_bonus_tokens is not None:
assert output.sampled_token_ids is not None

View File

@ -96,12 +96,16 @@ class MultiStepWorker(ProposerWorkerBase, DelegateWorkerBase):
# TODO: Remove this branch once DraftModelRunner supports TP>1
# and other restrictions that are part of DraftModelRunner's
# supports_gpu_multi_step(..)
if expanded_request.previous_hidden_states is not None:
self.worker.model_runner.return_hidden_states = True
for _ in range(sample_len):
model_output: List[SamplerOutput] = self.worker.execute_model(
execute_model_req=expanded_request)
assert (len(model_output) == 1
), "composing multistep workers not supported"
model_output = model_output[0]
self._maybe_update_previous_hidden_states(
model_output, expanded_request)
self._append_new_tokens(
model_output, expanded_request.seq_group_metadata_list,
@ -115,6 +119,19 @@ class MultiStepWorker(ProposerWorkerBase, DelegateWorkerBase):
model_outputs, indices_of_seq_with_bonus_tokens)
return filtered_model_outputs, True
@staticmethod
def _maybe_update_previous_hidden_states(
model_output: SamplerOutput,
expanded_request: ExecuteModelRequest) -> None:
"""
Updates the previous hidden states in an expanded request
in-place with the hidden states from the model output.
"""
if expanded_request.previous_hidden_states is not None:
expanded_request.previous_hidden_states = HiddenStates(
model_output.hidden_states,
expanded_request.seq_group_metadata_list)
@staticmethod
def _expand_execute_model_request(
execute_model_req: ExecuteModelRequest,

View File

@ -184,8 +184,7 @@ class SpecDecodeWorker(LoRANotSupportedWorkerBase):
elif draft_model_config.hf_config.model_type == "medusa":
proposer_worker = MedusaWorker(**draft_worker_kwargs)
else:
if draft_tp == 1 or draft_model_config.hf_config.model_type ==\
"deepseek_mtp":
if draft_tp == 1:
if current_platform.is_cuda_alike():
draft_worker_kwargs[
"model_runner_cls"] = TP1DraftModelRunner
@ -203,7 +202,8 @@ class SpecDecodeWorker(LoRANotSupportedWorkerBase):
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
if draft_model_config.hf_config.model_type == "deepseek_mtp":
num_spec_prefill_steps = num_speculative_tokens
num_spec_prefill_steps = \
draft_model_config.hf_config.n_predict
proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
proposer_worker, draft_tp, target_tp)

View File

@ -1685,11 +1685,22 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
# TODO(andoorve): We can remove this once all
# virtual engines share the same kv cache.
virtual_engine = model_input.virtual_engine
previous_hidden_states = kwargs.get("previous_hidden_states")
if prefill_meta is None and decode_meta.use_cuda_graph:
assert model_input.input_tokens is not None
graph_batch_size = model_input.input_tokens.shape[0]
model_executable = self.graph_runners[virtual_engine][
graph_batch_size]
if previous_hidden_states is not None:
previous_hidden_states = torch.cat([
previous_hidden_states,
torch.empty([
graph_batch_size - previous_hidden_states.shape[0],
*previous_hidden_states.shape[1:]
],
dtype=previous_hidden_states.dtype,
device=previous_hidden_states.device)
])
else:
model_executable = self.model
@ -1716,7 +1727,6 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
"finished_requests_ids": model_input.finished_requests_ids,
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
} if self.has_inner_state else {}
previous_hidden_states = kwargs.get("previous_hidden_states")
model_kwargs = {}
if previous_hidden_states is not None:
model_kwargs["previous_hidden_states"] = previous_hidden_states