mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 04:24:56 +08:00
[Model][Speculative Decoding] Expand DeepSeek MTP code to support k > n_predict (#13626)
Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
This commit is contained in:
parent
2e94b9cfbb
commit
9804145cac
@ -1978,13 +1978,12 @@ class SpeculativeConfig:
|
|||||||
if num_speculative_tokens is None:
|
if num_speculative_tokens is None:
|
||||||
# Default to max value defined in draft model config.
|
# Default to max value defined in draft model config.
|
||||||
num_speculative_tokens = n_predict
|
num_speculative_tokens = n_predict
|
||||||
elif num_speculative_tokens > n_predict:
|
elif num_speculative_tokens > n_predict and \
|
||||||
# Verify provided value doesn't exceed the maximum
|
num_speculative_tokens % n_predict != 0:
|
||||||
# supported by the draft model.
|
# Ensure divisibility for MTP module reuse.
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"This speculative model supports a maximum of "
|
f"{num_speculative_tokens=} must be divisible by "
|
||||||
f"num_speculative_tokens={n_predict}, but "
|
f"{n_predict=}")
|
||||||
f"{num_speculative_tokens=} was provided.")
|
|
||||||
|
|
||||||
speculative_draft_tensor_parallel_size = \
|
speculative_draft_tensor_parallel_size = \
|
||||||
SpeculativeConfig._verify_and_get_draft_model_tensor_parallel_size(
|
SpeculativeConfig._verify_and_get_draft_model_tensor_parallel_size(
|
||||||
|
|||||||
@ -87,7 +87,7 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
|
|||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
residual=None)
|
residual=None)
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
return self.shared_head(hidden_states)
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class DeepSeekMultiTokenPredictor(nn.Module):
|
class DeepSeekMultiTokenPredictor(nn.Module):
|
||||||
@ -121,12 +121,13 @@ class DeepSeekMultiTokenPredictor(nn.Module):
|
|||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
spec_step_idx: int = 0,
|
spec_step_idx: int = 0,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return self.layers[str(self.mtp_start_layer_idx + spec_step_idx)](
|
current_step_idx = (spec_step_idx % self.num_mtp_layers)
|
||||||
|
return self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
|
||||||
input_ids,
|
input_ids,
|
||||||
positions,
|
positions,
|
||||||
previous_hidden_states,
|
previous_hidden_states,
|
||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
spec_step_idx,
|
current_step_idx,
|
||||||
)
|
)
|
||||||
|
|
||||||
def compute_logits(
|
def compute_logits(
|
||||||
@ -135,9 +136,12 @@ class DeepSeekMultiTokenPredictor(nn.Module):
|
|||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
spec_step_idx: int = 0,
|
spec_step_idx: int = 0,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
mtp_layer = self.layers[str(self.mtp_start_layer_idx + spec_step_idx)]
|
current_step_idx = (spec_step_idx % self.num_mtp_layers)
|
||||||
|
mtp_layer = self.layers[str(self.mtp_start_layer_idx +
|
||||||
|
current_step_idx)]
|
||||||
logits = self.logits_processor(mtp_layer.shared_head.head,
|
logits = self.logits_processor(mtp_layer.shared_head.head,
|
||||||
hidden_states, sampling_metadata)
|
mtp_layer.shared_head(hidden_states),
|
||||||
|
sampling_metadata)
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -50,12 +50,6 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model_runner: ModelRunnerBase):
|
def __init__(self, model_runner: ModelRunnerBase):
|
||||||
if hasattr(
|
|
||||||
model_runner,
|
|
||||||
"return_hidden_states") and model_runner.return_hidden_states:
|
|
||||||
raise ValueError(
|
|
||||||
"return_hidden_states is not supported for TP1DraftModelRunner."
|
|
||||||
)
|
|
||||||
super().__init__(model_runner)
|
super().__init__(model_runner)
|
||||||
|
|
||||||
self.indices_of_seq_with_bonus_tokens = None
|
self.indices_of_seq_with_bonus_tokens = None
|
||||||
@ -153,7 +147,7 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
# TODO: Add support for other attn backends
|
# TODO: Add support for other attn backends
|
||||||
if self.attn_backend.get_name() not in ("FLASH_ATTN", "TRITON_MLA"):
|
if self.attn_backend.get_name() not in ("FLASH_ATTN", ):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# TODO: Add support for LORA
|
# TODO: Add support for LORA
|
||||||
@ -307,6 +301,9 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
|
|||||||
)
|
)
|
||||||
outputs.append(output)
|
outputs.append(output)
|
||||||
|
|
||||||
|
if self.return_hidden_states and is_fallback:
|
||||||
|
output.hidden_states = hidden_states
|
||||||
|
|
||||||
if model_input.attn_metadata.num_prefills == 0 \
|
if model_input.attn_metadata.num_prefills == 0 \
|
||||||
and self.indices_of_seq_with_bonus_tokens is not None:
|
and self.indices_of_seq_with_bonus_tokens is not None:
|
||||||
assert output.sampled_token_ids is not None
|
assert output.sampled_token_ids is not None
|
||||||
|
|||||||
@ -96,12 +96,16 @@ class MultiStepWorker(ProposerWorkerBase, DelegateWorkerBase):
|
|||||||
# TODO: Remove this branch once DraftModelRunner supports TP>1
|
# TODO: Remove this branch once DraftModelRunner supports TP>1
|
||||||
# and other restrictions that are part of DraftModelRunner's
|
# and other restrictions that are part of DraftModelRunner's
|
||||||
# supports_gpu_multi_step(..)
|
# supports_gpu_multi_step(..)
|
||||||
|
if expanded_request.previous_hidden_states is not None:
|
||||||
|
self.worker.model_runner.return_hidden_states = True
|
||||||
for _ in range(sample_len):
|
for _ in range(sample_len):
|
||||||
model_output: List[SamplerOutput] = self.worker.execute_model(
|
model_output: List[SamplerOutput] = self.worker.execute_model(
|
||||||
execute_model_req=expanded_request)
|
execute_model_req=expanded_request)
|
||||||
assert (len(model_output) == 1
|
assert (len(model_output) == 1
|
||||||
), "composing multistep workers not supported"
|
), "composing multistep workers not supported"
|
||||||
model_output = model_output[0]
|
model_output = model_output[0]
|
||||||
|
self._maybe_update_previous_hidden_states(
|
||||||
|
model_output, expanded_request)
|
||||||
|
|
||||||
self._append_new_tokens(
|
self._append_new_tokens(
|
||||||
model_output, expanded_request.seq_group_metadata_list,
|
model_output, expanded_request.seq_group_metadata_list,
|
||||||
@ -115,6 +119,19 @@ class MultiStepWorker(ProposerWorkerBase, DelegateWorkerBase):
|
|||||||
model_outputs, indices_of_seq_with_bonus_tokens)
|
model_outputs, indices_of_seq_with_bonus_tokens)
|
||||||
return filtered_model_outputs, True
|
return filtered_model_outputs, True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _maybe_update_previous_hidden_states(
|
||||||
|
model_output: SamplerOutput,
|
||||||
|
expanded_request: ExecuteModelRequest) -> None:
|
||||||
|
"""
|
||||||
|
Updates the previous hidden states in an expanded request
|
||||||
|
in-place with the hidden states from the model output.
|
||||||
|
"""
|
||||||
|
if expanded_request.previous_hidden_states is not None:
|
||||||
|
expanded_request.previous_hidden_states = HiddenStates(
|
||||||
|
model_output.hidden_states,
|
||||||
|
expanded_request.seq_group_metadata_list)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _expand_execute_model_request(
|
def _expand_execute_model_request(
|
||||||
execute_model_req: ExecuteModelRequest,
|
execute_model_req: ExecuteModelRequest,
|
||||||
|
|||||||
@ -184,8 +184,7 @@ class SpecDecodeWorker(LoRANotSupportedWorkerBase):
|
|||||||
elif draft_model_config.hf_config.model_type == "medusa":
|
elif draft_model_config.hf_config.model_type == "medusa":
|
||||||
proposer_worker = MedusaWorker(**draft_worker_kwargs)
|
proposer_worker = MedusaWorker(**draft_worker_kwargs)
|
||||||
else:
|
else:
|
||||||
if draft_tp == 1 or draft_model_config.hf_config.model_type ==\
|
if draft_tp == 1:
|
||||||
"deepseek_mtp":
|
|
||||||
if current_platform.is_cuda_alike():
|
if current_platform.is_cuda_alike():
|
||||||
draft_worker_kwargs[
|
draft_worker_kwargs[
|
||||||
"model_runner_cls"] = TP1DraftModelRunner
|
"model_runner_cls"] = TP1DraftModelRunner
|
||||||
@ -203,7 +202,8 @@ class SpecDecodeWorker(LoRANotSupportedWorkerBase):
|
|||||||
|
|
||||||
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
|
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
|
||||||
if draft_model_config.hf_config.model_type == "deepseek_mtp":
|
if draft_model_config.hf_config.model_type == "deepseek_mtp":
|
||||||
num_spec_prefill_steps = num_speculative_tokens
|
num_spec_prefill_steps = \
|
||||||
|
draft_model_config.hf_config.n_predict
|
||||||
|
|
||||||
proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
|
proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
|
||||||
proposer_worker, draft_tp, target_tp)
|
proposer_worker, draft_tp, target_tp)
|
||||||
|
|||||||
@ -1685,11 +1685,22 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
|||||||
# TODO(andoorve): We can remove this once all
|
# TODO(andoorve): We can remove this once all
|
||||||
# virtual engines share the same kv cache.
|
# virtual engines share the same kv cache.
|
||||||
virtual_engine = model_input.virtual_engine
|
virtual_engine = model_input.virtual_engine
|
||||||
|
previous_hidden_states = kwargs.get("previous_hidden_states")
|
||||||
if prefill_meta is None and decode_meta.use_cuda_graph:
|
if prefill_meta is None and decode_meta.use_cuda_graph:
|
||||||
assert model_input.input_tokens is not None
|
assert model_input.input_tokens is not None
|
||||||
graph_batch_size = model_input.input_tokens.shape[0]
|
graph_batch_size = model_input.input_tokens.shape[0]
|
||||||
model_executable = self.graph_runners[virtual_engine][
|
model_executable = self.graph_runners[virtual_engine][
|
||||||
graph_batch_size]
|
graph_batch_size]
|
||||||
|
if previous_hidden_states is not None:
|
||||||
|
previous_hidden_states = torch.cat([
|
||||||
|
previous_hidden_states,
|
||||||
|
torch.empty([
|
||||||
|
graph_batch_size - previous_hidden_states.shape[0],
|
||||||
|
*previous_hidden_states.shape[1:]
|
||||||
|
],
|
||||||
|
dtype=previous_hidden_states.dtype,
|
||||||
|
device=previous_hidden_states.device)
|
||||||
|
])
|
||||||
else:
|
else:
|
||||||
model_executable = self.model
|
model_executable = self.model
|
||||||
|
|
||||||
@ -1716,7 +1727,6 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
|||||||
"finished_requests_ids": model_input.finished_requests_ids,
|
"finished_requests_ids": model_input.finished_requests_ids,
|
||||||
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
|
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
|
||||||
} if self.has_inner_state else {}
|
} if self.has_inner_state else {}
|
||||||
previous_hidden_states = kwargs.get("previous_hidden_states")
|
|
||||||
model_kwargs = {}
|
model_kwargs = {}
|
||||||
if previous_hidden_states is not None:
|
if previous_hidden_states is not None:
|
||||||
model_kwargs["previous_hidden_states"] = previous_hidden_states
|
model_kwargs["previous_hidden_states"] = previous_hidden_states
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user