mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:34:57 +08:00
[Model][Speculative Decoding] Expand DeepSeek MTP code to support k > n_predict (#13626)
Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
This commit is contained in:
parent
2e94b9cfbb
commit
9804145cac
@ -1978,13 +1978,12 @@ class SpeculativeConfig:
|
||||
if num_speculative_tokens is None:
|
||||
# Default to max value defined in draft model config.
|
||||
num_speculative_tokens = n_predict
|
||||
elif num_speculative_tokens > n_predict:
|
||||
# Verify provided value doesn't exceed the maximum
|
||||
# supported by the draft model.
|
||||
elif num_speculative_tokens > n_predict and \
|
||||
num_speculative_tokens % n_predict != 0:
|
||||
# Ensure divisibility for MTP module reuse.
|
||||
raise ValueError(
|
||||
"This speculative model supports a maximum of "
|
||||
f"num_speculative_tokens={n_predict}, but "
|
||||
f"{num_speculative_tokens=} was provided.")
|
||||
f"{num_speculative_tokens=} must be divisible by "
|
||||
f"{n_predict=}")
|
||||
|
||||
speculative_draft_tensor_parallel_size = \
|
||||
SpeculativeConfig._verify_and_get_draft_model_tensor_parallel_size(
|
||||
|
||||
@ -87,7 +87,7 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
|
||||
hidden_states=hidden_states,
|
||||
residual=None)
|
||||
hidden_states = residual + hidden_states
|
||||
return self.shared_head(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DeepSeekMultiTokenPredictor(nn.Module):
|
||||
@ -121,12 +121,13 @@ class DeepSeekMultiTokenPredictor(nn.Module):
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor:
|
||||
return self.layers[str(self.mtp_start_layer_idx + spec_step_idx)](
|
||||
current_step_idx = (spec_step_idx % self.num_mtp_layers)
|
||||
return self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
|
||||
input_ids,
|
||||
positions,
|
||||
previous_hidden_states,
|
||||
inputs_embeds,
|
||||
spec_step_idx,
|
||||
current_step_idx,
|
||||
)
|
||||
|
||||
def compute_logits(
|
||||
@ -135,9 +136,12 @@ class DeepSeekMultiTokenPredictor(nn.Module):
|
||||
sampling_metadata: SamplingMetadata,
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor:
|
||||
mtp_layer = self.layers[str(self.mtp_start_layer_idx + spec_step_idx)]
|
||||
current_step_idx = (spec_step_idx % self.num_mtp_layers)
|
||||
mtp_layer = self.layers[str(self.mtp_start_layer_idx +
|
||||
current_step_idx)]
|
||||
logits = self.logits_processor(mtp_layer.shared_head.head,
|
||||
hidden_states, sampling_metadata)
|
||||
mtp_layer.shared_head(hidden_states),
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
|
||||
|
||||
@ -50,12 +50,6 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
|
||||
"""
|
||||
|
||||
def __init__(self, model_runner: ModelRunnerBase):
|
||||
if hasattr(
|
||||
model_runner,
|
||||
"return_hidden_states") and model_runner.return_hidden_states:
|
||||
raise ValueError(
|
||||
"return_hidden_states is not supported for TP1DraftModelRunner."
|
||||
)
|
||||
super().__init__(model_runner)
|
||||
|
||||
self.indices_of_seq_with_bonus_tokens = None
|
||||
@ -153,7 +147,7 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
|
||||
return False
|
||||
|
||||
# TODO: Add support for other attn backends
|
||||
if self.attn_backend.get_name() not in ("FLASH_ATTN", "TRITON_MLA"):
|
||||
if self.attn_backend.get_name() not in ("FLASH_ATTN", ):
|
||||
return False
|
||||
|
||||
# TODO: Add support for LORA
|
||||
@ -307,6 +301,9 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
|
||||
)
|
||||
outputs.append(output)
|
||||
|
||||
if self.return_hidden_states and is_fallback:
|
||||
output.hidden_states = hidden_states
|
||||
|
||||
if model_input.attn_metadata.num_prefills == 0 \
|
||||
and self.indices_of_seq_with_bonus_tokens is not None:
|
||||
assert output.sampled_token_ids is not None
|
||||
|
||||
@ -96,12 +96,16 @@ class MultiStepWorker(ProposerWorkerBase, DelegateWorkerBase):
|
||||
# TODO: Remove this branch once DraftModelRunner supports TP>1
|
||||
# and other restrictions that are part of DraftModelRunner's
|
||||
# supports_gpu_multi_step(..)
|
||||
if expanded_request.previous_hidden_states is not None:
|
||||
self.worker.model_runner.return_hidden_states = True
|
||||
for _ in range(sample_len):
|
||||
model_output: List[SamplerOutput] = self.worker.execute_model(
|
||||
execute_model_req=expanded_request)
|
||||
assert (len(model_output) == 1
|
||||
), "composing multistep workers not supported"
|
||||
model_output = model_output[0]
|
||||
self._maybe_update_previous_hidden_states(
|
||||
model_output, expanded_request)
|
||||
|
||||
self._append_new_tokens(
|
||||
model_output, expanded_request.seq_group_metadata_list,
|
||||
@ -115,6 +119,19 @@ class MultiStepWorker(ProposerWorkerBase, DelegateWorkerBase):
|
||||
model_outputs, indices_of_seq_with_bonus_tokens)
|
||||
return filtered_model_outputs, True
|
||||
|
||||
@staticmethod
|
||||
def _maybe_update_previous_hidden_states(
|
||||
model_output: SamplerOutput,
|
||||
expanded_request: ExecuteModelRequest) -> None:
|
||||
"""
|
||||
Updates the previous hidden states in an expanded request
|
||||
in-place with the hidden states from the model output.
|
||||
"""
|
||||
if expanded_request.previous_hidden_states is not None:
|
||||
expanded_request.previous_hidden_states = HiddenStates(
|
||||
model_output.hidden_states,
|
||||
expanded_request.seq_group_metadata_list)
|
||||
|
||||
@staticmethod
|
||||
def _expand_execute_model_request(
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
|
||||
@ -184,8 +184,7 @@ class SpecDecodeWorker(LoRANotSupportedWorkerBase):
|
||||
elif draft_model_config.hf_config.model_type == "medusa":
|
||||
proposer_worker = MedusaWorker(**draft_worker_kwargs)
|
||||
else:
|
||||
if draft_tp == 1 or draft_model_config.hf_config.model_type ==\
|
||||
"deepseek_mtp":
|
||||
if draft_tp == 1:
|
||||
if current_platform.is_cuda_alike():
|
||||
draft_worker_kwargs[
|
||||
"model_runner_cls"] = TP1DraftModelRunner
|
||||
@ -203,7 +202,8 @@ class SpecDecodeWorker(LoRANotSupportedWorkerBase):
|
||||
|
||||
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
|
||||
if draft_model_config.hf_config.model_type == "deepseek_mtp":
|
||||
num_spec_prefill_steps = num_speculative_tokens
|
||||
num_spec_prefill_steps = \
|
||||
draft_model_config.hf_config.n_predict
|
||||
|
||||
proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
|
||||
proposer_worker, draft_tp, target_tp)
|
||||
|
||||
@ -1685,11 +1685,22 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
||||
# TODO(andoorve): We can remove this once all
|
||||
# virtual engines share the same kv cache.
|
||||
virtual_engine = model_input.virtual_engine
|
||||
previous_hidden_states = kwargs.get("previous_hidden_states")
|
||||
if prefill_meta is None and decode_meta.use_cuda_graph:
|
||||
assert model_input.input_tokens is not None
|
||||
graph_batch_size = model_input.input_tokens.shape[0]
|
||||
model_executable = self.graph_runners[virtual_engine][
|
||||
graph_batch_size]
|
||||
if previous_hidden_states is not None:
|
||||
previous_hidden_states = torch.cat([
|
||||
previous_hidden_states,
|
||||
torch.empty([
|
||||
graph_batch_size - previous_hidden_states.shape[0],
|
||||
*previous_hidden_states.shape[1:]
|
||||
],
|
||||
dtype=previous_hidden_states.dtype,
|
||||
device=previous_hidden_states.device)
|
||||
])
|
||||
else:
|
||||
model_executable = self.model
|
||||
|
||||
@ -1716,7 +1727,6 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
||||
"finished_requests_ids": model_input.finished_requests_ids,
|
||||
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
|
||||
} if self.has_inner_state else {}
|
||||
previous_hidden_states = kwargs.get("previous_hidden_states")
|
||||
model_kwargs = {}
|
||||
if previous_hidden_states is not None:
|
||||
model_kwargs["previous_hidden_states"] = previous_hidden_states
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user