[Model][Speculative Decoding] DeepSeek MTP spec decode (#12755)

Signed-off-by: Lu Fang <fanglu@fb.com>
Co-authored-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
This commit is contained in:
Lucia Fang 2025-02-19 01:06:23 -08:00 committed by GitHub
parent 983a40a8bb
commit f525c0be8b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 727 additions and 46 deletions

View File

@ -2,7 +2,7 @@
# adding a new command to an existing step. See different options here for examples. # adding a new command to an existing step. See different options here for examples.
# This script will be feed into Jinja template in `test-template-aws.j2` at # This script will be feed into Jinja template in `test-template-aws.j2` at
# https://github.com/vllm-project/buildkite-ci/blob/main/scripts/test-template-aws.j2 # https://github.com/vllm-project/buildkite-ci/blob/main/scripts/test-template-aws.j2
# to generate the final pipeline yaml file. # to generate the final pipeline yaml file.
# Documentation # Documentation
@ -15,7 +15,7 @@
# mirror_hardwares(list): the list of hardwares to run the test on as well. currently only supports [amd] # mirror_hardwares(list): the list of hardwares to run the test on as well. currently only supports [amd]
# gpu(str): override the GPU selection for the test. default is on L4 GPUs. currently only supports a100 # gpu(str): override the GPU selection for the test. default is on L4 GPUs. currently only supports a100
# num_gpus(int): override the number of GPUs for the test. default to 1 GPU. currently support 2,4. # num_gpus(int): override the number of GPUs for the test. default to 1 GPU. currently support 2,4.
# num_nodes(int): whether to simulate multi-node setup by launch multiple containers on one host, # num_nodes(int): whether to simulate multi-node setup by launch multiple containers on one host,
# in this case, commands must be specified. the first command runs on first host, the second # in this case, commands must be specified. the first command runs on first host, the second
# command runs on the second host. # command runs on the second host.
# working_dir(str): specify the place where command should execute, default to /vllm-workspace/tests # working_dir(str): specify the place where command should execute, default to /vllm-workspace/tests
@ -24,8 +24,8 @@
# When adding a test # When adding a test
# - If the test belong to an existing group, add it there # - If the test belong to an existing group, add it there
# - If the test is short, add to any existing step # - If the test is short, add to any existing step
# - If the test takes more than 10min, then it is okay to create a new step. # - If the test takes more than 10min, then it is okay to create a new step.
# Note that all steps execute in parallel. # Note that all steps execute in parallel.
steps: steps:
##### fast check tests ##### ##### fast check tests #####
@ -145,14 +145,14 @@ steps:
- RAY_DEDUP_LOGS=0 python3 ../examples/offline_inference/rlhf_colocate.py - RAY_DEDUP_LOGS=0 python3 ../examples/offline_inference/rlhf_colocate.py
- label: Metrics, Tracing Test # 10min - label: Metrics, Tracing Test # 10min
num_gpus: 2 num_gpus: 2
fast_check: true fast_check: true
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
- tests/metrics - tests/metrics
- tests/tracing - tests/tracing
commands: commands:
- pytest -v -s metrics - pytest -v -s metrics
- "pip install \ - "pip install \
'opentelemetry-sdk>=1.26.0,<1.27.0' \ 'opentelemetry-sdk>=1.26.0,<1.27.0' \
'opentelemetry-api>=1.26.0,<1.27.0' \ 'opentelemetry-api>=1.26.0,<1.27.0' \
@ -254,7 +254,7 @@ steps:
- vllm/model_executor/guided_decoding - vllm/model_executor/guided_decoding
- tests/test_logits_processor - tests/test_logits_processor
- tests/model_executor/test_guided_processors - tests/model_executor/test_guided_processors
commands: commands:
- pytest -v -s test_logits_processor.py - pytest -v -s test_logits_processor.py
- pytest -v -s model_executor/test_guided_processors.py - pytest -v -s model_executor/test_guided_processors.py
@ -265,7 +265,7 @@ steps:
- vllm/model_executor/models/eagle.py - vllm/model_executor/models/eagle.py
commands: commands:
- pytest -v -s spec_decode/e2e/test_multistep_correctness.py - pytest -v -s spec_decode/e2e/test_multistep_correctness.py
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py --ignore=spec_decode/e2e/test_mtp_correctness.py
- pytest -v -s spec_decode/e2e/test_eagle_correctness.py - pytest -v -s spec_decode/e2e/test_eagle_correctness.py
- label: LoRA Test %N # 15min each - label: LoRA Test %N # 15min each
@ -580,7 +580,7 @@ steps:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn - export VLLM_WORKER_MULTIPROC_METHOD=spawn
# This test runs llama 13B, so it is required to run on 4 GPUs. # This test runs llama 13B, so it is required to run on 4 GPUs.
- pytest -v -s -x lora/test_long_context.py - pytest -v -s -x lora/test_long_context.py
# There is some Tensor Parallelism related processing logic in LoRA that # There is some Tensor Parallelism related processing logic in LoRA that
# requires multi-GPU testing for validation. # requires multi-GPU testing for validation.
- pytest -v -s -x lora/test_chatglm3_tp.py - pytest -v -s -x lora/test_chatglm3_tp.py
- pytest -v -s -x lora/test_llama_tp.py - pytest -v -s -x lora/test_llama_tp.py
@ -605,7 +605,7 @@ steps:
- vllm/ - vllm/
- tests/weight_loading - tests/weight_loading
commands: commands:
- bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models-large.txt - bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models-large.txt
##### multi gpus test ##### ##### multi gpus test #####
@ -617,7 +617,7 @@ steps:
num_gpus: 4 num_gpus: 4
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
commands: commands:
# NOTE: don't test llama model here, it seems hf implementation is buggy # NOTE: don't test llama model here, it seems hf implementation is buggy
# see https://github.com/vllm-project/vllm/pull/5689 for details # see https://github.com/vllm-project/vllm/pull/5689 for details
- pytest -v -s distributed/test_custom_all_reduce.py - pytest -v -s distributed/test_custom_all_reduce.py

View File

@ -296,6 +296,9 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
speculative_model="abhigoyal/vllm-medusa-llama-68m-random"), # noqa: E501 speculative_model="abhigoyal/vllm-medusa-llama-68m-random"), # noqa: E501
"MLPSpeculatorPreTrainedModel": _HfExamplesInfo("JackFram/llama-160m", "MLPSpeculatorPreTrainedModel": _HfExamplesInfo("JackFram/llama-160m",
speculative_model="ibm-ai-platform/llama-160m-accelerator"), # noqa: E501 speculative_model="ibm-ai-platform/llama-160m-accelerator"), # noqa: E501
"DeepSeekMTPModel": _HfExamplesInfo("luccafong/deepseek_mtp_main_random",
speculative_model="luccafong/deepseek_mtp_draft_random", # noqa: E501
trust_remote_code=True),
} }
_FALLBACK_MODEL = { _FALLBACK_MODEL = {

View File

@ -0,0 +1,318 @@
# SPDX-License-Identifier: Apache-2.0
"""This docstring details important information on the testing methodology.
Most of the tests rely on "greedy equality", where we expect the output of
speculative decoding on a sequence to exactly match the output of normal non-
speculative decoding.
Since speculative decoding with rejection sampling guarantees that the output
distribution matches the target model's output distribution (up to hardware
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
equality.
However, we still need to verify below scenario could be passed:
* Batch size 1 greedy equality
* Batch size >1 greedy equality
* Test greedy equality under preemption
* Test greedy equality under various number of speculative tokens.
With those tests, we can say at least, mtp would not break the
correctess for the target model outputs.
"""
import pytest
from .conftest import run_equality_correctness_test
# main model
MAIN_MODEL = "luccafong/deepseek_mtp_main_random"
# max. number of speculative tokens: this corresponds to
# num_nextn_predict_layers in the config.json of the speculator model.
MAX_SPEC_TOKENS = 1
# precision
PRECISION = "bfloat16"
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.85
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
])
@pytest.mark.parametrize("output_len", [
128,
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
seed: int):
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size, output_len, seed)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.85
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs_during_spec_decoding": False,
},
{
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs_during_spec_decoding": True,
},
])
@pytest.mark.parametrize("output_len", [
128,
])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("logprobs", [1, 6])
def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int, seed: int,
logprobs: int):
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs[
'disable_logprobs_during_spec_decoding'])
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"enforce_eager": False,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
"gpu_memory_utilization": 0.85
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
])
@pytest.mark.parametrize("output_len", [
128,
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
def test_mtp_e2e_greedy_correctness_cuda_graph(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size: int,
output_len: int, seed: int):
"""Verify greedy equality with cuda graph enabled and different
batch sizes."""
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size, output_len, seed)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"block_size": 8,
# 2 for small prompt, 256//8 for generated.
"num_gpu_blocks_override": 2 + 256 // 8,
"max_model_len": (2 + 256 // 8) * 8,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.9
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use small output len for fast test.
128,
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
def test_mtp_e2e_greedy_correctness_with_preemption(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify greedy equality, even when some sequences are preempted mid-
generation.
"""
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size, output_len, seed)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.9
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
{
"num_speculative_tokens": k,
}
# Try a range of num. speculative tokens
for k in range(1, 1 + MAX_SPEC_TOKENS)
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_mtp_different_k(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify that mtp speculative decoding produces exact equality
to without spec decode with different values of num_speculative_tokens.
"""
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size, output_len, seed)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.9
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_disable_by_batch_size": 4
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_mtp_disable_queue(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify that mtp speculative decoding produces exact equality
to without spec decode when speculation is disabled for large
batch sizes.
"""
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size, output_len, seed)
if __name__ == "__main__":
import pytest
pytest.main([__file__])

View File

@ -763,7 +763,7 @@ class ModelConfig:
def is_deepseek_mla(self) -> bool: def is_deepseek_mla(self) -> bool:
return (hasattr(self.hf_text_config, "model_type")) \ return (hasattr(self.hf_text_config, "model_type")) \
and (self.hf_text_config.model_type in \ and (self.hf_text_config.model_type in \
('deepseek_v2', 'deepseek_v3'))\ ('deepseek_v2', 'deepseek_v3', 'deepseek_mtp'))\
and (self.hf_text_config.kv_lora_rank is not None) and (self.hf_text_config.kv_lora_rank is not None)
def get_head_size(self) -> int: def get_head_size(self) -> int:
@ -856,8 +856,12 @@ class ModelConfig:
def get_layers_start_end_indices( def get_layers_start_end_indices(
self, parallel_config: "ParallelConfig") -> Tuple[int, int]: self, parallel_config: "ParallelConfig") -> Tuple[int, int]:
from vllm.distributed.utils import get_pp_indices from vllm.distributed.utils import get_pp_indices
total_num_hidden_layers = getattr(self.hf_text_config, if self.hf_text_config.model_type == "deepseek_mtp":
"num_hidden_layers", 0) total_num_hidden_layers = getattr(self.hf_text_config,
"num_nextn_predict_layers", 0)
else:
total_num_hidden_layers = getattr(self.hf_text_config,
"num_hidden_layers", 0)
pp_rank = parallel_config.rank // parallel_config.tensor_parallel_size pp_rank = parallel_config.rank // parallel_config.tensor_parallel_size
pp_size = parallel_config.pipeline_parallel_size pp_size = parallel_config.pipeline_parallel_size
start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size) start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size)
@ -1689,6 +1693,18 @@ class SpeculativeConfig:
hash_str = hashlib.md5(str(factors).encode()).hexdigest() hash_str = hashlib.md5(str(factors).encode()).hexdigest()
return hash_str return hash_str
@staticmethod
def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
if hf_config.model_type == "deepseek_v3":
hf_config.model_type = "deepseek_mtp"
if hf_config.model_type == "deepseek_mtp":
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update({
"n_predict": n_predict,
"architectures": ["DeepSeekMTPModel"]
})
return hf_config
@staticmethod @staticmethod
def maybe_create_spec_config( def maybe_create_spec_config(
target_model_config: ModelConfig, target_model_config: ModelConfig,
@ -1771,12 +1787,18 @@ class SpeculativeConfig:
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
the necessary conditions are met, else None. the necessary conditions are met, else None.
""" """
if speculative_model is None: if speculative_model is None:
if num_speculative_tokens is not None: if num_speculative_tokens is not None:
raise ValueError("num_speculative_tokens was provided without " if target_model_config.hf_text_config.model_type \
"speculative_model.") == "deepseek_v3":
return None # use the draft model from the same model:
speculative_model = target_model_config.model
else:
raise ValueError(
"num_speculative_tokens was provided without "
"speculative_model.")
else:
return None
if (speculative_disable_by_batch_size is not None if (speculative_disable_by_batch_size is not None
and speculative_disable_by_batch_size < 2): and speculative_disable_by_batch_size < 2):
@ -1830,6 +1852,7 @@ class SpeculativeConfig:
max_seq_len_to_capture=target_model_config. max_seq_len_to_capture=target_model_config.
max_seq_len_to_capture, max_seq_len_to_capture,
max_logprobs=target_model_config.max_logprobs, max_logprobs=target_model_config.max_logprobs,
hf_overrides=SpeculativeConfig.hf_config_override,
) )
draft_hf_config = draft_model_config.hf_config draft_hf_config = draft_model_config.hf_config
@ -1846,7 +1869,6 @@ class SpeculativeConfig:
if (num_speculative_tokens is not None if (num_speculative_tokens is not None
and hasattr(draft_hf_config, "num_lookahead_tokens")): and hasattr(draft_hf_config, "num_lookahead_tokens")):
draft_hf_config.num_lookahead_tokens = num_speculative_tokens draft_hf_config.num_lookahead_tokens = num_speculative_tokens
n_predict = getattr(draft_hf_config, "n_predict", None) n_predict = getattr(draft_hf_config, "n_predict", None)
if n_predict is not None: if n_predict is not None:
if num_speculative_tokens is None: if num_speculative_tokens is None:
@ -1960,8 +1982,9 @@ class SpeculativeConfig:
speculative_draft_tensor_parallel_size = 1 speculative_draft_tensor_parallel_size = 1
if target_parallel_config.tensor_parallel_size > 1: if target_parallel_config.tensor_parallel_size > 1:
logger.warning( logger.warning(
"MLPSpeculator cannot currently be run with tp>1; " "%s cannot currently be run with tp>1; "
"setting speculative_draft_tensor_parallel_size=1") "setting speculative_draft_tensor_parallel_size=1",
draft_hf_config.model_type)
else: else:
speculative_draft_tensor_parallel_size = \ speculative_draft_tensor_parallel_size = \
target_parallel_config.tensor_parallel_size target_parallel_config.tensor_parallel_size

View File

@ -0,0 +1,284 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Iterable, List, Optional, Set, Tuple
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .deepseek_v2 import (DeepseekV2DecoderLayer,
get_spec_layer_idx_from_weight_name)
from .utils import maybe_prefix
class SharedHead(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.norm(hidden_states)
class DeepSeekMultiTokenPredictorLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
prefix: str,
model_config: ModelConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.eh_proj = nn.Linear(config.hidden_size * 2,
config.hidden_size,
bias=False)
self.shared_head = SharedHead(config=config, quant_config=quant_config)
self.mtp_block = DeepseekV2DecoderLayer(config, prefix, model_config,
cache_config, quant_config)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
previous_hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_index: int = 0,
) -> torch.Tensor:
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
assert inputs_embeds is not None
# masking inputs at position 0, as not needed by MTP
inputs_embeds[positions == 0] = 0
inputs_embeds = self.enorm(inputs_embeds)
previous_hidden_states = self.hnorm(previous_hidden_states)
hidden_states = self.eh_proj(
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
hidden_states, residual = self.mtp_block(positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
residual=None)
hidden_states = residual + hidden_states
return self.shared_head(hidden_states)
class DeepSeekMultiTokenPredictor(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.mtp_start_layer_idx = config.num_hidden_layers
self.num_mtp_layers = config.num_nextn_predict_layers
# to map the exact layer index from weights
self.layers = torch.nn.ModuleDict({
str(idx):
DeepSeekMultiTokenPredictorLayer(
config,
f"{prefix}.layers.{idx}",
model_config=vllm_config.model_config,
cache_config=vllm_config.cache_config,
quant_config=vllm_config.quant_config,
)
for idx in range(self.mtp_start_layer_idx,
self.mtp_start_layer_idx + self.num_mtp_layers)
})
self.logits_processor = LogitsProcessor(config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
previous_hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
return self.layers[str(self.mtp_start_layer_idx + spec_step_idx)](
input_ids,
positions,
kv_caches[spec_step_idx],
attn_metadata,
previous_hidden_states,
inputs_embeds,
spec_step_idx,
)
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
spec_step_idx: int = 0,
) -> torch.Tensor:
mtp_layer = self.layers[str(self.mtp_start_layer_idx + spec_step_idx)]
logits = self.logits_processor(mtp_layer.shared_head.head,
hidden_states, sampling_metadata)
return logits
class DeepSeekMTP(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.config = vllm_config.model_config.hf_config
self.model = DeepSeekMultiTokenPredictor(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "model"))
self.sampler = get_sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
previous_hidden_states: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, previous_hidden_states,
inputs_embeds, spec_step_idx)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
spec_step_idx: int = 0,
) -> Optional[torch.Tensor]:
return self.model.compute_logits(hidden_states, sampling_metadata,
spec_step_idx)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts)
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is None:
continue
name = self._rewrite_spec_layer_name(spec_layer, name)
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if (("mlp.experts." in name) and name not in params_dict):
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
"""
Rewrite the weight name to match the format of the original model.
Add .mtp_block for modules in transformer layer block for spec layer
"""
spec_layer_weight_names = [
"embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head"
]
spec_layer_weight = False
for weight_name in spec_layer_weight_names:
if weight_name in name:
spec_layer_weight = True
break
if not spec_layer_weight:
# treat rest weights as weights for transformer layer block
name = name.replace(f"model.layers.{spec_layer}.",
f"model.layers.{spec_layer}.mtp_block.")
return name

View File

@ -732,13 +732,9 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
# TODO(simon): support nextn predict layers spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if hasattr(self.config, "num_nextn_predict_layers" if spec_layer is not None:
) and self.config.num_nextn_predict_layers > 0: continue # skip spec decode layers for main model
assert self.config.num_nextn_predict_layers == 1
layer_idx = self.config.num_hidden_layers
if name.startswith(f"model.layers.{layer_idx}"):
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below). # Skip non-stacked layers and experts (experts handled below).
@ -805,3 +801,15 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
pass pass
def get_spec_layer_idx_from_weight_name(config: PretrainedConfig,
weight_name: str) -> Optional[int]:
if hasattr(config,
"num_nextn_predict_layers") and (config.num_nextn_predict_layers
> 0):
layer_idx = config.num_hidden_layers
for i in range(config.num_nextn_predict_layers):
if weight_name.startswith(f"model.layers.{layer_idx+i}."):
return layer_idx + i
return None

View File

@ -187,6 +187,7 @@ _MULTIMODAL_MODELS = {
_SPECULATIVE_DECODING_MODELS = { _SPECULATIVE_DECODING_MODELS = {
"EAGLEModel": ("eagle", "EAGLE"), "EAGLEModel": ("eagle", "EAGLE"),
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
"MedusaModel": ("medusa", "Medusa"), "MedusaModel": ("medusa", "Medusa"),
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
} }

View File

@ -1307,6 +1307,8 @@ class ExecuteModelRequest(
previous_hidden_states: Optional[HiddenStates] = None previous_hidden_states: Optional[HiddenStates] = None
# The number of forward steps to run. # The number of forward steps to run.
num_steps: int = 1 num_steps: int = 1
# The step index for spec model input.
spec_step_idx: Optional[int] = None
# Finished request ids since last step. # Finished request ids since last step.
finished_requests_ids: List[str] = msgspec.field(default_factory=list) finished_requests_ids: List[str] = msgspec.field(default_factory=list)
# The last sampled token ids for multi step decoding. # The last sampled token ids for multi step decoding.

View File

@ -153,7 +153,7 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
return False return False
# TODO: Add support for other attn backends # TODO: Add support for other attn backends
if self.attn_backend.get_name() != "FLASH_ATTN": if self.attn_backend.get_name() not in ("FLASH_ATTN", "TRITON_MLA"):
return False return False
# TODO: Add support for LORA # TODO: Add support for LORA
@ -175,6 +175,7 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
previous_hidden_states: Optional[torch.Tensor] = None, previous_hidden_states: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1, num_steps: int = 1,
**kwargs,
) -> Optional[List[SamplerOutput]]: ) -> Optional[List[SamplerOutput]]:
"""Executes num_steps forward passes with advacement of input tensors """Executes num_steps forward passes with advacement of input tensors
on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions. on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions.
@ -271,10 +272,17 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
for step in range(num_steps): for step in range(num_steps):
multi_modal_kwargs = model_input.multi_modal_kwargs or {} multi_modal_kwargs = model_input.multi_modal_kwargs or {}
kwargs = {"previous_hidden_states": hidden_states} \ model_execute_kwargs = {"previous_hidden_states": hidden_states} \
if previous_hidden_states is not None else {} if previous_hidden_states is not None else {}
compute_logits_kwargs = {}
# Run model # Run model
if hasattr(self.model.config, "num_nextn_predict_layers"):
# for DeepSeek MTP only to use the corresponding layer for
# each step
spec_step_idx = kwargs.get("spec_step_idx", step)
model_execute_kwargs["spec_step_idx"] = spec_step_idx
compute_logits_kwargs["spec_step_idx"] = spec_step_idx
with set_forward_context(model_input.attn_metadata, with set_forward_context(model_input.attn_metadata,
self.vllm_config): self.vllm_config):
hidden_states = model_executable( hidden_states = model_executable(
@ -285,13 +293,15 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs, **MultiModalKwargs.as_kwargs(multi_modal_kwargs,
device=self.device), device=self.device),
**kwargs, **model_execute_kwargs,
) )
# Compute the logits. # Compute the logits.
logits = self.model.compute_logits(hidden_states, logits = self.model.compute_logits(hidden_states,
model_input.sampling_metadata) model_input.sampling_metadata,
**compute_logits_kwargs)
if not self.is_driver_worker:
return []
# Sample the next token. # Sample the next token.
output = self.model.sample( output = self.model.sample(
logits=logits, logits=logits,

View File

@ -108,6 +108,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
typical_acceptance_sampler_posterior_alpha, typical_acceptance_sampler_posterior_alpha,
disable_logprobs=speculative_config.disable_logprobs, disable_logprobs=speculative_config.disable_logprobs,
disable_log_stats=speculative_config.disable_log_stats, disable_log_stats=speculative_config.disable_log_stats,
num_speculative_tokens=speculative_config.num_speculative_tokens,
) )
return spec_decode_worker return spec_decode_worker
@ -153,10 +154,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
typical_acceptance_sampler_posterior_alpha: float, typical_acceptance_sampler_posterior_alpha: float,
disable_logprobs: bool, disable_logprobs: bool,
disable_log_stats: bool, disable_log_stats: bool,
num_speculative_tokens: int,
) -> "SpecDecodeWorker": ) -> "SpecDecodeWorker":
allow_zero_draft_token_step = True allow_zero_draft_token_step = True
enable_lm_head_weight_load = False enable_lm_head_weight_load = False
num_spec_prefill_steps = 1
ngram_prompt_lookup_max = ( ngram_prompt_lookup_max = (
draft_worker_kwargs.pop("ngram_prompt_lookup_max")) draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
ngram_prompt_lookup_min = ( ngram_prompt_lookup_min = (
@ -179,14 +182,16 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
elif draft_model_config.hf_config.model_type == "medusa": elif draft_model_config.hf_config.model_type == "medusa":
proposer_worker = MedusaWorker(**draft_worker_kwargs) proposer_worker = MedusaWorker(**draft_worker_kwargs)
else: else:
if draft_tp == 1: if draft_tp == 1 or draft_model_config.hf_config.model_type ==\
"deepseek_mtp":
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
draft_worker_kwargs[ draft_worker_kwargs[
"model_runner_cls"] = TP1DraftModelRunner "model_runner_cls"] = TP1DraftModelRunner
else: else:
if draft_model_config.hf_config.model_type == "eagle": if draft_model_config.hf_config.model_type == "eagle":
raise NotImplementedError( raise NotImplementedError(
"EAGLE does not support TP > 1 yet") f"{draft_model_config.hf_config.model_type} "
"does not support TP > 1 yet")
allow_zero_draft_token_step = False allow_zero_draft_token_step = False
@ -195,6 +200,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
enable_lm_head_weight_load = True enable_lm_head_weight_load = True
proposer_worker = MultiStepWorker(**draft_worker_kwargs) proposer_worker = MultiStepWorker(**draft_worker_kwargs)
if draft_model_config.hf_config.model_type == "deepseek_mtp":
num_spec_prefill_steps = num_speculative_tokens
proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker( proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
proposer_worker, draft_tp, target_tp) proposer_worker, draft_tp, target_tp)
@ -247,7 +254,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
disable_by_batch_size=disable_by_batch_size, disable_by_batch_size=disable_by_batch_size,
spec_decode_sampler=spec_decode_sampler, spec_decode_sampler=spec_decode_sampler,
allow_zero_draft_token_step=allow_zero_draft_token_step, allow_zero_draft_token_step=allow_zero_draft_token_step,
enable_lm_head_weight_load=enable_lm_head_weight_load) enable_lm_head_weight_load=enable_lm_head_weight_load,
num_spec_prefill_steps=num_spec_prefill_steps)
def __init__( def __init__(
self, self,
@ -261,6 +269,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
disable_by_batch_size: Optional[int] = None, disable_by_batch_size: Optional[int] = None,
allow_zero_draft_token_step: Optional[bool] = True, allow_zero_draft_token_step: Optional[bool] = True,
enable_lm_head_weight_load: Optional[bool] = False, enable_lm_head_weight_load: Optional[bool] = False,
num_spec_prefill_steps: int = 1,
): ):
""" """
Create a SpecDecodeWorker. Create a SpecDecodeWorker.
@ -293,6 +302,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
draft model is larger than 1 (TODO: #5814) draft model is larger than 1 (TODO: #5814)
enable_lm_head_weight_load: whether to load lm_head weight for enable_lm_head_weight_load: whether to load lm_head weight for
draft models like eagle. draft models like eagle.
num_spec_prefill_steps: number of speculative prefill steps to run
before the speculative decoding starts. This is only used when
the draft model is a deepseek_mtp model that requires prefill
kv cache separately for each MTP layer.
""" """
self.proposer_worker = proposer_worker self.proposer_worker = proposer_worker
self.scorer_worker = scorer_worker self.scorer_worker = scorer_worker
@ -326,6 +339,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self.previous_hidden_states: Optional[HiddenStates] = None self.previous_hidden_states: Optional[HiddenStates] = None
self._disable_logprobs = disable_logprobs self._disable_logprobs = disable_logprobs
self._disable_log_stats = disable_log_stats self._disable_log_stats = disable_log_stats
self._num_spec_prefill_steps = num_spec_prefill_steps
def init_device(self) -> None: def init_device(self) -> None:
"""Initialize both scorer and proposer models. """Initialize both scorer and proposer models.
@ -685,8 +699,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
execute_model_req.previous_hidden_states = \ execute_model_req.previous_hidden_states = \
prepare_prefill_hidden_states( prepare_prefill_hidden_states(
sampler_output.prefill_hidden_states) sampler_output.prefill_hidden_states)
for i in range(self._num_spec_prefill_steps):
self.proposer_worker.execute_model(execute_model_req) execute_model_req.spec_step_idx = i
self.proposer_worker.execute_model(execute_model_req)
sampler_output_to_return = (self._serialize_sampler_output_no_logprobs( sampler_output_to_return = (self._serialize_sampler_output_no_logprobs(
execute_model_req=execute_model_req, sampler_output=sampler_output) execute_model_req=execute_model_req, sampler_output=sampler_output)

View File

@ -99,6 +99,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
virtual_engine: int = 0 virtual_engine: int = 0
async_callback: Optional[Callable] = None async_callback: Optional[Callable] = None
scheduler_outputs: Optional[SchedulerOutputs] = None scheduler_outputs: Optional[SchedulerOutputs] = None
previous_hidden_states: Optional[torch.Tensor] = None
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = { tensor_dict = {
@ -1649,6 +1650,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1, num_steps: int = 1,
**kwargs,
) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
if num_steps > 1: if num_steps > 1:
raise ValueError("num_steps > 1 is not supported in ModelRunner") raise ValueError("num_steps > 1 is not supported in ModelRunner")
@ -1706,6 +1708,10 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
"finished_requests_ids": model_input.finished_requests_ids, "finished_requests_ids": model_input.finished_requests_ids,
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
} if self.has_inner_state else {} } if self.has_inner_state else {}
previous_hidden_states = kwargs.get("previous_hidden_states")
model_kwargs = {}
if previous_hidden_states is not None:
model_kwargs["previous_hidden_states"] = previous_hidden_states
if (self.observability_config is not None if (self.observability_config is not None
and self.observability_config.collect_model_forward_time): and self.observability_config.collect_model_forward_time):
model_forward_start = torch.cuda.Event(enable_timing=True) model_forward_start = torch.cuda.Event(enable_timing=True)
@ -1723,7 +1729,9 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs, **MultiModalKwargs.as_kwargs(multi_modal_kwargs,
device=self.device), device=self.device),
**seqlen_agnostic_kwargs) **seqlen_agnostic_kwargs,
**model_kwargs,
)
if (self.observability_config is not None if (self.observability_config is not None
and self.observability_config.collect_model_forward_time): and self.observability_config.collect_model_forward_time):
@ -1815,7 +1823,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
1. current vLLM instance is KV cache consumer/decode vLLM instance 1. current vLLM instance is KV cache consumer/decode vLLM instance
2. this batch is not a profiling run 2. this batch is not a profiling run
3. this batch is a prefill run 3. this batch is a prefill run
Args: Args:
model_input: input to the model executable model_input: input to the model executable
kv_caches: vLLM's paged memory kv_caches: vLLM's paged memory
@ -1840,7 +1848,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
1. current vLLM instance is KV cache producer/prefill vLLM instance 1. current vLLM instance is KV cache producer/prefill vLLM instance
2. this batch is not a profiling run 2. this batch is not a profiling run
3. this batch is a prefill run 3. this batch is a prefill run
Args: Args:
model_input: input to the model executable model_input: input to the model executable
kv_caches: vLLM's paged memory kv_caches: vLLM's paged memory
@ -1976,7 +1984,11 @@ class CUDAGraphRunner(nn.Module):
# Copy the input tensors to the input buffers. # Copy the input tensors to the input buffers.
self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True) self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
if positions is not None: if positions is not None:
self.input_buffers["positions"].copy_(positions, non_blocking=True) # in some case like MLA, it will reuse positions in metadata
# but truncate them to the original size
# so the shape is not padded, we need to copy partial only
self.input_buffers["positions"][:positions.shape[0]].copy_(
positions, non_blocking=True)
if self.backend_name != "NO_ATTENTION": if self.backend_name != "NO_ATTENTION":
self.input_buffers["slot_mapping"].copy_( self.input_buffers["slot_mapping"].copy_(

View File

@ -46,7 +46,10 @@ def _init_attn_metadata_from_tensor_dict(
valid_attn_kwargs = {} valid_attn_kwargs = {}
for field in dataclasses.fields(attn_backend.get_metadata_cls()): for field in dataclasses.fields(attn_backend.get_metadata_cls()):
if field.name in tensor_dict: if field.name in tensor_dict:
valid_attn_kwargs[field.name] = tensor_dict.pop(field.name) if field.name == "input_positions":
valid_attn_kwargs[field.name] = tensor_dict[field.name]
else:
valid_attn_kwargs[field.name] = tensor_dict.pop(field.name)
attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs) attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs)
tensor_dict["attn_metadata"] = attn_metadata tensor_dict["attn_metadata"] = attn_metadata

View File

@ -68,10 +68,10 @@ class Worker(LocalOrDistributedWorkerBase):
speculative_config = self.speculative_config speculative_config = self.speculative_config
model_config = self.model_config model_config = self.model_config
speculative_args = {} if speculative_config is None \ speculative_args = {} if speculative_config is None \
or (speculative_config.draft_model_config.model == or (speculative_config.draft_model_config.hf_config.model_type ==
model_config.model) \ model_config.hf_config.model_type) \
or (speculative_config.draft_model_config.hf_config.model_type or (speculative_config.draft_model_config.hf_config.model_type
not in ["medusa", "mlp_speculator", "eagle"]) \ not in ("medusa", "mlp_speculator", "eagle", "deepseek_mtp")) \
else {"return_hidden_states": True} else {"return_hidden_states": True}
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner

View File

@ -397,6 +397,8 @@ class LocalOrDistributedWorkerBase(WorkerBase):
model_input, worker_input, kwargs = inputs model_input, worker_input, kwargs = inputs
num_steps = worker_input.num_steps num_steps = worker_input.num_steps
if (execute_model_req is not None and execute_model_req.spec_step_idx):
kwargs["spec_step_idx"] = execute_model_req.spec_step_idx
self.execute_worker(worker_input) self.execute_worker(worker_input)